Make common methods threading safe. pytest enhancements
[osm/common.git] / osm_common / msglocal.py
index 8fae7a2..1e8e089 100644 (file)
@@ -1,9 +1,27 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import logging
 import os
 import yaml
 import asyncio
 from osm_common.msgbase import MsgBase, MsgException
 from time import sleep
+from http import HTTPStatus
 
 __author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
 
@@ -17,8 +35,8 @@ One text line per message is used in yaml format.
 
 class MsgLocal(MsgBase):
 
-    def __init__(self, logger_name='msg'):
-        self.logger = logging.getLogger(logger_name)
+    def __init__(self, logger_name='msg', lock=False):
+        super().__init__(logger_name, lock)
         self.path = None
         # create a different file for each topic
         self.files_read = {}
@@ -37,17 +55,19 @@ class MsgLocal(MsgBase):
         except MsgException:
             raise
         except Exception as e:  # TODO refine
-            raise MsgException(str(e))
+            raise MsgException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
 
     def disconnect(self):
-        for f in self.files_read.values():
+        for topic, f in self.files_read.items():
             try:
                 f.close()
+                self.files_read[topic] = None
             except Exception:  # TODO refine
                 pass
-        for f in self.files_write.values():
+        for topic, f in self.files_write.items():
             try:
                 f.close()
+                self.files_write[topic] = None
             except Exception:  # TODO refine
                 pass
 
@@ -60,12 +80,13 @@ class MsgLocal(MsgBase):
         :return: None or raises and exception
         """
         try:
-            if topic not in self.files_write:
-                self.files_write[topic] = open(self.path + topic, "a+")
-            yaml.safe_dump({key: msg}, self.files_write[topic], default_flow_style=True, width=20000)
-            self.files_write[topic].flush()
+            with self.lock:
+                if topic not in self.files_write:
+                    self.files_write[topic] = open(self.path + topic, "a+")
+                yaml.safe_dump({key: msg}, self.files_write[topic], default_flow_style=True, width=20000)
+                self.files_write[topic].flush()
         except Exception as e:  # TODO refine
-            raise MsgException(str(e))
+            raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
 
     def read(self, topic, blocks=True):
         """
@@ -81,24 +102,25 @@ class MsgLocal(MsgBase):
                 topic_list = (topic, )
             while True:
                 for single_topic in topic_list:
-                    if single_topic not in self.files_read:
-                        self.files_read[single_topic] = open(self.path + single_topic, "a+")
+                    with self.lock:
+                        if single_topic not in self.files_read:
+                            self.files_read[single_topic] = open(self.path + single_topic, "a+")
+                            self.buffer[single_topic] = ""
+                        self.buffer[single_topic] += self.files_read[single_topic].readline()
+                        if not self.buffer[single_topic].endswith("\n"):
+                            continue
+                        msg_dict = yaml.load(self.buffer[single_topic])
                         self.buffer[single_topic] = ""
-                    self.buffer[single_topic] += self.files_read[single_topic].readline()
-                    if not self.buffer[single_topic].endswith("\n"):
-                        continue
-                    msg_dict = yaml.load(self.buffer[single_topic])
-                    self.buffer[single_topic] = ""
-                    assert len(msg_dict) == 1
-                    for k, v in msg_dict.items():
-                        return single_topic, k, v
+                        assert len(msg_dict) == 1
+                        for k, v in msg_dict.items():
+                            return single_topic, k, v
                 if not blocks:
                     return None
                 sleep(2)
         except Exception as e:  # TODO refine
-            raise MsgException(str(e))
+            raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
 
-    async def aioread(self, topic, loop):
+    async def aioread(self, topic, loop=None, callback=None, aiocallback=None, **kwargs):
         """
         Asyncio read from one or several topics. It blocks
         :param topic: can be str: single topic; or str list: several topics
@@ -109,9 +131,25 @@ class MsgLocal(MsgBase):
             while True:
                 msg = self.read(topic, blocks=False)
                 if msg:
-                    return msg
+                    if callback:
+                        callback(*msg, **kwargs)
+                    elif aiocallback:
+                        await aiocallback(*msg, **kwargs)
+                    else:
+                        return msg
                 await asyncio.sleep(2, loop=loop)
         except MsgException:
             raise
         except Exception as e:  # TODO refine
-            raise MsgException(str(e))
+            raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
+
+    async def aiowrite(self, topic, key, msg, loop=None):
+        """
+        Asyncio write. It blocks
+        :param topic: str
+        :param key: str
+        :param msg: message, can be str or yaml
+        :param loop: asyncio loop
+        :return: nothing if ok or raises an exception
+        """
+        return self.write(topic, key, msg)