From 136f29577fd83028369c2c4fc4c60f738e0d26d3 Mon Sep 17 00:00:00 2001 From: tierno Date: Fri, 19 Oct 2018 13:01:03 +0200 Subject: [PATCH] Bug 559 adding encrypt/decrypt methods Fixing pytest and unittest. Adding to devops stage test Change-Id: Idbeaa82dec736c4a8b2d2a26bd39aeecbc49b901 Signed-off-by: tierno --- devops-stages/stage-test.sh | 3 +- osm_common/__init__.py | 6 +- osm_common/dbbase.py | 76 +++++++++++++++++----- osm_common/dbmemory.py | 2 + osm_common/dbmongo.py | 18 +++++- osm_common/msgbase.py | 9 +-- osm_common/msglocal.py | 9 +-- osm_common/tests/test_dbbase.py | 60 +++++++++++++++++- osm_common/tests/test_dbmemory.py | 101 +++++++++++++----------------- tox.ini | 17 +++-- 10 files changed, 210 insertions(+), 91 deletions(-) diff --git a/devops-stages/stage-test.sh b/devops-stages/stage-test.sh index b866c8e..5dcbd36 100755 --- a/devops-stages/stage-test.sh +++ b/devops-stages/stage-test.sh @@ -14,4 +14,5 @@ # limitations under the License. tox -e flake8 - +tox -e unittest +tox -e pytest diff --git a/osm_common/__init__.py b/osm_common/__init__.py index f288cb6..eb221c0 100644 --- a/osm_common/__init__.py +++ b/osm_common/__init__.py @@ -15,5 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -version = '0.1.9' -date_version = '2018-10-09' +version = '0.1.10' +# TODO add package version filling commit id with 0's; e.g.: '5.0.0.post11+00000000.dirty-1' +date_version = '2018-10-22' + diff --git a/osm_common/dbbase.py b/osm_common/dbbase.py index cdb1644..3959383 100644 --- a/osm_common/dbbase.py +++ b/osm_common/dbbase.py @@ -19,6 +19,8 @@ import yaml import logging from http import HTTPStatus from copy import deepcopy +from Crypto.Cipher import AES +from base64 import b64decode, b64encode __author__ = "Alfonso Tierno " @@ -40,14 +42,16 @@ class DbBase(object): """ self.logger = logging.getLogger(logger_name) self.master_password = master_password + self.secret_key = None - def db_connect(self, config): + def db_connect(self, config, target_version=None): """ Connect to database :param config: Configuration of database + :param target_version: if provided it checks if database contains required version, raising exception otherwise. :return: None or raises DbException on error """ - pass + raise DbException("Method 'db_connect' not implemented") def db_disconnect(self): """ @@ -141,27 +145,71 @@ class DbBase(object): """ raise DbException("Method 'replace' not implemented") - def encrypt(self, value, salt=None): + @staticmethod + def _join_passwords(passwd_byte, passwd_str): + """ + Modifies passwd_byte with the xor of passwd_str. Used for adding salt, join passwords, etc + :param passwd_byte: original password in bytes, 32 byte length + :param passwd_str: string salt to be added + :return: modified password in bytes + """ + if not passwd_str: + return passwd_byte + secret_key = bytearray(passwd_byte) + for i, b in enumerate(passwd_str.encode()): + secret_key[i % 32] ^= b + return bytes(secret_key) + + def set_secret_key(self, secret_key): + """ + Set internal secret key used for encryption + :param secret_key: byte array length 32 with the secret_key + :return: None + """ + assert (len(secret_key) == 32) + self.secret_key = self._join_passwords(secret_key, self.master_password) + + def encrypt(self, value, schema_version=None, salt=None): """ Encrypt a value - :param value: value to be encrypted - :param salt: optional salt to be used + :param value: value to be encrypted. It is string/unicode + :param schema_version: used for version control. If None or '1.0' no encryption is done. + If '1.1' symmetric AES encryption is done + :param salt: optional salt to be used. Must be str :return: Encrypted content of value """ - # for the moment return same value. until all modules call this method - return value - # raise DbException("Method 'encrypt' not implemented") - - def decrypt(self, value, salt=None): + if not schema_version or schema_version == '1.0': + return value + else: + if not self.secret_key: + raise DbException("Cannot encrypt. Missing secret_key", http_code=HTTPStatus.INTERNAL_SERVER_ERROR) + secret_key = self._join_passwords(self.secret_key, salt) + cipher = AES.new(secret_key) + padded_private_msg = value + ('\0' * ((16-len(value)) % 16)) + encrypted_msg = cipher.encrypt(padded_private_msg) + encoded_encrypted_msg = b64encode(encrypted_msg) + return encoded_encrypted_msg.decode("ascii") + + def decrypt(self, value, schema_version=None, salt=None): """ Decrypt an encrypted value - :param value: value to be decrypted + :param value: value to be decrypted. It is a base64 string + :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done. + If '1.1' symmetric AES encryption has been done :param salt: optional salt to be used :return: Plain content of value """ - # for the moment return same value. until all modules call this method - return value - # raise DbException("Method 'decrypt' not implemented") + if not schema_version or schema_version == '1.0': + return value + else: + if not self.secret_key: + raise DbException("Cannot decrypt. Missing secret_key", http_code=HTTPStatus.INTERNAL_SERVER_ERROR) + secret_key = self._join_passwords(self.secret_key, salt) + encrypted_msg = b64decode(value) + cipher = AES.new(secret_key) + decrypted_msg = cipher.decrypt(encrypted_msg) + unpadded_private_msg = decrypted_msg.decode().rstrip('\0') + return unpadded_private_msg def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None): diff --git a/osm_common/dbmemory.py b/osm_common/dbmemory.py index e37d97b..0e0c42c 100644 --- a/osm_common/dbmemory.py +++ b/osm_common/dbmemory.py @@ -156,6 +156,8 @@ class DbMemory(DbBase): return None self.db[table][i] = deepcopy(indata) return {"updated": 1} + except DbException: + raise except Exception as e: # TODO refine raise DbException(str(e)) diff --git a/osm_common/dbmongo.py b/osm_common/dbmongo.py index 9b5bc57..af63d5b 100644 --- a/osm_common/dbmongo.py +++ b/osm_common/dbmongo.py @@ -22,6 +22,7 @@ from osm_common.dbbase import DbException, DbBase from http import HTTPStatus from time import time, sleep from copy import deepcopy +from base64 import b64decode __author__ = "Alfonso Tierno " @@ -66,10 +67,11 @@ class DbMongo(DbBase): self.client = None self.db = None - def db_connect(self, config): + def db_connect(self, config, target_version=None): """ Connect to database :param config: Configuration of database + :param target_version: if provided it checks if database contains required version, raising exception otherwise. :return: None or raises DbException on error """ try: @@ -83,7 +85,19 @@ class DbMongo(DbBase): now = time() while True: try: - self.db.users.find_one({"username": "admin"}) + version_data = self.get_one("admin", {"_id": "version"}, fail_on_empty=False, fail_on_more=True) + # check database status is ok + if version_data and version_data.get("status") != 'ENABLED': + raise DbException("Wrong database status '{}'".format(version_data.get("status")), + http_code=HTTPStatus.INTERNAL_SERVER_ERROR) + # check version + db_version = None if not version_data else version_data.get("version") + if target_version and target_version != db_version: + raise DbException("Invalid database version {}. Expected {}".format(db_version, target_version)) + # get serial + if version_data and version_data.get("serial"): + self.set_secret_key(b64decode(version_data["serial"])) + self.logger.info("Connected to database {} version {}".format(config["name"], db_version)) return except errors.ConnectionFailure as e: if time() - now >= self.conn_initial_timout: diff --git a/osm_common/msgbase.py b/osm_common/msgbase.py index 6e6dec0..5ba8f71 100644 --- a/osm_common/msgbase.py +++ b/osm_common/msgbase.py @@ -50,13 +50,14 @@ class MsgBase(object): pass def write(self, topic, key, msg): - raise MsgException("Method 'write' not implemented") + raise MsgException("Method 'write' not implemented", http_code=HTTPStatus.INTERNAL_SERVER_ERROR) def read(self, topic): - raise MsgException("Method 'read' not implemented") + raise MsgException("Method 'read' not implemented", http_code=HTTPStatus.INTERNAL_SERVER_ERROR) async def aiowrite(self, topic, key, msg, loop=None): - raise MsgException("Method 'aiowrite' not implemented") + raise MsgException("Method 'aiowrite' not implemented", http_code=HTTPStatus.INTERNAL_SERVER_ERROR) async def aioread(self, topic, loop=None, callback=None, aiocallback=None, **kwargs): - raise MsgException("Method 'aioread' not implemented") + raise MsgException("Method 'aioread' not implemented", http_code=HTTPStatus.INTERNAL_SERVER_ERROR) + diff --git a/osm_common/msglocal.py b/osm_common/msglocal.py index f731e74..247de7b 100644 --- a/osm_common/msglocal.py +++ b/osm_common/msglocal.py @@ -21,6 +21,7 @@ import yaml import asyncio from osm_common.msgbase import MsgBase, MsgException from time import sleep +from http import HTTPStatus __author__ = "Alfonso Tierno " @@ -54,7 +55,7 @@ class MsgLocal(MsgBase): except MsgException: raise except Exception as e: # TODO refine - raise MsgException(str(e)) + raise MsgException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR) def disconnect(self): for f in self.files_read.values(): @@ -82,7 +83,7 @@ class MsgLocal(MsgBase): yaml.safe_dump({key: msg}, self.files_write[topic], default_flow_style=True, width=20000) self.files_write[topic].flush() except Exception as e: # TODO refine - raise MsgException(str(e)) + raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR) def read(self, topic, blocks=True): """ @@ -113,7 +114,7 @@ class MsgLocal(MsgBase): return None sleep(2) except Exception as e: # TODO refine - raise MsgException(str(e)) + raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR) async def aioread(self, topic, loop): """ @@ -131,7 +132,7 @@ class MsgLocal(MsgBase): except MsgException: raise except Exception as e: # TODO refine - raise MsgException(str(e)) + raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR) async def aiowrite(self, topic, key, msg, loop=None): """ diff --git a/osm_common/tests/test_dbbase.py b/osm_common/tests/test_dbbase.py index 2da4506..64bfb3e 100644 --- a/osm_common/tests/test_dbbase.py +++ b/osm_common/tests/test_dbbase.py @@ -2,6 +2,7 @@ import http import pytest import unittest from osm_common.dbbase import DbBase, DbException, deep_update +from os import urandom def exception_message(message): @@ -20,7 +21,9 @@ def test_constructor(): def test_db_connect(db_base): - db_base.db_connect(None) + with pytest.raises(DbException) as excinfo: + db_base.db_connect(None) + assert str(excinfo.value).startswith(exception_message("Method 'db_connect' not implemented")) def test_db_disconnect(db_base): @@ -62,6 +65,57 @@ def test_del_one(db_base): assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND +class TestEncryption(unittest.TestCase): + def setUp(self): + master_password = "Setting a long master password with numbers 123 and capitals AGHBNHD and symbols %&8)!'" + db_base1 = DbBase(master_password=master_password) + db_base2 = DbBase() + # set self.secret_key obtained when connect + db_base1.secret_key = DbBase._join_passwords(urandom(32), db_base1.master_password) + db_base2.secret_key = DbBase._join_passwords(urandom(32), db_base2.master_password) + self.db_base = [db_base1, db_base2] + + def test_encrypt_decrypt(self): + TEST = ( + ("plain text 1 ! ", None), + ("plain text 2 with salt ! ", "1afd5d1a-4a7e-4d9c-8c65-251290183106"), + ("plain text 3 with usalt ! ", u"1afd5d1a-4a7e-4d9c-8c65-251290183106"), + (u"plain unicode 4 ! ", None), + (u"plain unicode 5 with salt ! ", "1a000d1a-4a7e-4d9c-8c65-251290183106"), + (u"plain unicode 6 with usalt ! ", u"1abcdd1a-4a7e-4d9c-8c65-251290183106"), + ) + for db_base in self.db_base: + for value, salt in TEST: + # no encryption + encrypted = db_base.encrypt(value, schema_version='1.0', salt=salt) + self.assertEqual(encrypted, value, "value '{}' has been encrypted".format(value)) + decrypted = db_base.decrypt(encrypted, schema_version='1.0', salt=salt) + self.assertEqual(decrypted, value, "value '{}' has been decrypted".format(value)) + + # encrypt/decrypt + encrypted = db_base.encrypt(value, schema_version='1.1', salt=salt) + self.assertNotEqual(encrypted, value, "value '{}' has not been encrypted".format(value)) + self.assertIsInstance(encrypted, str, "Encrypted is not ascii text") + decrypted = db_base.decrypt(encrypted, schema_version='1.1', salt=salt) + self.assertEqual(decrypted, value, "value is not equal after encryption/decryption") + + def test_encrypt_decrypt_salt(self): + value = "value to be encrypted!" + encrypted = [] + for db_base in self.db_base: + for salt in (None, "salt 1", "1afd5d1a-4a7e-4d9c-8c65-251290183106"): + # encrypt/decrypt + encrypted.append(db_base.encrypt(value, schema_version='1.1', salt=salt)) + self.assertNotEqual(encrypted[-1], value, "value '{}' has not been encrypted".format(value)) + self.assertIsInstance(encrypted[-1], str, "Encrypted is not ascii text") + decrypted = db_base.decrypt(encrypted[-1], schema_version='1.1', salt=salt) + self.assertEqual(decrypted, value, "value is not equal after encryption/decryption") + for i in range(0, len(encrypted)): + for j in range(i+1, len(encrypted)): + self.assertNotEqual(encrypted[i], encrypted[j], + "encryption with different salt contains different result") + + class TestDeepUpdate(unittest.TestCase): def test_update_dict(self): # Original, patch, expected result @@ -161,3 +215,7 @@ class TestDeepUpdate(unittest.TestCase): deep_update(t[0], t[1]) except DbException as e: print(e) + + +if __name__ == '__main__': + unittest.main() diff --git a/osm_common/tests/test_dbmemory.py b/osm_common/tests/test_dbmemory.py index c4d2874..3e59e94 100644 --- a/osm_common/tests/test_dbmemory.py +++ b/osm_common/tests/test_dbmemory.py @@ -42,8 +42,8 @@ def del_one_exception_message(filter): return "database exception Not found entry with filter='{}'".format(filter) -def replace_exception_message(filter): - return "database exception Not found entry with filter='{}'".format(filter) +def replace_exception_message(value): + return "database exception Not found entry with _id='{}'".format(value) def test_constructor(): @@ -366,17 +366,18 @@ def test_del_one_generic_exception(db_memory_with_data, fail_on_empty): assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND -@pytest.mark.parametrize("table, filter, indata", [ - ("test", {}, {"_id": 1, "data": 42}), - ("test", {}, {"_id": 3, "data": 42}), - ("test", {"_id": 1}, {"_id": 3, "data": 42}), - ("test", {"_id": 3}, {"_id": 3, "data": 42}), - ("test", {"data": 1}, {"_id": 3, "data": 42}), - ("test", {"data": 3}, {"_id": 3, "data": 42}), - ("test", {"_id": 1, "data": 1}, {"_id": 3, "data": 42}), - ("test", {"_id": 3, "data": 3}, {"_id": 3, "data": 42})]) -def test_replace(db_memory_with_data, table, filter, indata): - result = db_memory_with_data.replace(table, filter, indata) +@pytest.mark.parametrize("table, _id, indata", [ + ("test", 1, {"_id": 1, "data": 42}), + ("test", 1, {"_id": 1, "data": 42, "kk": 34}), + ("test", 1, {"_id": 1}), + ("test", 2, {"_id": 2, "data": 42}), + ("test", 2, {"_id": 2, "data": 42, "kk": 34}), + ("test", 2, {"_id": 2}), + ("test", 3, {"_id": 3, "data": 42}), + ("test", 3, {"_id": 3, "data": 42, "kk": 34}), + ("test", 3, {"_id": 3})]) +def test_replace(db_memory_with_data, table, _id, indata): + result = db_memory_with_data.replace(table, _id, indata) assert result == {"updated": 1} assert len(db_memory_with_data.db) == 1 assert table in db_memory_with_data.db @@ -384,61 +385,43 @@ def test_replace(db_memory_with_data, table, filter, indata): assert indata in db_memory_with_data.db[table] -@pytest.mark.parametrize("table, filter, indata", [ - ("test", {}, {'_id': 1, 'data': 1}), - ("test", {}, {'_id': 2, 'data': 1}), - ("test", {}, {'_id': 1, 'data': 2}), - ("test", {'_id': 1}, {'_id': 1, 'data': 1}), - ("test", {'_id': 1, 'data': 1}, {'_id': 1, 'data': 1}), - ("test_table", {}, {'_id': 1, 'data': 1}), - ("test_table", {}, {'_id': 2, 'data': 1}), - ("test_table", {}, {'_id': 1, 'data': 2}), - ("test_table", {'_id': 1}, {'_id': 1, 'data': 1}), - ("test_table", {'_id': 1, 'data': 1}, {'_id': 1, 'data': 1})]) -def test_replace_without_data_exception(db_memory, table, filter, indata): +@pytest.mark.parametrize("table, _id, indata", [ + ("test", 1, {"_id": 1, "data": 42}), + ("test", 2, {"_id": 2}), + ("test", 3, {"_id": 3})]) +def test_replace_without_data_exception(db_memory, table, _id, indata): with pytest.raises(DbException) as excinfo: - db_memory.replace(table, filter, indata, fail_on_empty=True) - assert str(excinfo.value) == (empty_exception_message() + replace_exception_message(filter)) + db_memory.replace(table, _id, indata, fail_on_empty=True) + assert str(excinfo.value) == (replace_exception_message(_id)) assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND -@pytest.mark.parametrize("table, filter, indata", [ - ("test", {}, {'_id': 1, 'data': 1}), - ("test", {}, {'_id': 2, 'data': 1}), - ("test", {}, {'_id': 1, 'data': 2}), - ("test", {'_id': 1}, {'_id': 1, 'data': 1}), - ("test", {'_id': 1, 'data': 1}, {'_id': 1, 'data': 1}), - ("test_table", {}, {'_id': 1, 'data': 1}), - ("test_table", {}, {'_id': 2, 'data': 1}), - ("test_table", {}, {'_id': 1, 'data': 2}), - ("test_table", {'_id': 1}, {'_id': 1, 'data': 1}), - ("test_table", {'_id': 1, 'data': 1}, {'_id': 1, 'data': 1})]) -def test_replace_without_data_none(db_memory, table, filter, indata): - result = db_memory.replace(table, filter, indata, fail_on_empty=False) +@pytest.mark.parametrize("table, _id, indata", [ + ("test", 1, {"_id": 1, "data": 42}), + ("test", 2, {"_id": 2}), + ("test", 3, {"_id": 3})]) +def test_replace_without_data_none(db_memory, table, _id, indata): + result = db_memory.replace(table, _id, indata, fail_on_empty=False) assert result is None -@pytest.mark.parametrize("table, filter, indata", [ - ("test_table", {}, {'_id': 1, 'data': 1}), - ("test_table", {}, {'_id': 2, 'data': 1}), - ("test_table", {}, {'_id': 1, 'data': 2}), - ("test_table", {'_id': 1}, {'_id': 1, 'data': 1}), - ("test_table", {'_id': 1, 'data': 1}, {'_id': 1, 'data': 1})]) -def test_replace_with_data_exception(db_memory_with_data, table, filter, indata): +@pytest.mark.parametrize("table, _id, indata", [ + ("test", 11, {"_id": 11, "data": 42}), + ("test", 12, {"_id": 12}), + ("test", 33, {"_id": 33})]) +def test_replace_with_data_exception(db_memory_with_data, table, _id, indata): with pytest.raises(DbException) as excinfo: - db_memory_with_data.replace(table, filter, indata, fail_on_empty=True) - assert str(excinfo.value) == (empty_exception_message() + replace_exception_message(filter)) + db_memory_with_data.replace(table, _id, indata, fail_on_empty=True) + assert str(excinfo.value) == (replace_exception_message(_id)) assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND -@pytest.mark.parametrize("table, filter, indata", [ - ("test_table", {}, {'_id': 1, 'data': 1}), - ("test_table", {}, {'_id': 2, 'data': 1}), - ("test_table", {}, {'_id': 1, 'data': 2}), - ("test_table", {'_id': 1}, {'_id': 1, 'data': 1}), - ("test_table", {'_id': 1, 'data': 1}, {'_id': 1, 'data': 1})]) -def test_replace_with_data_none(db_memory_with_data, table, filter, indata): - result = db_memory_with_data.replace(table, filter, indata, fail_on_empty=False) +@pytest.mark.parametrize("table, _id, indata", [ + ("test", 11, {"_id": 11, "data": 42}), + ("test", 12, {"_id": 12}), + ("test", 33, {"_id": 33})]) +def test_replace_with_data_none(db_memory_with_data, table, _id, indata): + result = db_memory_with_data.replace(table, _id, indata, fail_on_empty=False) assert result is None @@ -447,11 +430,11 @@ def test_replace_with_data_none(db_memory_with_data, table, filter, indata): False]) def test_replace_generic_exception(db_memory_with_data, fail_on_empty): table = 'test' - filter = {} + _id = {} indata = {'_id': 1, 'data': 1} db_memory_with_data._find = MagicMock(side_effect=Exception()) with pytest.raises(DbException) as excinfo: - db_memory_with_data.replace(table, filter, indata, fail_on_empty=fail_on_empty) + db_memory_with_data.replace(table, _id, indata, fail_on_empty=fail_on_empty) assert str(excinfo.value) == empty_exception_message() assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND diff --git a/tox.ini b/tox.ini index e3bc42c..1cf3d54 100644 --- a/tox.ini +++ b/tox.ini @@ -14,13 +14,15 @@ # limitations under the License. [tox] -envlist = py27,py3,flake8,pytest +envlist = py3 toxworkdir={homedir}/.tox -[testenv] +[testenv:pytest] basepython = python3 -deps = pytest pytest-asyncio -commands = pytest +deps = pytest + pytest-asyncio + pycrypto +commands = pytest osm_common [testenv:flake8] basepython = python3 @@ -28,6 +30,13 @@ deps = flake8 commands = flake8 osm_common/ setup.py --max-line-length 120 --exclude .svn,CVS,.gz,.git,__pycache__,.tox,local,temp --ignore W291,W293,E226 +[testenv:unittest] +basepython = python3 +deps = pycrypto + pytest +commands = python3 -m unittest osm_common.tests.test_dbbase + + [testenv:build] basepython = python3 deps = stdeb -- 2.25.1