improvements in dbmemory. Change yaml.load to save_load 00/7900/4
authortierno <alfonso.tiernosepulveda@telefonica.com>
Mon, 2 Sep 2019 16:04:16 +0000 (16:04 +0000)
committertierno <alfonso.tiernosepulveda@telefonica.com>
Tue, 17 Sep 2019 09:38:56 +0000 (09:38 +0000)
Change-Id: I577efa64a8c1503a084cb21b49ec7e3665b7b56f
Signed-off-by: tierno <alfonso.tiernosepulveda@telefonica.com>
osm_common/__init__.py
osm_common/dbmemory.py
osm_common/msgkafka.py
osm_common/msglocal.py
osm_common/tests/test_dbmemory.py
osm_common/tests/test_msglocal.py

index 5435050..bf77d1a 100644 (file)
@@ -15,5 +15,5 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-version = '6.0.2.post0'
+version = '6.0.2.post1'
 date_version = '2019-08-28'
index c2d9a11..c994640 100644 (file)
@@ -17,6 +17,7 @@
 
 import logging
 from osm_common.dbbase import DbException, DbBase
+from osm_common.dbmongo import deep_update
 from http import HTTPStatus
 from uuid import uuid4
 from copy import deepcopy
@@ -44,16 +45,102 @@ class DbMemory(DbBase):
 
     @staticmethod
     def _format_filter(q_filter):
-        return q_filter    # TODO
+        db_filter = {}
+        # split keys with ANYINDEX in this way:
+        # {"A.B.ANYINDEX.C.D.ANYINDEX.E": v }  -> {"A.B.ANYINDEX": {"C.D.ANYINDEX": {"E": v}}}
+        if q_filter:
+            for k, v in q_filter.items():
+                db_v = v
+                kleft, _, kright = k.rpartition(".ANYINDEX.")
+                while kleft:
+                    k = kleft + ".ANYINDEX"
+                    db_v = {kright: db_v}
+                    kleft, _, kright = k.rpartition(".ANYINDEX.")
+                deep_update(db_filter, {k: db_v})
+
+        return db_filter
 
     def _find(self, table, q_filter):
+
+        def recursive_find(key_list, key_next_index, content, operator, target):
+            if key_next_index == len(key_list) or content is None:
+                try:
+                    if operator == "eq":
+                        if isinstance(target, list) and not isinstance(content, list):
+                            return True if content in target else False
+                        return True if content == target else False
+                    elif operator in ("neq", "ne"):
+                        if isinstance(target, list) and not isinstance(content, list):
+                            return True if content not in target else False
+                        return True if content != target else False
+                    if operator == "gt":
+                        return content > target
+                    elif operator == "gte":
+                        return content >= target
+                    elif operator == "lt":
+                        return content < target
+                    elif operator == "lte":
+                        return content <= target
+                    elif operator == "cont":
+                        return content in target
+                    elif operator == "ncont":
+                        return content not in target
+                    else:
+                        raise DbException("Unknown filter operator '{}' in key '{}'".
+                                          format(operator, ".".join(key_list)), http_code=HTTPStatus.BAD_REQUEST)
+                except TypeError:
+                    return False
+
+            elif isinstance(content, dict):
+                return recursive_find(key_list, key_next_index+1, content.get(key_list[key_next_index]), operator,
+                                      target)
+            elif isinstance(content, list):
+                look_for_match = True  # when there is a match return immediately
+                if (target is None and operator not in ("neq", "ne")) or \
+                        (target is not None and operator in ("neq", "ne")):
+                    look_for_match = False  # when there is a match return immediately
+
+                for content_item in content:
+                    if key_list[key_next_index] == "ANYINDEX" and isinstance(v, dict):
+                        for k2, v2 in target.items():
+                            k_new_list = k2.split(".")
+                            new_operator = "eq"
+                            if k_new_list[-1] in ("eq", "ne", "gt", "gte", "lt", "lte", "cont", "ncont", "neq"):
+                                new_operator = k_new_list.pop()
+                            if not recursive_find(k_new_list, 0, content_item, new_operator, v2):
+                                match = False
+                                break
+                        else:
+                            match = True
+
+                    else:
+                        match = recursive_find(key_list, key_next_index, content_item, operator, target)
+                    if match == look_for_match:
+                        return match
+                if key_list[key_next_index].isdecimal() and int(key_list[key_next_index]) < len(content):
+                    match = recursive_find(key_list, key_next_index+1, content[int(key_list[key_next_index])],
+                                           operator, target)
+                    if match == look_for_match:
+                        return match
+                return not look_for_match
+            else:  # content is not dict, nor list neither None, so not found
+                if operator in ("neq", "ne"):
+                    return True if target is None else False
+                else:
+                    return True if target is None else False
+
         for i, row in enumerate(self.db.get(table, ())):
