Make common methods threading safe. pytest enhancements
Change-Id: Iaacf38c9bb9c31fc521cbde48acd0d6a9cb9a56d
Signed-off-by: tierno <alfonso.tiernosepulveda@telefonica.com>
diff --git a/osm_common/__init__.py b/osm_common/__init__.py
index 1c60968..81567f8 100644
--- a/osm_common/__init__.py
+++ b/osm_common/__init__.py
@@ -15,6 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-version = '0.1.11'
+version = '0.1.12'
# TODO add package version filling commit id with 0's; e.g.: '5.0.0.post11+00000000.dirty-1'
-date_version = '2018-10-23'
+date_version = '2018-11-05'
diff --git a/osm_common/common_utils.py b/osm_common/common_utils.py
new file mode 100644
index 0000000..4cb5857
--- /dev/null
+++ b/osm_common/common_utils.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
+
+
+class FakeLock:
+ """Implements a fake lock that can be called with the "with" statement or acquire, release methods"""
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ def acquire(self):
+ pass
+
+ def release(self):
+ pass
diff --git a/osm_common/dbbase.py b/osm_common/dbbase.py
index 5fef9ee..81586ec 100644
--- a/osm_common/dbbase.py
+++ b/osm_common/dbbase.py
@@ -21,6 +21,8 @@
from copy import deepcopy
from Crypto.Cipher import AES
from base64 import b64decode, b64encode
+from osm_common.common_utils import FakeLock
+from threading import Lock
__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
@@ -34,14 +36,26 @@
class DbBase(object):
- def __init__(self, logger_name='db'):
+ def __init__(self, logger_name='db', lock=False):
"""
- Constructor od dbBase
+ Constructor of dbBase
:param logger_name: logging name
+ :param lock: Used to protect simultaneous access to the same instance class by several threads:
+ False, None: Do not protect, this object will only be accessed by one thread
+ True: This object needs to be protected by several threads accessing.
+ Lock object. Use thi Lock for the threads access protection
"""
self.logger = logging.getLogger(logger_name)
self.master_password = None
self.secret_key = None
+ if not lock:
+ self.lock = FakeLock()
+ elif lock is True:
+ self.lock = Lock()
+ elif isinstance(lock, Lock):
+ self.lock = lock
+ else:
+ raise ValueError("lock parameter must be a Lock classclass or boolean")
def db_connect(self, config, target_version=None):
"""
diff --git a/osm_common/dbmemory.py b/osm_common/dbmemory.py
index bae68e2..b796e81 100644
--- a/osm_common/dbmemory.py
+++ b/osm_common/dbmemory.py
@@ -26,8 +26,8 @@
class DbMemory(DbBase):
- def __init__(self, logger_name='db'):
- super().__init__(logger_name)
+ def __init__(self, logger_name='db', lock=False):
+ super().__init__(logger_name, lock)
self.db = {}
def db_connect(self, config):
@@ -63,8 +63,9 @@
"""
try:
result = []
- for _, row in self._find(table, self._format_filter(q_filter)):
- result.append(deepcopy(row))
+ with self.lock:
+ for _, row in self._find(table, self._format_filter(q_filter)):
+ result.append(deepcopy(row))
return result
except DbException:
raise
@@ -84,13 +85,14 @@
"""
try:
result = None
- for _, row in self._find(table, self._format_filter(q_filter)):
- if not fail_on_more:
- return deepcopy(row)
- if result:
- raise DbException("Found more than one entry with filter='{}'".format(q_filter),
- HTTPStatus.CONFLICT.value)
- result = row
+ with self.lock:
+ for _, row in self._find(table, self._format_filter(q_filter)):
+ if not fail_on_more:
+ return deepcopy(row)
+ if result:
+ raise DbException("Found more than one entry with filter='{}'".format(q_filter),
+ HTTPStatus.CONFLICT.value)
+ result = row
if not result and fail_on_empty:
raise DbException("Not found entry with filter='{}'".format(q_filter), HTTPStatus.NOT_FOUND)
return deepcopy(result)
@@ -106,8 +108,9 @@
"""
try:
id_list = []
- for i, _ in self._find(table, self._format_filter(q_filter)):
- id_list.append(i)
+ with self.lock:
+ for i, _ in self._find(table, self._format_filter(q_filter)):
+ id_list.append(i)
deleted = len(id_list)
for i in reversed(id_list):
del self.db[table][i]
@@ -127,13 +130,14 @@
:return: Dict with the number of entries deleted
"""
try:
- for i, _ in self._find(table, self._format_filter(q_filter)):
- break
- else:
- if fail_on_empty:
- raise DbException("Not found entry with filter='{}'".format(q_filter), HTTPStatus.NOT_FOUND)
- return None
- del self.db[table][i]
+ with self.lock:
+ for i, _ in self._find(table, self._format_filter(q_filter)):
+ break
+ else:
+ if fail_on_empty:
+ raise DbException("Not found entry with filter='{}'".format(q_filter), HTTPStatus.NOT_FOUND)
+ return None
+ del self.db[table][i]
return {"deleted": 1}
except Exception as e: # TODO refine
raise DbException(str(e))
@@ -149,13 +153,14 @@
:return: Dict with the number of entries replaced
"""
try:
- for i, _ in self._find(table, self._format_filter({"_id": _id})):
- break
- else:
- if fail_on_empty:
- raise DbException("Not found entry with _id='{}'".format(_id), HTTPStatus.NOT_FOUND)
- return None
- self.db[table][i] = deepcopy(indata)
+ with self.lock:
+ for i, _ in self._find(table, self._format_filter({"_id": _id})):
+ break
+ else:
+ if fail_on_empty:
+ raise DbException("Not found entry with _id='{}'".format(_id), HTTPStatus.NOT_FOUND)
+ return None
+ self.db[table][i] = deepcopy(indata)
return {"updated": 1}
except DbException:
raise
@@ -174,9 +179,10 @@
if not id:
id = str(uuid4())
indata["_id"] = id
- if table not in self.db:
- self.db[table] = []
- self.db[table].append(deepcopy(indata))
+ with self.lock:
+ if table not in self.db:
+ self.db[table] = []
+ self.db[table].append(deepcopy(indata))
return id
except Exception as e: # TODO refine
raise DbException(str(e))
diff --git a/osm_common/dbmongo.py b/osm_common/dbmongo.py
index 2e94a5a..8140521 100644
--- a/osm_common/dbmongo.py
+++ b/osm_common/dbmongo.py
@@ -62,8 +62,8 @@
conn_initial_timout = 120
conn_timout = 10
- def __init__(self, logger_name='db'):
- super().__init__(logger_name)
+ def __init__(self, logger_name='db', lock=False):
+ super().__init__(logger_name, lock)
self.client = None
self.db = None
@@ -204,9 +204,10 @@
"""
try:
result = []
- collection = self.db[table]
- db_filter = self._format_filter(q_filter)
- rows = collection.find(db_filter)
+ with self.lock:
+ collection = self.db[table]
+ db_filter = self._format_filter(q_filter)
+ rows = collection.find(db_filter)
for row in rows:
result.append(row)
return result
@@ -228,10 +229,11 @@
"""
try:
db_filter = self._format_filter(q_filter)
- collection = self.db[table]
- if not (fail_on_empty and fail_on_more):
- return collection.find_one(db_filter)
- rows = collection.find(db_filter)
+ with self.lock:
+ collection = self.db[table]
+ if not (fail_on_empty and fail_on_more):
+ return collection.find_one(db_filter)
+ rows = collection.find(db_filter)
if rows.count() == 0:
if fail_on_empty:
raise DbException("Not found any {} with filter='{}'".format(table[:-1], q_filter),
@@ -253,8 +255,9 @@
:return: Dict with the number of entries deleted
"""
try:
- collection = self.db[table]
- rows = collection.delete_many(self._format_filter(q_filter))
+ with self.lock:
+ collection = self.db[table]
+ rows = collection.delete_many(self._format_filter(q_filter))
return {"deleted": rows.deleted_count}
except DbException:
raise
@@ -271,8 +274,9 @@
:return: Dict with the number of entries deleted
"""
try:
- collection = self.db[table]
- rows = collection.delete_one(self._format_filter(q_filter))
+ with self.lock:
+ collection = self.db[table]
+ rows = collection.delete_one(self._format_filter(q_filter))
if rows.deleted_count == 0:
if fail_on_empty:
raise DbException("Not found any {} with filter='{}'".format(table[:-1], q_filter),
@@ -290,8 +294,9 @@
:return: database id of the inserted element. Raises a DbException on error
"""
try:
- collection = self.db[table]
- data = collection.insert_one(indata)
+ with self.lock:
+ collection = self.db[table]
+ data = collection.insert_one(indata)
return data.inserted_id
except Exception as e: # TODO refine
raise DbException(e)
@@ -307,8 +312,9 @@
:return: Dict with the number of entries modified. None if no matching is found.
"""
try:
- collection = self.db[table]
- rows = collection.update_one(self._format_filter(q_filter), {"$set": update_dict})
+ with self.lock:
+ collection = self.db[table]
+ rows = collection.update_one(self._format_filter(q_filter), {"$set": update_dict})
if rows.matched_count == 0:
if fail_on_empty:
raise DbException("Not found any {} with filter='{}'".format(table[:-1], q_filter),
@@ -327,8 +333,9 @@
:return: Dict with the number of entries modified
"""
try:
- collection = self.db[table]
- rows = collection.update_many(self._format_filter(q_filter), {"$set": update_dict})
+ with self.lock:
+ collection = self.db[table]
+ rows = collection.update_many(self._format_filter(q_filter), {"$set": update_dict})
return {"modified": rows.modified_count}
except Exception as e: # TODO refine
raise DbException(e)
@@ -345,8 +352,9 @@
"""
try:
db_filter = {"_id": _id}
- collection = self.db[table]
- rows = collection.replace_one(db_filter, indata)
+ with self.lock:
+ collection = self.db[table]
+ rows = collection.replace_one(db_filter, indata)
if rows.matched_count == 0:
if fail_on_empty:
raise DbException("Not found any {} with _id='{}'".format(table[:-1], _id), HTTPStatus.NOT_FOUND)
diff --git a/osm_common/fsbase.py b/osm_common/fsbase.py
index 6f82cd3..b941c21 100644
--- a/osm_common/fsbase.py
+++ b/osm_common/fsbase.py
@@ -16,7 +16,10 @@
# limitations under the License.
+import logging
from http import HTTPStatus
+from osm_common.common_utils import FakeLock
+from threading import Lock
__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
@@ -28,8 +31,24 @@
class FsBase(object):
- def __init__(self):
- pass
+ def __init__(self, logger_name='fs', lock=False):
+ """
+ Constructor of FsBase
+ :param logger_name: logging name
+ :param lock: Used to protect simultaneous access to the same instance class by several threads:
+ False, None: Do not protect, this object will only be accessed by one thread
+ True: This object needs to be protected by several threads accessing.
+ Lock object. Use thi Lock for the threads access protection
+ """
+ self.logger = logging.getLogger(logger_name)
+ if not lock:
+ self.lock = FakeLock()
+ elif lock is True:
+ self.lock = Lock()
+ elif isinstance(lock, Lock):
+ self.lock = lock
+ else:
+ raise ValueError("lock parameter must be a Lock class or boolean")
def get_params(self):
return {}
diff --git a/osm_common/fslocal.py b/osm_common/fslocal.py
index a027558..61600ec 100644
--- a/osm_common/fslocal.py
+++ b/osm_common/fslocal.py
@@ -27,8 +27,8 @@
class FsLocal(FsBase):
- def __init__(self, logger_name='fs'):
- self.logger = logging.getLogger(logger_name)
+ def __init__(self, logger_name='fs', lock=False):
+ super().__init__(logger_name, lock)
self.path = None
def get_params(self):
diff --git a/osm_common/msgbase.py b/osm_common/msgbase.py
index 0a15dae..92b24e0 100644
--- a/osm_common/msgbase.py
+++ b/osm_common/msgbase.py
@@ -15,7 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from http import HTTPStatus
+from osm_common.common_utils import FakeLock
+from threading import Lock
__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
@@ -40,8 +43,24 @@
Base class for all msgXXXX classes
"""
- def __init__(self):
- pass
+ def __init__(self, logger_name='msg', lock=False):
+ """
+ Constructor of FsBase
+ :param logger_name: logging name
+ :param lock: Used to protect simultaneous access to the same instance class by several threads:
+ False, None: Do not protect, this object will only be accessed by one thread
+ True: This object needs to be protected by several threads accessing.
+ Lock object. Use thi Lock for the threads access protection
+ """
+ self.logger = logging.getLogger(logger_name)
+ if not lock:
+ self.lock = FakeLock()
+ elif lock is True:
+ self.lock = Lock()
+ elif isinstance(lock, Lock):
+ self.lock = lock
+ else:
+ raise ValueError("lock parameter must be a Lock class or boolean")
def connect(self, config):
pass
diff --git a/osm_common/msgkafka.py b/osm_common/msgkafka.py
index 767fff6..b782685 100644
--- a/osm_common/msgkafka.py
+++ b/osm_common/msgkafka.py
@@ -26,8 +26,8 @@
class MsgKafka(MsgBase):
- def __init__(self, logger_name='msg'):
- self.logger = logging.getLogger(logger_name)
+ def __init__(self, logger_name='msg', lock=False):
+ super().__init__(logger_name, lock)
self.host = None
self.port = None
self.consumer = None
diff --git a/osm_common/msglocal.py b/osm_common/msglocal.py
index b0abb89..1e8e089 100644
--- a/osm_common/msglocal.py
+++ b/osm_common/msglocal.py
@@ -35,8 +35,8 @@
class MsgLocal(MsgBase):
- def __init__(self, logger_name='msg'):
- self.logger = logging.getLogger(logger_name)
+ def __init__(self, logger_name='msg', lock=False):
+ super().__init__(logger_name, lock)
self.path = None
# create a different file for each topic
self.files_read = {}
@@ -58,14 +58,16 @@
raise MsgException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
def disconnect(self):
- for f in self.files_read.values():
+ for topic, f in self.files_read.items():
try:
f.close()
+ self.files_read[topic] = None
except Exception: # TODO refine
pass
- for f in self.files_write.values():
+ for topic, f in self.files_write.items():
try:
f.close()
+ self.files_write[topic] = None
except Exception: # TODO refine
pass
@@ -78,10 +80,11 @@
:return: None or raises and exception
"""
try:
- if topic not in self.files_write:
- self.files_write[topic] = open(self.path + topic, "a+")
- yaml.safe_dump({key: msg}, self.files_write[topic], default_flow_style=True, width=20000)
- self.files_write[topic].flush()
+ with self.lock:
+ if topic not in self.files_write:
+ self.files_write[topic] = open(self.path + topic, "a+")
+ yaml.safe_dump({key: msg}, self.files_write[topic], default_flow_style=True, width=20000)
+ self.files_write[topic].flush()
except Exception as e: # TODO refine
raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
@@ -99,17 +102,18 @@
topic_list = (topic, )
while True:
for single_topic in topic_list:
- if single_topic not in self.files_read:
- self.files_read[single_topic] = open(self.path + single_topic, "a+")
+ with self.lock:
+ if single_topic not in self.files_read:
+ self.files_read[single_topic] = open(self.path + single_topic, "a+")
+ self.buffer[single_topic] = ""
+ self.buffer[single_topic] += self.files_read[single_topic].readline()
+ if not self.buffer[single_topic].endswith("\n"):
+ continue
+ msg_dict = yaml.load(self.buffer[single_topic])
self.buffer[single_topic] = ""
- self.buffer[single_topic] += self.files_read[single_topic].readline()
- if not self.buffer[single_topic].endswith("\n"):
- continue
- msg_dict = yaml.load(self.buffer[single_topic])
- self.buffer[single_topic] = ""
- assert len(msg_dict) == 1
- for k, v in msg_dict.items():
- return single_topic, k, v
+ assert len(msg_dict) == 1
+ for k, v in msg_dict.items():
+ return single_topic, k, v
if not blocks:
return None
sleep(2)
diff --git a/osm_common/tests/test_dbmemory.py b/osm_common/tests/test_dbmemory.py
index 3e59e94..37c2c83 100644
--- a/osm_common/tests/test_dbmemory.py
+++ b/osm_common/tests/test_dbmemory.py
@@ -9,15 +9,15 @@
__author__ = 'Eduardo Sousa <eduardosousa@av.it.pt>'
-@pytest.fixture
-def db_memory():
- db = DbMemory()
+@pytest.fixture(scope="function", params=[True, False])
+def db_memory(request):
+ db = DbMemory(lock=request.param)
return db
-@pytest.fixture
-def db_memory_with_data():
- db = DbMemory()
+@pytest.fixture(scope="function", params=[True, False])
+def db_memory_with_data(request):
+ db = DbMemory(lock=request.param)
db.create("test", {"_id": 1, "data": 1})
db.create("test", {"_id": 2, "data": 2})
@@ -30,16 +30,16 @@
return 'database exception '
-def get_one_exception_message(filter):
- return "database exception Not found entry with filter='{}'".format(filter)
+def get_one_exception_message(db_filter):
+ return "database exception Not found entry with filter='{}'".format(db_filter)
-def get_one_multiple_exception_message(filter):
- return "database exception Found more than one entry with filter='{}'".format(filter)
+def get_one_multiple_exception_message(db_filter):
+ return "database exception Found more than one entry with filter='{}'".format(db_filter)
-def del_one_exception_message(filter):
- return "database exception Not found entry with filter='{}'".format(filter)
+def del_one_exception_message(db_filter):
+ return "database exception Not found entry with filter='{}'".format(db_filter)
def replace_exception_message(value):
@@ -72,17 +72,17 @@
db_memory.db_disconnect()
-@pytest.mark.parametrize("table, filter", [
+@pytest.mark.parametrize("table, db_filter", [
("test", {}),
("test", {"_id": 1}),
("test", {"data": 1}),
("test", {"_id": 1, "data": 1})])
-def test_get_list_with_empty_db(db_memory, table, filter):
- result = db_memory.get_list(table, filter)
+def test_get_list_with_empty_db(db_memory, table, db_filter):
+ result = db_memory.get_list(table, db_filter)
assert len(result) == 0
-@pytest.mark.parametrize("table, filter, expected_data", [
+@pytest.mark.parametrize("table, db_filter, expected_data", [
("test", {}, [{"_id": 1, "data": 1}, {"_id": 2, "data": 2}, {"_id": 3, "data": 3}]),
("test", {"_id": 1}, [{"_id": 1, "data": 1}]),
("test", {"data": 1}, [{"_id": 1, "data": 1}]),
@@ -97,8 +97,8 @@
("test_table", {"_id": 1}, []),
("test_table", {"data": 1}, []),
("test_table", {"_id": 1, "data": 1}, [])])
-def test_get_list_with_non_empty_db(db_memory_with_data, table, filter, expected_data):
- result = db_memory_with_data.get_list(table, filter)
+def test_get_list_with_non_empty_db(db_memory_with_data, table, db_filter, expected_data):
+ result = db_memory_with_data.get_list(table, db_filter)
assert len(result) == len(expected_data)
for data in expected_data:
assert data in result
@@ -106,15 +106,15 @@
def test_get_list_exception(db_memory_with_data):
table = 'test'
- filter = {}
+ db_filter = {}
db_memory_with_data._find = MagicMock(side_effect=Exception())
with pytest.raises(DbException) as excinfo:
- db_memory_with_data.get_list(table, filter)
+ db_memory_with_data.get_list(table, db_filter)
assert str(excinfo.value) == empty_exception_message()
assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-@pytest.mark.parametrize("table, filter, expected_data", [
+@pytest.mark.parametrize("table, db_filter, expected_data", [
("test", {"_id": 1}, {"_id": 1, "data": 1}),
("test", {"_id": 2}, {"_id": 2, "data": 2}),
("test", {"_id": 3}, {"_id": 3, "data": 3}),
@@ -124,8 +124,8 @@
("test", {"_id": 1, "data": 1}, {"_id": 1, "data": 1}),
("test", {"_id": 2, "data": 2}, {"_id": 2, "data": 2}),
("test", {"_id": 3, "data": 3}, {"_id": 3, "data": 3})])
-def test_get_one(db_memory_with_data, table, filter, expected_data):
- result = db_memory_with_data.get_one(table, filter)
+def test_get_one(db_memory_with_data, table, db_filter, expected_data):
+ result = db_memory_with_data.get_one(table, db_filter)
assert result == expected_data
assert len(db_memory_with_data.db) == 1
assert table in db_memory_with_data.db
@@ -133,10 +133,10 @@
assert result in db_memory_with_data.db[table]
-@pytest.mark.parametrize("table, filter, expected_data", [
+@pytest.mark.parametrize("table, db_filter, expected_data", [
("test", {}, {"_id": 1, "data": 1})])
-def test_get_one_with_multiple_results(db_memory_with_data, table, filter, expected_data):
- result = db_memory_with_data.get_one(table, filter, fail_on_more=False)
+def test_get_one_with_multiple_results(db_memory_with_data, table, db_filter, expected_data):
+ result = db_memory_with_data.get_one(table, db_filter, fail_on_more=False)
assert result == expected_data
assert len(db_memory_with_data.db) == 1
assert table in db_memory_with_data.db
@@ -146,83 +146,83 @@
def test_get_one_with_multiple_results_exception(db_memory_with_data):
table = "test"
- filter = {}
+ db_filter = {}
with pytest.raises(DbException) as excinfo:
- db_memory_with_data.get_one(table, filter)
- assert str(excinfo.value) == (empty_exception_message() + get_one_multiple_exception_message(filter))
+ db_memory_with_data.get_one(table, db_filter)
+ assert str(excinfo.value) == (empty_exception_message() + get_one_multiple_exception_message(db_filter))
# assert excinfo.value.http_code == http.HTTPStatus.CONFLICT
-@pytest.mark.parametrize("table, filter", [
+@pytest.mark.parametrize("table, db_filter", [
("test", {"_id": 4}),
("test", {"data": 4}),
("test", {"_id": 4, "data": 4}),
("test_table", {"_id": 4}),
("test_table", {"data": 4}),
("test_table", {"_id": 4, "data": 4})])
-def test_get_one_with_non_empty_db_exception(db_memory_with_data, table, filter):
+def test_get_one_with_non_empty_db_exception(db_memory_with_data, table, db_filter):
with pytest.raises(DbException) as excinfo:
- db_memory_with_data.get_one(table, filter)
- assert str(excinfo.value) == (empty_exception_message() + get_one_exception_message(filter))
+ db_memory_with_data.get_one(table, db_filter)
+ assert str(excinfo.value) == (empty_exception_message() + get_one_exception_message(db_filter))
assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-@pytest.mark.parametrize("table, filter", [
+@pytest.mark.parametrize("table, db_filter", [
("test", {"_id": 4}),
("test", {"data": 4}),
("test", {"_id": 4, "data": 4}),
("test_table", {"_id": 4}),
("test_table", {"data": 4}),
("test_table", {"_id": 4, "data": 4})])
-def test_get_one_with_non_empty_db_none(db_memory_with_data, table, filter):
- result = db_memory_with_data.get_one(table, filter, fail_on_empty=False)
+def test_get_one_with_non_empty_db_none(db_memory_with_data, table, db_filter):
+ result = db_memory_with_data.get_one(table, db_filter, fail_on_empty=False)
assert result is None
-@pytest.mark.parametrize("table, filter", [
+@pytest.mark.parametrize("table, db_filter", [
("test", {"_id": 4}),
("test", {"data": 4}),
("test", {"_id": 4, "data": 4}),
("test_table", {"_id": 4}),
("test_table", {"data": 4}),
("test_table", {"_id": 4, "data": 4})])
-def test_get_one_with_empty_db_exception(db_memory, table, filter):
+def test_get_one_with_empty_db_exception(db_memory, table, db_filter):
with pytest.raises(DbException) as excinfo:
- db_memory.get_one(table, filter)
- assert str(excinfo.value) == (empty_exception_message() + get_one_exception_message(filter))
+ db_memory.get_one(table, db_filter)
+ assert str(excinfo.value) == (empty_exception_message() + get_one_exception_message(db_filter))
assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-@pytest.mark.parametrize("table, filter", [
+@pytest.mark.parametrize("table, db_filter", [
("test", {"_id": 4}),
("test", {"data": 4}),
("test", {"_id": 4, "data": 4}),
("test_table", {"_id": 4}),
("test_table", {"data": 4}),
("test_table", {"_id": 4, "data": 4})])
-def test_get_one_with_empty_db_none(db_memory, table, filter):
- result = db_memory.get_one(table, filter, fail_on_empty=False)
+def test_get_one_with_empty_db_none(db_memory, table, db_filter):
+ result = db_memory.get_one(table, db_filter, fail_on_empty=False)
assert result is None
def test_get_one_generic_exception(db_memory_with_data):
table = 'test'
- filter = {}
+ db_filter = {}
db_memory_with_data._find = MagicMock(side_effect=Exception())
with pytest.raises(DbException) as excinfo:
- db_memory_with_data.get_one(table, filter)
+ db_memory_with_data.get_one(table, db_filter)
assert str(excinfo.value) == empty_exception_message()
assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-@pytest.mark.parametrize("table, filter, expected_data", [
+@pytest.mark.parametrize("table, db_filter, expected_data", [
("test", {}, []),
("test", {"_id": 1}, [{"_id": 2, "data": 2}, {"_id": 3, "data": 3}]),
("test", {"_id": 2}, [{"_id": 1, "data": 1}, {"_id": 3, "data": 3}]),
("test", {"_id": 1, "data": 1}, [{"_id": 2, "data": 2}, {"_id": 3, "data": 3}]),
("test", {"_id": 2, "data": 2}, [{"_id": 1, "data": 1}, {"_id": 3, "data": 3}])])
-def test_del_list_with_non_empty_db(db_memory_with_data, table, filter, expected_data):
- result = db_memory_with_data.del_list(table, filter)
+def test_del_list_with_non_empty_db(db_memory_with_data, table, db_filter, expected_data):
+ result = db_memory_with_data.del_list(table, db_filter)
assert result["deleted"] == (3 - len(expected_data))
assert len(db_memory_with_data.db) == 1
assert table in db_memory_with_data.db
@@ -231,7 +231,7 @@
assert data in db_memory_with_data.db[table]
-@pytest.mark.parametrize("table, filter", [
+@pytest.mark.parametrize("table, db_filter", [
("test", {}),
("test", {"_id": 1}),
("test", {"_id": 2}),
@@ -239,22 +239,22 @@
("test", {"data": 2}),
("test", {"_id": 1, "data": 1}),
("test", {"_id": 2, "data": 2})])
-def test_del_list_with_empty_db(db_memory, table, filter):
- result = db_memory.del_list(table, filter)
+def test_del_list_with_empty_db(db_memory, table, db_filter):
+ result = db_memory.del_list(table, db_filter)
assert result['deleted'] == 0
def test_del_list_generic_exception(db_memory_with_data):
table = 'test'
- filter = {}
+ db_filter = {}
db_memory_with_data._find = MagicMock(side_effect=Exception())
with pytest.raises(DbException) as excinfo:
- db_memory_with_data.del_list(table, filter)
+ db_memory_with_data.del_list(table, db_filter)
assert str(excinfo.value) == empty_exception_message()
assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-@pytest.mark.parametrize("table, filter, data", [
+@pytest.mark.parametrize("table, db_filter, data", [
("test", {}, {"_id": 1, "data": 1}),
("test", {"_id": 1}, {"_id": 1, "data": 1}),
("test", {"data": 1}, {"_id": 1, "data": 1}),
@@ -262,8 +262,8 @@
("test", {"_id": 2}, {"_id": 2, "data": 2}),
("test", {"data": 2}, {"_id": 2, "data": 2}),
("test", {"_id": 2, "data": 2}, {"_id": 2, "data": 2})])
-def test_del_one(db_memory_with_data, table, filter, data):
- result = db_memory_with_data.del_one(table, filter)
+def test_del_one(db_memory_with_data, table, db_filter, data):
+ result = db_memory_with_data.del_one(table, db_filter)
assert result == {"deleted": 1}
assert len(db_memory_with_data.db) == 1
assert table in db_memory_with_data.db
@@ -271,7 +271,7 @@
assert data not in db_memory_with_data.db[table]
-@pytest.mark.parametrize("table, filter", [
+@pytest.mark.parametrize("table, db_filter", [
("test", {}),
("test", {"_id": 1}),
("test", {"_id": 2}),
@@ -286,14 +286,14 @@
("test_table", {"data": 2}),
("test_table", {"_id": 1, "data": 1}),
("test_table", {"_id": 2, "data": 2})])
-def test_del_one_with_empty_db_exception(db_memory, table, filter):
+def test_del_one_with_empty_db_exception(db_memory, table, db_filter):
with pytest.raises(DbException) as excinfo:
- db_memory.del_one(table, filter)
- assert str(excinfo.value) == (empty_exception_message() + del_one_exception_message(filter))
+ db_memory.del_one(table, db_filter)
+ assert str(excinfo.value) == (empty_exception_message() + del_one_exception_message(db_filter))
assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-@pytest.mark.parametrize("table, filter", [
+@pytest.mark.parametrize("table, db_filter", [
("test", {}),
("test", {"_id": 1}),
("test", {"_id": 2}),
@@ -308,12 +308,12 @@
("test_table", {"data": 2}),
("test_table", {"_id": 1, "data": 1}),
("test_table", {"_id": 2, "data": 2})])
-def test_del_one_with_empty_db_none(db_memory, table, filter):
- result = db_memory.del_one(table, filter, fail_on_empty=False)
+def test_del_one_with_empty_db_none(db_memory, table, db_filter):
+ result = db_memory.del_one(table, db_filter, fail_on_empty=False)
assert result is None
-@pytest.mark.parametrize("table, filter", [
+@pytest.mark.parametrize("table, db_filter", [
("test", {"_id": 4}),
("test", {"_id": 5}),
("test", {"data": 4}),
@@ -327,14 +327,14 @@
("test_table", {"data": 2}),
("test_table", {"_id": 1, "data": 1}),
("test_table", {"_id": 2, "data": 2})])
-def test_del_one_with_non_empty_db_exception(db_memory_with_data, table, filter):
+def test_del_one_with_non_empty_db_exception(db_memory_with_data, table, db_filter):
with pytest.raises(DbException) as excinfo:
- db_memory_with_data.del_one(table, filter)
- assert str(excinfo.value) == (empty_exception_message() + del_one_exception_message(filter))
+ db_memory_with_data.del_one(table, db_filter)
+ assert str(excinfo.value) == (empty_exception_message() + del_one_exception_message(db_filter))
assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-@pytest.mark.parametrize("table, filter", [
+@pytest.mark.parametrize("table, db_filter", [
("test", {"_id": 4}),
("test", {"_id": 5}),
("test", {"data": 4}),
@@ -348,8 +348,8 @@
("test_table", {"data": 2}),
("test_table", {"_id": 1, "data": 1}),
("test_table", {"_id": 2, "data": 2})])
-def test_del_one_with_non_empty_db_none(db_memory_with_data, table, filter):
- result = db_memory_with_data.del_one(table, filter, fail_on_empty=False)
+def test_del_one_with_non_empty_db_none(db_memory_with_data, table, db_filter):
+ result = db_memory_with_data.del_one(table, db_filter, fail_on_empty=False)
assert result is None
@@ -358,10 +358,10 @@
(False)])
def test_del_one_generic_exception(db_memory_with_data, fail_on_empty):
table = 'test'
- filter = {}
+ db_filter = {}
db_memory_with_data._find = MagicMock(side_effect=Exception())
with pytest.raises(DbException) as excinfo:
- db_memory_with_data.del_one(table, filter, fail_on_empty=fail_on_empty)
+ db_memory_with_data.del_one(table, db_filter, fail_on_empty=fail_on_empty)
assert str(excinfo.value) == empty_exception_message()
assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
diff --git a/osm_common/tests/test_fslocal.py b/osm_common/tests/test_fslocal.py
index 3a2bbb4..86b0491 100644
--- a/osm_common/tests/test_fslocal.py
+++ b/osm_common/tests/test_fslocal.py
@@ -22,9 +22,9 @@
return '/#tweeter/'
-@pytest.fixture
-def fs_local():
- fs = FsLocal()
+@pytest.fixture(scope="function", params=[True, False])
+def fs_local(request):
+ fs = FsLocal(lock=request.param)
fs.fs_connect({'path': valid_path()})
return fs
diff --git a/osm_common/tests/test_msglocal.py b/osm_common/tests/test_msglocal.py
index 93bd54d..f2b63cc 100644
--- a/osm_common/tests/test_msglocal.py
+++ b/osm_common/tests/test_msglocal.py
@@ -24,19 +24,19 @@
return '/#tweeter/'
-@pytest.fixture
-def msg_local():
- msg = MsgLocal()
+@pytest.fixture(scope="function", params=[True, False])
+def msg_local(request):
+ msg = MsgLocal(lock=request.param)
yield msg
+ msg.disconnect()
if msg.path and msg.path != invalid_path() and msg.path != valid_path():
- msg.disconnect()
shutil.rmtree(msg.path)
-@pytest.fixture
-def msg_local_config():
- msg = MsgLocal()
+@pytest.fixture(scope="function", params=[True, False])
+def msg_local_config(request):
+ msg = MsgLocal(lock=request.param)
msg.connect({"path": valid_path() + str(uuid.uuid4())})
yield msg
@@ -45,9 +45,9 @@
shutil.rmtree(msg.path)
-@pytest.fixture
-def msg_local_with_data():
- msg = MsgLocal()
+@pytest.fixture(scope="function", params=[True, False])
+def msg_local_with_data(request):
+ msg = MsgLocal(lock=request.param)
msg.connect({"path": valid_path() + str(uuid.uuid4())})
msg.write("topic1", "key1", "msg1")
@@ -117,41 +117,49 @@
def test_disconnect(msg_local_config):
+ files_read = msg_local_config.files_read.copy()
+ files_write = msg_local_config.files_write.copy()
msg_local_config.disconnect()
- for f in msg_local_config.files_read.values():
+ for f in files_read.values():
assert f.closed
- for f in msg_local_config.files_write.values():
+ for f in files_write.values():
assert f.closed
def test_disconnect_with_read(msg_local_config):
msg_local_config.read('topic1', blocks=False)
msg_local_config.read('topic2', blocks=False)
+ files_read = msg_local_config.files_read.copy()
+ files_write = msg_local_config.files_write.copy()
msg_local_config.disconnect()
- for f in msg_local_config.files_read.values():
+ for f in files_read.values():
assert f.closed
- for f in msg_local_config.files_write.values():
+ for f in files_write.values():
assert f.closed
def test_disconnect_with_write(msg_local_with_data):
+ files_read = msg_local_with_data.files_read.copy()
+ files_write = msg_local_with_data.files_write.copy()
msg_local_with_data.disconnect()
- for f in msg_local_with_data.files_read.values():
+ for f in files_read.values():
assert f.closed
- for f in msg_local_with_data.files_write.values():
+ for f in files_write.values():
assert f.closed
def test_disconnect_with_read_and_write(msg_local_with_data):
msg_local_with_data.read('topic1', blocks=False)
msg_local_with_data.read('topic2', blocks=False)
-
+ files_read = msg_local_with_data.files_read.copy()
+ files_write = msg_local_with_data.files_write.copy()
+
msg_local_with_data.disconnect()
- for f in msg_local_with_data.files_read.values():
+ for f in files_read.values():
assert f.closed
- for f in msg_local_with_data.files_write.values():
+ for f in files_write.values():
assert f.closed