From 1e9a329ca0085be33665e35d123394905bc46d74 Mon Sep 17 00:00:00 2001 From: tierno Date: Mon, 5 Nov 2018 18:18:45 +0100 Subject: [PATCH] Make common methods threading safe. pytest enhancements Change-Id: Iaacf38c9bb9c31fc521cbde48acd0d6a9cb9a56d Signed-off-by: tierno --- osm_common/__init__.py | 4 +- osm_common/common_utils.py | 34 +++++++ osm_common/dbbase.py | 18 +++- osm_common/dbmemory.py | 66 +++++++------- osm_common/dbmongo.py | 50 ++++++----- osm_common/fsbase.py | 23 ++++- osm_common/fslocal.py | 4 +- osm_common/msgbase.py | 23 ++++- osm_common/msgkafka.py | 4 +- osm_common/msglocal.py | 40 +++++---- osm_common/tests/test_dbmemory.py | 144 +++++++++++++++--------------- osm_common/tests/test_fslocal.py | 6 +- osm_common/tests/test_msglocal.py | 46 ++++++---- 13 files changed, 287 insertions(+), 175 deletions(-) create mode 100644 osm_common/common_utils.py diff --git a/osm_common/__init__.py b/osm_common/__init__.py index 1c60968..81567f8 100644 --- a/osm_common/__init__.py +++ b/osm_common/__init__.py @@ -15,6 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -version = '0.1.11' +version = '0.1.12' # TODO add package version filling commit id with 0's; e.g.: '5.0.0.post11+00000000.dirty-1' -date_version = '2018-10-23' +date_version = '2018-11-05' diff --git a/osm_common/common_utils.py b/osm_common/common_utils.py new file mode 100644 index 0000000..4cb5857 --- /dev/null +++ b/osm_common/common_utils.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +# Copyright 2018 Telefonica S.A. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +__author__ = "Alfonso Tierno " + + +class FakeLock: + """Implements a fake lock that can be called with the "with" statement or acquire, release methods""" + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def acquire(self): + pass + + def release(self): + pass diff --git a/osm_common/dbbase.py b/osm_common/dbbase.py index 5fef9ee..81586ec 100644 --- a/osm_common/dbbase.py +++ b/osm_common/dbbase.py @@ -21,6 +21,8 @@ from http import HTTPStatus from copy import deepcopy from Crypto.Cipher import AES from base64 import b64decode, b64encode +from osm_common.common_utils import FakeLock +from threading import Lock __author__ = "Alfonso Tierno " @@ -34,14 +36,26 @@ class DbException(Exception): class DbBase(object): - def __init__(self, logger_name='db'): + def __init__(self, logger_name='db', lock=False): """ - Constructor od dbBase + Constructor of dbBase :param logger_name: logging name + :param lock: Used to protect simultaneous access to the same instance class by several threads: + False, None: Do not protect, this object will only be accessed by one thread + True: This object needs to be protected by several threads accessing. + Lock object. Use thi Lock for the threads access protection """ self.logger = logging.getLogger(logger_name) self.master_password = None self.secret_key = None + if not lock: + self.lock = FakeLock() + elif lock is True: + self.lock = Lock() + elif isinstance(lock, Lock): + self.lock = lock + else: + raise ValueError("lock parameter must be a Lock classclass or boolean") def db_connect(self, config, target_version=None): """ diff --git a/osm_common/dbmemory.py b/osm_common/dbmemory.py index bae68e2..b796e81 100644 --- a/osm_common/dbmemory.py +++ b/osm_common/dbmemory.py @@ -26,8 +26,8 @@ __author__ = "Alfonso Tierno " class DbMemory(DbBase): - def __init__(self, logger_name='db'): - super().__init__(logger_name) + def __init__(self, logger_name='db', lock=False): + super().__init__(logger_name, lock) self.db = {} def db_connect(self, config): @@ -63,8 +63,9 @@ class DbMemory(DbBase): """ try: result = [] - for _, row in self._find(table, self._format_filter(q_filter)): - result.append(deepcopy(row)) + with self.lock: + for _, row in self._find(table, self._format_filter(q_filter)): + result.append(deepcopy(row)) return result except DbException: raise @@ -84,13 +85,14 @@ class DbMemory(DbBase): """ try: result = None - for _, row in self._find(table, self._format_filter(q_filter)): - if not fail_on_more: - return deepcopy(row) - if result: - raise DbException("Found more than one entry with filter='{}'".format(q_filter), - HTTPStatus.CONFLICT.value) - result = row + with self.lock: + for _, row in self._find(table, self._format_filter(q_filter)): + if not fail_on_more: + return deepcopy(row) + if result: + raise DbException("Found more than one entry with filter='{}'".format(q_filter), + HTTPStatus.CONFLICT.value) + result = row if not result and fail_on_empty: raise DbException("Not found entry with filter='{}'".format(q_filter), HTTPStatus.NOT_FOUND) return deepcopy(result) @@ -106,8 +108,9 @@ class DbMemory(DbBase): """ try: id_list = [] - for i, _ in self._find(table, self._format_filter(q_filter)): - id_list.append(i) + with self.lock: + for i, _ in self._find(table, self._format_filter(q_filter)): + id_list.append(i) deleted = len(id_list) for i in reversed(id_list): del self.db[table][i] @@ -127,13 +130,14 @@ class DbMemory(DbBase): :return: Dict with the number of entries deleted """ try: - for i, _ in self._find(table, self._format_filter(q_filter)): - break - else: - if fail_on_empty: - raise DbException("Not found entry with filter='{}'".format(q_filter), HTTPStatus.NOT_FOUND) - return None - del self.db[table][i] + with self.lock: + for i, _ in self._find(table, self._format_filter(q_filter)): + break + else: + if fail_on_empty: + raise DbException("Not found entry with filter='{}'".format(q_filter), HTTPStatus.NOT_FOUND) + return None + del self.db[table][i] return {"deleted": 1} except Exception as e: # TODO refine raise DbException(str(e)) @@ -149,13 +153,14 @@ class DbMemory(DbBase): :return: Dict with the number of entries replaced """ try: - for i, _ in self._find(table, self._format_filter({"_id": _id})): - break - else: - if fail_on_empty: - raise DbException("Not found entry with _id='{}'".format(_id), HTTPStatus.NOT_FOUND) - return None - self.db[table][i] = deepcopy(indata) + with self.lock: + for i, _ in self._find(table, self._format_filter({"_id": _id})): + break + else: + if fail_on_empty: + raise DbException("Not found entry with _id='{}'".format(_id), HTTPStatus.NOT_FOUND) + return None + self.db[table][i] = deepcopy(indata) return {"updated": 1} except DbException: raise @@ -174,9 +179,10 @@ class DbMemory(DbBase): if not id: id = str(uuid4()) indata["_id"] = id - if table not in self.db: - self.db[table] = [] - self.db[table].append(deepcopy(indata)) + with self.lock: + if table not in self.db: + self.db[table] = [] + self.db[table].append(deepcopy(indata)) return id except Exception as e: # TODO refine raise DbException(str(e)) diff --git a/osm_common/dbmongo.py b/osm_common/dbmongo.py index 2e94a5a..8140521 100644 --- a/osm_common/dbmongo.py +++ b/osm_common/dbmongo.py @@ -62,8 +62,8 @@ class DbMongo(DbBase): conn_initial_timout = 120 conn_timout = 10 - def __init__(self, logger_name='db'): - super().__init__(logger_name) + def __init__(self, logger_name='db', lock=False): + super().__init__(logger_name, lock) self.client = None self.db = None @@ -204,9 +204,10 @@ class DbMongo(DbBase): """ try: result = [] - collection = self.db[table] - db_filter = self._format_filter(q_filter) - rows = collection.find(db_filter) + with self.lock: + collection = self.db[table] + db_filter = self._format_filter(q_filter) + rows = collection.find(db_filter) for row in rows: result.append(row) return result @@ -228,10 +229,11 @@ class DbMongo(DbBase): """ try: db_filter = self._format_filter(q_filter) - collection = self.db[table] - if not (fail_on_empty and fail_on_more): - return collection.find_one(db_filter) - rows = collection.find(db_filter) + with self.lock: + collection = self.db[table] + if not (fail_on_empty and fail_on_more): + return collection.find_one(db_filter) + rows = collection.find(db_filter) if rows.count() == 0: if fail_on_empty: raise DbException("Not found any {} with filter='{}'".format(table[:-1], q_filter), @@ -253,8 +255,9 @@ class DbMongo(DbBase): :return: Dict with the number of entries deleted """ try: - collection = self.db[table] - rows = collection.delete_many(self._format_filter(q_filter)) + with self.lock: + collection = self.db[table] + rows = collection.delete_many(self._format_filter(q_filter)) return {"deleted": rows.deleted_count} except DbException: raise @@ -271,8 +274,9 @@ class DbMongo(DbBase): :return: Dict with the number of entries deleted """ try: - collection = self.db[table] - rows = collection.delete_one(self._format_filter(q_filter)) + with self.lock: + collection = self.db[table] + rows = collection.delete_one(self._format_filter(q_filter)) if rows.deleted_count == 0: if fail_on_empty: raise DbException("Not found any {} with filter='{}'".format(table[:-1], q_filter), @@ -290,8 +294,9 @@ class DbMongo(DbBase): :return: database id of the inserted element. Raises a DbException on error """ try: - collection = self.db[table] - data = collection.insert_one(indata) + with self.lock: + collection = self.db[table] + data = collection.insert_one(indata) return data.inserted_id except Exception as e: # TODO refine raise DbException(e) @@ -307,8 +312,9 @@ class DbMongo(DbBase): :return: Dict with the number of entries modified. None if no matching is found. """ try: - collection = self.db[table] - rows = collection.update_one(self._format_filter(q_filter), {"$set": update_dict}) + with self.lock: + collection = self.db[table] + rows = collection.update_one(self._format_filter(q_filter), {"$set": update_dict}) if rows.matched_count == 0: if fail_on_empty: raise DbException("Not found any {} with filter='{}'".format(table[:-1], q_filter), @@ -327,8 +333,9 @@ class DbMongo(DbBase): :return: Dict with the number of entries modified """ try: - collection = self.db[table] - rows = collection.update_many(self._format_filter(q_filter), {"$set": update_dict}) + with self.lock: + collection = self.db[table] + rows = collection.update_many(self._format_filter(q_filter), {"$set": update_dict}) return {"modified": rows.modified_count} except Exception as e: # TODO refine raise DbException(e) @@ -345,8 +352,9 @@ class DbMongo(DbBase): """ try: db_filter = {"_id": _id} - collection = self.db[table] - rows = collection.replace_one(db_filter, indata) + with self.lock: + collection = self.db[table] + rows = collection.replace_one(db_filter, indata) if rows.matched_count == 0: if fail_on_empty: raise DbException("Not found any {} with _id='{}'".format(table[:-1], _id), HTTPStatus.NOT_FOUND) diff --git a/osm_common/fsbase.py b/osm_common/fsbase.py index 6f82cd3..b941c21 100644 --- a/osm_common/fsbase.py +++ b/osm_common/fsbase.py @@ -16,7 +16,10 @@ # limitations under the License. +import logging from http import HTTPStatus +from osm_common.common_utils import FakeLock +from threading import Lock __author__ = "Alfonso Tierno " @@ -28,8 +31,24 @@ class FsException(Exception): class FsBase(object): - def __init__(self): - pass + def __init__(self, logger_name='fs', lock=False): + """ + Constructor of FsBase + :param logger_name: logging name + :param lock: Used to protect simultaneous access to the same instance class by several threads: + False, None: Do not protect, this object will only be accessed by one thread + True: This object needs to be protected by several threads accessing. + Lock object. Use thi Lock for the threads access protection + """ + self.logger = logging.getLogger(logger_name) + if not lock: + self.lock = FakeLock() + elif lock is True: + self.lock = Lock() + elif isinstance(lock, Lock): + self.lock = lock + else: + raise ValueError("lock parameter must be a Lock class or boolean") def get_params(self): return {} diff --git a/osm_common/fslocal.py b/osm_common/fslocal.py index a027558..61600ec 100644 --- a/osm_common/fslocal.py +++ b/osm_common/fslocal.py @@ -27,8 +27,8 @@ __author__ = "Alfonso Tierno " class FsLocal(FsBase): - def __init__(self, logger_name='fs'): - self.logger = logging.getLogger(logger_name) + def __init__(self, logger_name='fs', lock=False): + super().__init__(logger_name, lock) self.path = None def get_params(self): diff --git a/osm_common/msgbase.py b/osm_common/msgbase.py index 0a15dae..92b24e0 100644 --- a/osm_common/msgbase.py +++ b/osm_common/msgbase.py @@ -15,7 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from http import HTTPStatus +from osm_common.common_utils import FakeLock +from threading import Lock __author__ = "Alfonso Tierno " @@ -40,8 +43,24 @@ class MsgBase(object): Base class for all msgXXXX classes """ - def __init__(self): - pass + def __init__(self, logger_name='msg', lock=False): + """ + Constructor of FsBase + :param logger_name: logging name + :param lock: Used to protect simultaneous access to the same instance class by several threads: + False, None: Do not protect, this object will only be accessed by one thread + True: This object needs to be protected by several threads accessing. + Lock object. Use thi Lock for the threads access protection + """ + self.logger = logging.getLogger(logger_name) + if not lock: + self.lock = FakeLock() + elif lock is True: + self.lock = Lock() + elif isinstance(lock, Lock): + self.lock = lock + else: + raise ValueError("lock parameter must be a Lock class or boolean") def connect(self, config): pass diff --git a/osm_common/msgkafka.py b/osm_common/msgkafka.py index 767fff6..b782685 100644 --- a/osm_common/msgkafka.py +++ b/osm_common/msgkafka.py @@ -26,8 +26,8 @@ __author__ = "Alfonso Tierno , " \ class MsgKafka(MsgBase): - def __init__(self, logger_name='msg'): - self.logger = logging.getLogger(logger_name) + def __init__(self, logger_name='msg', lock=False): + super().__init__(logger_name, lock) self.host = None self.port = None self.consumer = None diff --git a/osm_common/msglocal.py b/osm_common/msglocal.py index b0abb89..1e8e089 100644 --- a/osm_common/msglocal.py +++ b/osm_common/msglocal.py @@ -35,8 +35,8 @@ One text line per message is used in yaml format. class MsgLocal(MsgBase): - def __init__(self, logger_name='msg'): - self.logger = logging.getLogger(logger_name) + def __init__(self, logger_name='msg', lock=False): + super().__init__(logger_name, lock) self.path = None # create a different file for each topic self.files_read = {} @@ -58,14 +58,16 @@ class MsgLocal(MsgBase): raise MsgException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR) def disconnect(self): - for f in self.files_read.values(): + for topic, f in self.files_read.items(): try: f.close() + self.files_read[topic] = None except Exception: # TODO refine pass - for f in self.files_write.values(): + for topic, f in self.files_write.items(): try: f.close() + self.files_write[topic] = None except Exception: # TODO refine pass @@ -78,10 +80,11 @@ class MsgLocal(MsgBase): :return: None or raises and exception """ try: - if topic not in self.files_write: - self.files_write[topic] = open(self.path + topic, "a+") - yaml.safe_dump({key: msg}, self.files_write[topic], default_flow_style=True, width=20000) - self.files_write[topic].flush() + with self.lock: + if topic not in self.files_write: + self.files_write[topic] = open(self.path + topic, "a+") + yaml.safe_dump({key: msg}, self.files_write[topic], default_flow_style=True, width=20000) + self.files_write[topic].flush() except Exception as e: # TODO refine raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR) @@ -99,17 +102,18 @@ class MsgLocal(MsgBase): topic_list = (topic, ) while True: for single_topic in topic_list: - if single_topic not in self.files_read: - self.files_read[single_topic] = open(self.path + single_topic, "a+") + with self.lock: + if single_topic not in self.files_read: + self.files_read[single_topic] = open(self.path + single_topic, "a+") + self.buffer[single_topic] = "" + self.buffer[single_topic] += self.files_read[single_topic].readline() + if not self.buffer[single_topic].endswith("\n"): + continue + msg_dict = yaml.load(self.buffer[single_topic]) self.buffer[single_topic] = "" - self.buffer[single_topic] += self.files_read[single_topic].readline() - if not self.buffer[single_topic].endswith("\n"): - continue - msg_dict = yaml.load(self.buffer[single_topic]) - self.buffer[single_topic] = "" - assert len(msg_dict) == 1 - for k, v in msg_dict.items(): - return single_topic, k, v + assert len(msg_dict) == 1 + for k, v in msg_dict.items(): + return single_topic, k, v if not blocks: return None sleep(2) diff --git a/osm_common/tests/test_dbmemory.py b/osm_common/tests/test_dbmemory.py index 3e59e94..37c2c83 100644 --- a/osm_common/tests/test_dbmemory.py +++ b/osm_common/tests/test_dbmemory.py @@ -9,15 +9,15 @@ from osm_common.dbmemory import DbMemory __author__ = 'Eduardo Sousa ' -@pytest.fixture -def db_memory(): - db = DbMemory() +@pytest.fixture(scope="function", params=[True, False]) +def db_memory(request): + db = DbMemory(lock=request.param) return db -@pytest.fixture -def db_memory_with_data(): - db = DbMemory() +@pytest.fixture(scope="function", params=[True, False]) +def db_memory_with_data(request): + db = DbMemory(lock=request.param) db.create("test", {"_id": 1, "data": 1}) db.create("test", {"_id": 2, "data": 2}) @@ -30,16 +30,16 @@ def empty_exception_message(): return 'database exception ' -def get_one_exception_message(filter): - return "database exception Not found entry with filter='{}'".format(filter) +def get_one_exception_message(db_filter): + return "database exception Not found entry with filter='{}'".format(db_filter) -def get_one_multiple_exception_message(filter): - return "database exception Found more than one entry with filter='{}'".format(filter) +def get_one_multiple_exception_message(db_filter): + return "database exception Found more than one entry with filter='{}'".format(db_filter) -def del_one_exception_message(filter): - return "database exception Not found entry with filter='{}'".format(filter) +def del_one_exception_message(db_filter): + return "database exception Not found entry with filter='{}'".format(db_filter) def replace_exception_message(value): @@ -72,17 +72,17 @@ def test_db_disconnect(db_memory): db_memory.db_disconnect() -@pytest.mark.parametrize("table, filter", [ +@pytest.mark.parametrize("table, db_filter", [ ("test", {}), ("test", {"_id": 1}), ("test", {"data": 1}), ("test", {"_id": 1, "data": 1})]) -def test_get_list_with_empty_db(db_memory, table, filter): - result = db_memory.get_list(table, filter) +def test_get_list_with_empty_db(db_memory, table, db_filter): + result = db_memory.get_list(table, db_filter) assert len(result) == 0 -@pytest.mark.parametrize("table, filter, expected_data", [ +@pytest.mark.parametrize("table, db_filter, expected_data", [ ("test", {}, [{"_id": 1, "data": 1}, {"_id": 2, "data": 2}, {"_id": 3, "data": 3}]), ("test", {"_id": 1}, [{"_id": 1, "data": 1}]), ("test", {"data": 1}, [{"_id": 1, "data": 1}]), @@ -97,8 +97,8 @@ def test_get_list_with_empty_db(db_memory, table, filter): ("test_table", {"_id": 1}, []), ("test_table", {"data": 1}, []), ("test_table", {"_id": 1, "data": 1}, [])]) -def test_get_list_with_non_empty_db(db_memory_with_data, table, filter, expected_data): - result = db_memory_with_data.get_list(table, filter) +def test_get_list_with_non_empty_db(db_memory_with_data, table, db_filter, expected_data): + result = db_memory_with_data.get_list(table, db_filter) assert len(result) == len(expected_data) for data in expected_data: assert data in result @@ -106,15 +106,15 @@ def test_get_list_with_non_empty_db(db_memory_with_data, table, filter, expected def test_get_list_exception(db_memory_with_data): table = 'test' - filter = {} + db_filter = {} db_memory_with_data._find = MagicMock(side_effect=Exception()) with pytest.raises(DbException) as excinfo: - db_memory_with_data.get_list(table, filter) + db_memory_with_data.get_list(table, db_filter) assert str(excinfo.value) == empty_exception_message() assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND -@pytest.mark.parametrize("table, filter, expected_data", [ +@pytest.mark.parametrize("table, db_filter, expected_data", [ ("test", {"_id": 1}, {"_id": 1, "data": 1}), ("test", {"_id": 2}, {"_id": 2, "data": 2}), ("test", {"_id": 3}, {"_id": 3, "data": 3}), @@ -124,8 +124,8 @@ def test_get_list_exception(db_memory_with_data): ("test", {"_id": 1, "data": 1}, {"_id": 1, "data": 1}), ("test", {"_id": 2, "data": 2}, {"_id": 2, "data": 2}), ("test", {"_id": 3, "data": 3}, {"_id": 3, "data": 3})]) -def test_get_one(db_memory_with_data, table, filter, expected_data): - result = db_memory_with_data.get_one(table, filter) +def test_get_one(db_memory_with_data, table, db_filter, expected_data): + result = db_memory_with_data.get_one(table, db_filter) assert result == expected_data assert len(db_memory_with_data.db) == 1 assert table in db_memory_with_data.db @@ -133,10 +133,10 @@ def test_get_one(db_memory_with_data, table, filter, expected_data): assert result in db_memory_with_data.db[table] -@pytest.mark.parametrize("table, filter, expected_data", [ +@pytest.mark.parametrize("table, db_filter, expected_data", [ ("test", {}, {"_id": 1, "data": 1})]) -def test_get_one_with_multiple_results(db_memory_with_data, table, filter, expected_data): - result = db_memory_with_data.get_one(table, filter, fail_on_more=False) +def test_get_one_with_multiple_results(db_memory_with_data, table, db_filter, expected_data): + result = db_memory_with_data.get_one(table, db_filter, fail_on_more=False) assert result == expected_data assert len(db_memory_with_data.db) == 1 assert table in db_memory_with_data.db @@ -146,83 +146,83 @@ def test_get_one_with_multiple_results(db_memory_with_data, table, filter, expec def test_get_one_with_multiple_results_exception(db_memory_with_data): table = "test" - filter = {} + db_filter = {} with pytest.raises(DbException) as excinfo: - db_memory_with_data.get_one(table, filter) - assert str(excinfo.value) == (empty_exception_message() + get_one_multiple_exception_message(filter)) + db_memory_with_data.get_one(table, db_filter) + assert str(excinfo.value) == (empty_exception_message() + get_one_multiple_exception_message(db_filter)) # assert excinfo.value.http_code == http.HTTPStatus.CONFLICT -@pytest.mark.parametrize("table, filter", [ +@pytest.mark.parametrize("table, db_filter", [ ("test", {"_id": 4}), ("test", {"data": 4}), ("test", {"_id": 4, "data": 4}), ("test_table", {"_id": 4}), ("test_table", {"data": 4}), ("test_table", {"_id": 4, "data": 4})]) -def test_get_one_with_non_empty_db_exception(db_memory_with_data, table, filter): +def test_get_one_with_non_empty_db_exception(db_memory_with_data, table, db_filter): with pytest.raises(DbException) as excinfo: - db_memory_with_data.get_one(table, filter) - assert str(excinfo.value) == (empty_exception_message() + get_one_exception_message(filter)) + db_memory_with_data.get_one(table, db_filter) + assert str(excinfo.value) == (empty_exception_message() + get_one_exception_message(db_filter)) assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND -@pytest.mark.parametrize("table, filter", [ +@pytest.mark.parametrize("table, db_filter", [ ("test", {"_id": 4}), ("test", {"data": 4}), ("test", {"_id": 4, "data": 4}), ("test_table", {"_id": 4}), ("test_table", {"data": 4}), ("test_table", {"_id": 4, "data": 4})]) -def test_get_one_with_non_empty_db_none(db_memory_with_data, table, filter): - result = db_memory_with_data.get_one(table, filter, fail_on_empty=False) +def test_get_one_with_non_empty_db_none(db_memory_with_data, table, db_filter): + result = db_memory_with_data.get_one(table, db_filter, fail_on_empty=False) assert result is None -@pytest.mark.parametrize("table, filter", [ +@pytest.mark.parametrize("table, db_filter", [ ("test", {"_id": 4}), ("test", {"data": 4}), ("test", {"_id": 4, "data": 4}), ("test_table", {"_id": 4}), ("test_table", {"data": 4}), ("test_table", {"_id": 4, "data": 4})]) -def test_get_one_with_empty_db_exception(db_memory, table, filter): +def test_get_one_with_empty_db_exception(db_memory, table, db_filter): with pytest.raises(DbException) as excinfo: - db_memory.get_one(table, filter) - assert str(excinfo.value) == (empty_exception_message() + get_one_exception_message(filter)) + db_memory.get_one(table, db_filter) + assert str(excinfo.value) == (empty_exception_message() + get_one_exception_message(db_filter)) assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND -@pytest.mark.parametrize("table, filter", [ +@pytest.mark.parametrize("table, db_filter", [ ("test", {"_id": 4}), ("test", {"data": 4}), ("test", {"_id": 4, "data": 4}), ("test_table", {"_id": 4}), ("test_table", {"data": 4}), ("test_table", {"_id": 4, "data": 4})]) -def test_get_one_with_empty_db_none(db_memory, table, filter): - result = db_memory.get_one(table, filter, fail_on_empty=False) +def test_get_one_with_empty_db_none(db_memory, table, db_filter): + result = db_memory.get_one(table, db_filter, fail_on_empty=False) assert result is None def test_get_one_generic_exception(db_memory_with_data): table = 'test' - filter = {} + db_filter = {} db_memory_with_data._find = MagicMock(side_effect=Exception()) with pytest.raises(DbException) as excinfo: - db_memory_with_data.get_one(table, filter) + db_memory_with_data.get_one(table, db_filter) assert str(excinfo.value) == empty_exception_message() assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND -@pytest.mark.parametrize("table, filter, expected_data", [ +@pytest.mark.parametrize("table, db_filter, expected_data", [ ("test", {}, []), ("test", {"_id": 1}, [{"_id": 2, "data": 2}, {"_id": 3, "data": 3}]), ("test", {"_id": 2}, [{"_id": 1, "data": 1}, {"_id": 3, "data": 3}]), ("test", {"_id": 1, "data": 1}, [{"_id": 2, "data": 2}, {"_id": 3, "data": 3}]), ("test", {"_id": 2, "data": 2}, [{"_id": 1, "data": 1}, {"_id": 3, "data": 3}])]) -def test_del_list_with_non_empty_db(db_memory_with_data, table, filter, expected_data): - result = db_memory_with_data.del_list(table, filter) +def test_del_list_with_non_empty_db(db_memory_with_data, table, db_filter, expected_data): + result = db_memory_with_data.del_list(table, db_filter) assert result["deleted"] == (3 - len(expected_data)) assert len(db_memory_with_data.db) == 1 assert table in db_memory_with_data.db @@ -231,7 +231,7 @@ def test_del_list_with_non_empty_db(db_memory_with_data, table, filter, expected assert data in db_memory_with_data.db[table] -@pytest.mark.parametrize("table, filter", [ +@pytest.mark.parametrize("table, db_filter", [ ("test", {}), ("test", {"_id": 1}), ("test", {"_id": 2}), @@ -239,22 +239,22 @@ def test_del_list_with_non_empty_db(db_memory_with_data, table, filter, expected ("test", {"data": 2}), ("test", {"_id": 1, "data": 1}), ("test", {"_id": 2, "data": 2})]) -def test_del_list_with_empty_db(db_memory, table, filter): - result = db_memory.del_list(table, filter) +def test_del_list_with_empty_db(db_memory, table, db_filter): + result = db_memory.del_list(table, db_filter) assert result['deleted'] == 0 def test_del_list_generic_exception(db_memory_with_data): table = 'test' - filter = {} + db_filter = {} db_memory_with_data._find = MagicMock(side_effect=Exception()) with pytest.raises(DbException) as excinfo: - db_memory_with_data.del_list(table, filter) + db_memory_with_data.del_list(table, db_filter) assert str(excinfo.value) == empty_exception_message() assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND -@pytest.mark.parametrize("table, filter, data", [ +@pytest.mark.parametrize("table, db_filter, data", [ ("test", {}, {"_id": 1, "data": 1}), ("test", {"_id": 1}, {"_id": 1, "data": 1}), ("test", {"data": 1}, {"_id": 1, "data": 1}), @@ -262,8 +262,8 @@ def test_del_list_generic_exception(db_memory_with_data): ("test", {"_id": 2}, {"_id": 2, "data": 2}), ("test", {"data": 2}, {"_id": 2, "data": 2}), ("test", {"_id": 2, "data": 2}, {"_id": 2, "data": 2})]) -def test_del_one(db_memory_with_data, table, filter, data): - result = db_memory_with_data.del_one(table, filter) +def test_del_one(db_memory_with_data, table, db_filter, data): + result = db_memory_with_data.del_one(table, db_filter) assert result == {"deleted": 1} assert len(db_memory_with_data.db) == 1 assert table in db_memory_with_data.db @@ -271,7 +271,7 @@ def test_del_one(db_memory_with_data, table, filter, data): assert data not in db_memory_with_data.db[table] -@pytest.mark.parametrize("table, filter", [ +@pytest.mark.parametrize("table, db_filter", [ ("test", {}), ("test", {"_id": 1}), ("test", {"_id": 2}), @@ -286,14 +286,14 @@ def test_del_one(db_memory_with_data, table, filter, data): ("test_table", {"data": 2}), ("test_table", {"_id": 1, "data": 1}), ("test_table", {"_id": 2, "data": 2})]) -def test_del_one_with_empty_db_exception(db_memory, table, filter): +def test_del_one_with_empty_db_exception(db_memory, table, db_filter): with pytest.raises(DbException) as excinfo: - db_memory.del_one(table, filter) - assert str(excinfo.value) == (empty_exception_message() + del_one_exception_message(filter)) + db_memory.del_one(table, db_filter) + assert str(excinfo.value) == (empty_exception_message() + del_one_exception_message(db_filter)) assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND -@pytest.mark.parametrize("table, filter", [ +@pytest.mark.parametrize("table, db_filter", [ ("test", {}), ("test", {"_id": 1}), ("test", {"_id": 2}), @@ -308,12 +308,12 @@ def test_del_one_with_empty_db_exception(db_memory, table, filter): ("test_table", {"data": 2}), ("test_table", {"_id": 1, "data": 1}), ("test_table", {"_id": 2, "data": 2})]) -def test_del_one_with_empty_db_none(db_memory, table, filter): - result = db_memory.del_one(table, filter, fail_on_empty=False) +def test_del_one_with_empty_db_none(db_memory, table, db_filter): + result = db_memory.del_one(table, db_filter, fail_on_empty=False) assert result is None -@pytest.mark.parametrize("table, filter", [ +@pytest.mark.parametrize("table, db_filter", [ ("test", {"_id": 4}), ("test", {"_id": 5}), ("test", {"data": 4}), @@ -327,14 +327,14 @@ def test_del_one_with_empty_db_none(db_memory, table, filter): ("test_table", {"data": 2}), ("test_table", {"_id": 1, "data": 1}), ("test_table", {"_id": 2, "data": 2})]) -def test_del_one_with_non_empty_db_exception(db_memory_with_data, table, filter): +def test_del_one_with_non_empty_db_exception(db_memory_with_data, table, db_filter): with pytest.raises(DbException) as excinfo: - db_memory_with_data.del_one(table, filter) - assert str(excinfo.value) == (empty_exception_message() + del_one_exception_message(filter)) + db_memory_with_data.del_one(table, db_filter) + assert str(excinfo.value) == (empty_exception_message() + del_one_exception_message(db_filter)) assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND -@pytest.mark.parametrize("table, filter", [ +@pytest.mark.parametrize("table, db_filter", [ ("test", {"_id": 4}), ("test", {"_id": 5}), ("test", {"data": 4}), @@ -348,8 +348,8 @@ def test_del_one_with_non_empty_db_exception(db_memory_with_data, table, filter) ("test_table", {"data": 2}), ("test_table", {"_id": 1, "data": 1}), ("test_table", {"_id": 2, "data": 2})]) -def test_del_one_with_non_empty_db_none(db_memory_with_data, table, filter): - result = db_memory_with_data.del_one(table, filter, fail_on_empty=False) +def test_del_one_with_non_empty_db_none(db_memory_with_data, table, db_filter): + result = db_memory_with_data.del_one(table, db_filter, fail_on_empty=False) assert result is None @@ -358,10 +358,10 @@ def test_del_one_with_non_empty_db_none(db_memory_with_data, table, filter): (False)]) def test_del_one_generic_exception(db_memory_with_data, fail_on_empty): table = 'test' - filter = {} + db_filter = {} db_memory_with_data._find = MagicMock(side_effect=Exception()) with pytest.raises(DbException) as excinfo: - db_memory_with_data.del_one(table, filter, fail_on_empty=fail_on_empty) + db_memory_with_data.del_one(table, db_filter, fail_on_empty=fail_on_empty) assert str(excinfo.value) == empty_exception_message() assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND diff --git a/osm_common/tests/test_fslocal.py b/osm_common/tests/test_fslocal.py index 3a2bbb4..86b0491 100644 --- a/osm_common/tests/test_fslocal.py +++ b/osm_common/tests/test_fslocal.py @@ -22,9 +22,9 @@ def invalid_path(): return '/#tweeter/' -@pytest.fixture -def fs_local(): - fs = FsLocal() +@pytest.fixture(scope="function", params=[True, False]) +def fs_local(request): + fs = FsLocal(lock=request.param) fs.fs_connect({'path': valid_path()}) return fs diff --git a/osm_common/tests/test_msglocal.py b/osm_common/tests/test_msglocal.py index 93bd54d..f2b63cc 100644 --- a/osm_common/tests/test_msglocal.py +++ b/osm_common/tests/test_msglocal.py @@ -24,19 +24,19 @@ def invalid_path(): return '/#tweeter/' -@pytest.fixture -def msg_local(): - msg = MsgLocal() +@pytest.fixture(scope="function", params=[True, False]) +def msg_local(request): + msg = MsgLocal(lock=request.param) yield msg + msg.disconnect() if msg.path and msg.path != invalid_path() and msg.path != valid_path(): - msg.disconnect() shutil.rmtree(msg.path) -@pytest.fixture -def msg_local_config(): - msg = MsgLocal() +@pytest.fixture(scope="function", params=[True, False]) +def msg_local_config(request): + msg = MsgLocal(lock=request.param) msg.connect({"path": valid_path() + str(uuid.uuid4())}) yield msg @@ -45,9 +45,9 @@ def msg_local_config(): shutil.rmtree(msg.path) -@pytest.fixture -def msg_local_with_data(): - msg = MsgLocal() +@pytest.fixture(scope="function", params=[True, False]) +def msg_local_with_data(request): + msg = MsgLocal(lock=request.param) msg.connect({"path": valid_path() + str(uuid.uuid4())}) msg.write("topic1", "key1", "msg1") @@ -117,41 +117,49 @@ def test_connect_with_exception(msg_local, config): def test_disconnect(msg_local_config): + files_read = msg_local_config.files_read.copy() + files_write = msg_local_config.files_write.copy() msg_local_config.disconnect() - for f in msg_local_config.files_read.values(): + for f in files_read.values(): assert f.closed - for f in msg_local_config.files_write.values(): + for f in files_write.values(): assert f.closed def test_disconnect_with_read(msg_local_config): msg_local_config.read('topic1', blocks=False) msg_local_config.read('topic2', blocks=False) + files_read = msg_local_config.files_read.copy() + files_write = msg_local_config.files_write.copy() msg_local_config.disconnect() - for f in msg_local_config.files_read.values(): + for f in files_read.values(): assert f.closed - for f in msg_local_config.files_write.values(): + for f in files_write.values(): assert f.closed def test_disconnect_with_write(msg_local_with_data): + files_read = msg_local_with_data.files_read.copy() + files_write = msg_local_with_data.files_write.copy() msg_local_with_data.disconnect() - for f in msg_local_with_data.files_read.values(): + for f in files_read.values(): assert f.closed - for f in msg_local_with_data.files_write.values(): + for f in files_write.values(): assert f.closed def test_disconnect_with_read_and_write(msg_local_with_data): msg_local_with_data.read('topic1', blocks=False) msg_local_with_data.read('topic2', blocks=False) - + files_read = msg_local_with_data.files_read.copy() + files_write = msg_local_with_data.files_write.copy() + msg_local_with_data.disconnect() - for f in msg_local_with_data.files_read.values(): + for f in files_read.values(): assert f.closed - for f in msg_local_with_data.files_write.values(): + for f in files_write.values(): assert f.closed -- 2.17.1