-            match = True
-            if q_filter:
-                for k, v in q_filter.items():
-                    if k not in row or v != row[k]:
-                        match = False
-            if match:
+            q_filter = q_filter or {}
+            for k, v in q_filter.items():
+                k_list = k.split(".")
+                operator = "eq"
+                if k_list[-1] in ("eq", "ne", "gt", "gte", "lt", "lte", "cont", "ncont", "neq"):
+                    operator = k_list.pop()
+                match = recursive_find(k_list, 0, row, operator, v)
+                if not match:
+                    break
+            else:
+                # match
                 yield i, row
 
     def get_list(self, table, q_filter=None):
@@ -144,6 +231,59 @@ class DbMemory(DbBase):
         except Exception as e:  # TODO refine
             raise DbException(str(e))
 
+    def set_one(self, table, q_filter, update_dict, fail_on_empty=True, unset=None, pull=None, push=None):
+        """
+        Modifies an entry at database
+        :param table: collection or table
+        :param q_filter: Filter
+        :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
+        :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
+        it raises a DbException
+        :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
+                      ignored. If not exist, it is ignored
+        :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
+                     if exist in the array is removed. If not exist, it is ignored
+        :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
+                     is appended to the end of the array
+        :return: Dict with the number of entries modified. None if no matching is found.
+        """
+        try:
+            with self.lock:
+                for i, db_item in self._find(table, self._format_filter(q_filter)):
+                    break
+                else:
+                    if fail_on_empty:
+                        raise DbException("Not found entry with _id='{}'".format(q_filter), HTTPStatus.NOT_FOUND)
+                    return None
+                for k, v in update_dict.items():
+                    db_nested = db_item
+                    k_list = k.split(".")
+                    k_nested_prev = k_list[0]
+                    for k_nested in k_list[1:]:
+                        if isinstance(db_nested[k_nested_prev], dict):
+                            if k_nested not in db_nested[k_nested_prev]:
+                                db_nested[k_nested_prev][k_nested] = None
+                        elif isinstance(db_nested[k_nested_prev], list) and k_nested.isdigit():
+                            # extend list with Nones if index greater than list
+                            k_nested = int(k_nested)
+                            if k_nested >= len(db_nested[k_nested_prev]):
+                                db_nested[k_nested_prev] += [None] * (k_nested - len(db_nested[k_nested_prev]) + 1)
+                        elif db_nested[k_nested_prev] is None:
+                            db_nested[k_nested_prev] = {k_nested: None}
+                        else:  # number, string, boolean, ... or list but with not integer key
+                            raise DbException("Cannot set '{}' on existing '{}={}'".format(k, k_nested_prev,
+                                                                                           db_nested[k_nested_prev]))
+
+                        db_nested = db_nested[k_nested_prev]
+                        k_nested_prev = k_nested
+
+                    db_nested[k_nested_prev] = v
+                return {"updated": 1}
+        except DbException:
+            raise
+        except Exception as e:  # TODO refine
+            raise DbException(str(e))
+
     def replace(self, table, _id, indata, fail_on_empty=True):
         """
         Replace the content of an entry
@@ -189,6 +329,29 @@ class DbMemory(DbBase):
         except Exception as e:  # TODO refine
             raise DbException(str(e))
 
+    def create_list(self, table, indata_list):
+        """
+        Add a new entry at database
+        :param table: collection or table
+        :param indata_list: list content to be added
+        :return: database ids of the inserted element. Raises a DbException on error
+        """
+        try:
+            _ids = []
+            for indata in indata_list:
+                _id = indata.get("_id")
+                if not _id:
+                    _id = str(uuid4())
+                    indata["_id"] = _id
+                with self.lock:
+                    if table not in self.db:
+                        self.db[table] = []
+                    self.db[table].append(deepcopy(indata))
+                _ids.append(_id)
+            return _ids
+        except Exception as e:  # TODO refine
+            raise DbException(str(e))
+
 
 if __name__ == '__main__':
     # some test code
index bc9147d..1e22c9f 100644 (file)
@@ -136,11 +136,12 @@ class MsgKafka(MsgBase):
 
             async for message in self.consumer:
                 if callback:
-                    callback(message.topic, yaml.load(message.key), yaml.load(message.value), **kwargs)
+                    callback(message.topic, yaml.safe_load(message.key), yaml.safe_load(message.value), **kwargs)
                 elif aiocallback:
-                    await aiocallback(message.topic, yaml.load(message.key), yaml.load(message.value), **kwargs)
+                    await aiocallback(message.topic, yaml.safe_load(message.key), yaml.safe_load(message.value),
+                                      **kwargs)
                 else:
-                    return message.topic, yaml.load(message.key), yaml.load(message.value)
+                    return message.topic, yaml.safe_load(message.key), yaml.safe_load(message.value)
         except KafkaError as e:
             raise MsgException(str(e))
         finally:
index 965cb26..843b376 100644 (file)
@@ -112,7 +112,7 @@ class MsgLocal(MsgBase):
                         self.buffer[single_topic] += self.files_read[single_topic].readline()
                         if not self.buffer[single_topic].endswith("\n"):
                             continue
-                        msg_dict = yaml.load(self.buffer[single_topic])
+                        msg_dict = yaml.safe_load(self.buffer[single_topic])
                         self.buffer[single_topic] = ""
                         assert len(msg_dict) == 1
                         for k, v in msg_dict.items():
index 63f1a6b..e89560b 100644 (file)
@@ -20,6 +20,8 @@
 import http
 import logging
 import pytest
+import unittest
+from unittest.mock import Mock
 
 from unittest.mock import MagicMock
 from osm_common.dbbase import DbException
@@ -45,6 +47,23 @@ def db_memory_with_data(request):
     return db
 
 
+@pytest.fixture(scope="function")
+def db_memory_with_many_data(request):
+    db = DbMemory(lock=False)
+
+    db.create_list("test", [
+        {"_id": 1, "data": {"data2": {"data3": 1}}, "list": [{"a": 1}], "text": "sometext"},
+        {"_id": 2, "data": {"data2": {"data3": 2}}, "list": [{"a": 2}]},
+        {"_id": 3, "data": {"data2": {"data3": 3}}, "list": [{"a": 3}]},
+        {"_id": 4, "data": {"data2": {"data3": 4}}, "list": [{"a": 4}, {"a": 0}]},
+        {"_id": 5, "data": {"data2": {"data3": 5}}, "list": [{"a": 5}]},
+        {"_id": 6, "data": {"data2": {"data3": 6}}, "list": [{"0": {"a": 1}}]},
+        {"_id": 7, "data": {"data2": {"data3": 7}}, "0": {"a": 0}},
+        {"_id": 8, "list": [{"a": 3, "b": 0, "c": [{"a": 3, "b": 1}, {"a": 0, "b": "v"}]}, {"a": 0, "b": 1}]},
+    ])
+    return db
+
+
 def empty_exception_message():
     return 'database exception '
 
@@ -68,14 +87,14 @@ def replace_exception_message(value):
 def test_constructor():
     db = DbMemory()
     assert db.logger == logging.getLogger('db')
-    assert len(db.db) == 0
+    assert db.db == {}
 
 
 def test_constructor_with_logger():
     logger_name = 'db_local'
     db = DbMemory(logger_name=logger_name)
     assert db.logger == logging.getLogger(logger_name)
-    assert len(db.db) == 0
+    assert db.db == {}
 
 
 def test_db_connect():
@@ -84,7 +103,7 @@ def test_db_connect():
     db = DbMemory()
     db.db_connect(config)
     assert db.logger == logging.getLogger(logger_name)
-    assert len(db.db) == 0
+    assert db.db == {}
 
 
 def test_db_disconnect(db_memory):
@@ -152,6 +171,60 @@ def test_get_one(db_memory_with_data, table, db_filter, expected_data):
     assert result in db_memory_with_data.db[table]
 
 
+@pytest.mark.parametrize("db_filter, expected_ids", [
+    ({}, [1, 2, 3, 4, 5, 6, 7, 8]),
+    ({"_id": 1}, [1]),
+    ({"data.data2.data3": 2}, [2]),
+    ({"data.data2.data3.eq": 2}, [2]),
+    ({"data.data2.data3.neq": 2}, [1, 3, 4, 5, 6, 7, 8]),
+    ({"data.data2.data3": [2, 3]}, [2, 3]),
+    ({"data.data2.data3.gt": 4}, [5, 6, 7]),
+    ({"data.data2.data3.gte": 4}, [4, 5, 6, 7]),
+    ({"data.data2.data3.lt": 4}, [1, 2, 3]),
+    ({"data.data2.data3.lte": 4}, [1, 2, 3, 4]),
+    ({"data.data2.data3.lte": 4.5}, [1, 2, 3, 4]),
+    ({"data.data2.data3.gt": "text"}, []),
+    ({"text.eq": "sometext"}, [1]),
+    ({"text.neq": "sometext"}, [2, 3, 4, 5, 6, 7, 8]),
+    ({"text.eq": "somet"}, []),
+    ({"text.gte": "a"}, [1]),
+    ({"text.gte": "somet"}, [1]),
+    ({"text.gte": "sometext"}, [1]),
+    ({"text.lt": "somet"}, []),
+    ({"data.data2.data3": 2, "data.data2.data4": None}, [2]),
+    ({"data.data2.data3": 2, "data.data2.data4": 5}, []),
+    ({"data.data2.data3": 4}, [4]),
+    ({"data.data2.data3": [3, 4, "e"]}, [3, 4]),
+    ({"data.data2.data3": None}, [8]),
+    ({"data.data2": "4"}, []),
+    ({"list.0.a": 1}, [1, 6]),
+    ({"list.ANYINDEX.a": 1}, [1]),
+    ({"list.a": 3, "list.b": 1}, [8]),
+    ({"list.ANYINDEX.a": 3, "list.ANYINDEX.b": 1}, []),
+    ({"list.ANYINDEX.a": 3, "list.ANYINDEX.c.a": 3}, [8]),
+    ({"list.ANYINDEX.a": 3, "list.ANYINDEX.b": 0}, [8]),
+    ({"list.ANYINDEX.a": 3, "list.ANYINDEX.c.ANYINDEX.a": 0, "list.ANYINDEX.c.ANYINDEX.b": "v"}, [8]),
+    ({"list.ANYINDEX.a": 3, "list.ANYINDEX.c.ANYINDEX.a": 0, "list.ANYINDEX.c.ANYINDEX.b": 1}, []),
+    ({"list.c.b": 1}, [8]),
+    ({"list.c.b": None}, [1, 2, 3, 4, 5, 6, 7]),
+    # ({"data.data2.data3": 4}, []),
+    # ({"data.data2.data3": 4}, []),
+])
+def test_get_list(db_memory_with_many_data, db_filter, expected_ids):
+    result = db_memory_with_many_data.get_list("test", db_filter)
+    assert isinstance(result, list)
+    result_ids = [item["_id"] for item in result]
+    assert len(result) == len(expected_ids), "for db_filter={} result={} expected_ids={}".format(db_filter, result,
+                                                                                                 result_ids)
+    assert result_ids == expected_ids
+    for i in range(len(result)):
+        assert result[i] in db_memory_with_many_data.db["test"]
+
+    assert len(db_memory_with_many_data.db) == 1
+    assert "test" in db_memory_with_many_data.db
+    assert len(db_memory_with_many_data.db["test"]) == 8
+
+
 @pytest.mark.parametrize("table, db_filter, expected_data", [
     ("test", {}, {"_id": 1, "data": 1})])
 def test_get_one_with_multiple_results(db_memory_with_data, table, db_filter, expected_data):
@@ -236,8 +309,8 @@ def test_get_one_generic_exception(db_memory_with_data):
 
 @pytest.mark.parametrize("table, db_filter, expected_data", [
     ("test", {}, []),
-    ("test", {"_id": 1}, [{"_id": 2, "data": 2}, {"_id": 3, "data": 3}]), 
-    ("test", {"_id": 2}, [{"_id": 1, "data": 1}, {"_id": 3, "data": 3}]), 
+    ("test", {"_id": 1}, [{"_id": 2, "data": 2}, {"_id": 3, "data": 3}]),
+    ("test", {"_id": 2}, [{"_id": 1, "data": 1}, {"_id": 3, "data": 3}]),
     ("test", {"_id": 1, "data": 1}, [{"_id": 2, "data": 2}, {"_id": 3, "data": 3}]),
     ("test", {"_id": 2, "data": 2}, [{"_id": 1, "data": 1}, {"_id": 3, "data": 3}])])
 def test_del_list_with_non_empty_db(db_memory_with_data, table, db_filter, expected_data):
@@ -577,3 +650,59 @@ def test_create_with_exception(db_memory):
         db_memory.create(table, data)
     assert str(excinfo.value) == empty_exception_message()
     assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize("db_content, update_dict, expected, message", [
+    ({"a": {"none": None}}, {"a.b.num": "v"}, {"a": {"none": None, "b": {"num": "v"}}}, "create dict"),
+    ({"a": {"none": None}}, {"a.none.num": "v"}, {"a": {"none": {"num": "v"}}}, "create dict over none"),
+    ({"a": {"b": {"num": 4}}}, {"a.b.num": "v"}, {"a": {"b": {"num": "v"}}}, "replace_number"),
+    ({"a": {"b": {"num": 4}}}, {"a.b.num.c.d": "v"}, None, "create dict over number should fail"),
+    ({"a": {"b": {"num": 4}}}, {"a.b": "v"}, {"a": {"b": "v"}}, "replace dict with a string"),
+    ({"a": {"b": {"num": 4}}}, {"a.b": None}, {"a": {"b": None}}, "replace dict with None"),
+    ({"a": [{"b": {"num": 4}}]}, {"a.b.num": "v"}, None, "create dict over list should fail"),
+    ({"a": [{"b": {"num": 4}}]}, {"a.0.b.num": "v"}, {"a": [{"b": {"num": "v"}}]}, "set list"),
+    ({"a": [{"b": {"num": 4}}]}, {"a.3.b.num": "v"},
+     {"a": [{"b": {"num": 4}}, None, None, {"b": {"num": "v"}}]}, "expand list"),
+    ({"a": [[4]]}, {"a.0.0": "v"}, {"a": [["v"]]}, "set nested list"),
+    ({"a": [[4]]}, {"a.0.2": "v"}, {"a": [[4, None, "v"]]}, "expand nested list"),
+    ({"a": [[4]]}, {"a.2.2": "v"}, {"a": [[4], None, {"2": "v"}]}, "expand list and add number key")])
+def test_set_one(db_memory, db_content, update_dict, expected, message):
+    db_memory._find = Mock(return_value=((0, db_content), ))
+    if expected is None:
+        with pytest.raises(DbException) as excinfo:
+            db_memory.set_one("table", {}, update_dict)
+        assert (excinfo.value.http_code == http.HTTPStatus.NOT_FOUND), message
+    else:
+        db_memory.set_one("table", {}, update_dict)
+        assert (db_content == expected), message
+
+
+class TestDbMemory(unittest.TestCase):
+    # TODO to delete. This is cover with pytest test_set_one.
+    def test_set_one(self):
+        test_set = (
+            # (database content, set-content, expected database content (None=fails), message)
+            ({"a": {"none": None}}, {"a.b.num": "v"}, {"a": {"none": None, "b": {"num": "v"}}}, "create dict"),
+            ({"a": {"none": None}}, {"a.none.num": "v"}, {"a": {"none": {"num": "v"}}}, "create dict over none"),
+            ({"a": {"b": {"num": 4}}}, {"a.b.num": "v"}, {"a": {"b": {"num": "v"}}}, "replace_number"),
+            ({"a": {"b": {"num": 4}}}, {"a.b.num.c.d": "v"}, None, "create dict over number should fail"),
+            ({"a": {"b": {"num": 4}}}, {"a.b": "v"}, {"a": {"b": "v"}}, "replace dict with a string"),
+            ({"a": {"b": {"num": 4}}}, {"a.b": None}, {"a": {"b": None}}, "replace dict with None"),
+
+            ({"a": [{"b": {"num": 4}}]}, {"a.b.num": "v"}, None, "create dict over list should fail"),
+            ({"a": [{"b": {"num": 4}}]}, {"a.0.b.num": "v"}, {"a": [{"b": {"num": "v"}}]}, "set list"),
+            ({"a": [{"b": {"num": 4}}]}, {"a.3.b.num": "v"},
+             {"a": [{"b": {"num": 4}}, None, None, {"b": {"num": "v"}}]}, "expand list"),
+            ({"a": [[4]]}, {"a.0.0": "v"}, {"a": [["v"]]}, "set nested list"),
+            ({"a": [[4]]}, {"a.0.2": "v"}, {"a": [[4, None, "v"]]}, "expand nested list"),
+            ({"a": [[4]]}, {"a.2.2": "v"}, {"a": [[4], None, {"2": "v"}]}, "expand list and add number key"),
+        )
+        db_men = DbMemory()
+        db_men._find = Mock()
+        for db_content, update_dict, expected, message in test_set:
+            db_men._find.return_value = ((0, db_content), )
+            if expected is None:
+                self.assertRaises(DbException, db_men.set_one, "table", {}, update_dict)
+            else:
+                db_men.set_one("table", {}, update_dict)
+                self.assertEqual(db_content, expected, message)
index 5c62639..41f6eb8 100644 (file)
@@ -200,7 +200,7 @@ def test_write(msg_local_config, topic, key, msg):
     assert os.path.exists(file_path)
 
     with open(file_path, 'r') as stream:
-        assert yaml.load(stream) == {key: msg if not isinstance(msg, tuple) else list(msg)}
+        assert yaml.safe_load(stream) == {key: msg if not isinstance(msg, tuple) else list(msg)}
 
 
 @pytest.mark.parametrize("topic, key, msg, times", [
@@ -225,7 +225,7 @@ def test_write_with_multiple_calls(msg_local_config, topic, key, msg, times):
     with open(file_path, 'r') as stream:
         for _ in range(times):
             data = stream.readline()
-            assert yaml.load(data) == {key: msg if not isinstance(msg, tuple) else list(msg)}
+            assert yaml.safe_load(data) == {key: msg if not isinstance(msg, tuple) else list(msg)}
 
 
 def test_write_exception(msg_local_config):
@@ -453,7 +453,7 @@ def test_aiowrite(msg_local_config, event_loop, topic, key, msg):
     assert os.path.exists(file_path)
 
     with open(file_path, 'r') as stream:
-        assert yaml.load(stream) == {key: msg if not isinstance(msg, tuple) else list(msg)}
+        assert yaml.safe_load(stream) == {key: msg if not isinstance(msg, tuple) else list(msg)}
 
 
 @pytest.mark.parametrize("topic, key, msg, times", [
@@ -477,7 +477,7 @@ def test_aiowrite_with_multiple_calls(msg_local_config, event_loop, topic, key,
     with open(file_path, 'r') as stream:
         for _ in range(times):
             data = stream.readline()
-            assert yaml.load(data) == {key: msg if not isinstance(msg, tuple) else list(msg)}
+            assert yaml.safe_load(data) == {key: msg if not isinstance(msg, tuple) else list(msg)}
 
 
 def test_aiowrite_exception(msg_local_config, event_loop):