Added pycrypto dependency. Adding callback to msglocal aioread method
[osm/common.git] / osm_common / msglocal.py
index c774f85..b0abb89 100644 (file)
@@ -1,26 +1,46 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import logging
 import os
 import yaml
 import asyncio
-from msgbase import MsgBase, MsgException
+from osm_common.msgbase import MsgBase, MsgException
 from time import sleep
+from http import HTTPStatus
 
 __author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
 
 """
 This emulated kafka bus by just using a shared file system. Useful for testing or devops.
-One file is used per topic. Only one producer and one consumer is allowed per topic. Both consumer and producer 
+One file is used per topic. Only one producer and one consumer is allowed per topic. Both consumer and producer
 access to the same file. e.g. same volume if running with docker.
 One text line per message is used in yaml format.
 """
 
+
 class MsgLocal(MsgBase):
 
     def __init__(self, logger_name='msg'):
         self.logger = logging.getLogger(logger_name)
         self.path = None
         # create a different file for each topic
-        self.files = {}
+        self.files_read = {}
+        self.files_write = {}
         self.buffer = {}
 
     def connect(self, config):
@@ -35,13 +55,18 @@ class MsgLocal(MsgBase):
         except MsgException:
             raise
         except Exception as e:  # TODO refine
-            raise MsgException(str(e))
+            raise MsgException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
 
     def disconnect(self):
-        for f in self.files.values():
+        for f in self.files_read.values():
             try:
                 f.close()
-            except Exception as e:  # TODO refine
+            except Exception:  # TODO refine
+                pass
+        for f in self.files_write.values():
+            try:
+                f.close()
+            except Exception:  # TODO refine
                 pass
 
     def write(self, topic, key, msg):
@@ -53,12 +78,12 @@ class MsgLocal(MsgBase):
         :return: None or raises and exception
         """
         try:
-            if topic not in self.files:
-                self.files[topic] = open(self.path + topic, "a+")
-            yaml.safe_dump({key: msg}, self.files[topic], default_flow_style=True, width=20000)
-            self.files[topic].flush()
+            if topic not in self.files_write:
+                self.files_write[topic] = open(self.path + topic, "a+")
+            yaml.safe_dump({key: msg}, self.files_write[topic], default_flow_style=True, width=20000)
+            self.files_write[topic].flush()
         except Exception as e:  # TODO refine
-            raise MsgException(str(e))
+            raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
 
     def read(self, topic, blocks=True):
         """
@@ -74,10 +99,10 @@ class MsgLocal(MsgBase):
                 topic_list = (topic, )
             while True:
                 for single_topic in topic_list:
-                    if single_topic not in self.files:
-                        self.files[single_topic] = open(self.path + single_topic, "a+")
+                    if single_topic not in self.files_read:
+                        self.files_read[single_topic] = open(self.path + single_topic, "a+")
                         self.buffer[single_topic] = ""
-                    self.buffer[single_topic] += self.files[single_topic].readline()
+                    self.buffer[single_topic] += self.files_read[single_topic].readline()
                     if not self.buffer[single_topic].endswith("\n"):
                         continue
                     msg_dict = yaml.load(self.buffer[single_topic])
@@ -89,9 +114,9 @@ class MsgLocal(MsgBase):
                     return None
                 sleep(2)
         except Exception as e:  # TODO refine
-            raise MsgException(str(e))
+            raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
 
-    async def aioread(self, topic, loop):
+    async def aioread(self, topic, loop=None, callback=None, aiocallback=None, **kwargs):
         """
         Asyncio read from one or several topics. It blocks
         :param topic: can be str: single topic; or str list: several topics
@@ -102,10 +127,25 @@ class MsgLocal(MsgBase):
             while True:
                 msg = self.read(topic, blocks=False)
                 if msg:
-                    return msg
+                    if callback:
+                        callback(*msg, **kwargs)
+                    elif aiocallback:
+                        await aiocallback(*msg, **kwargs)
+                    else:
+                        return msg
                 await asyncio.sleep(2, loop=loop)
         except MsgException:
             raise
         except Exception as e:  # TODO refine
-            raise MsgException(str(e))
+            raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
 
+    async def aiowrite(self, topic, key, msg, loop=None):
+        """
+        Asyncio write. It blocks
+        :param topic: str
+        :param key: str
+        :param msg: message, can be str or yaml
+        :param loop: asyncio loop
+        :return: nothing if ok or raises an exception
+        """
+        return self.write(topic, key, msg)