From f9b4b3dfbcc355bbea219b33862afe81b146c1f8 Mon Sep 17 00:00:00 2001 From: Gulsum Atici Date: Mon, 9 Jan 2023 23:19:18 +0300 Subject: [PATCH] Feature 10950 Replace pycrypto with pycryptodome Remove the pycrypto library and change encrypt and decrypt methods to work with pycryptodome. Move encryption methods from N2VC to common. Change-Id: I12a5f6138664ab6ebb7100c82523e91750f05f14 Signed-off-by: Gulsum Atici --- osm_common/dbbase.py | 289 +++- osm_common/dbmemory.py | 2 +- osm_common/dbmongo.py | 2 +- osm_common/fsmongo.py | 6 +- osm_common/tests/test_dbbase.py | 1261 ++++++++++++++++- .../use_pycryptpdome-8ef05275b779994b.yaml | 21 + requirements.in | 5 +- requirements.txt | 8 +- 8 files changed, 1555 insertions(+), 39 deletions(-) create mode 100644 releasenotes/notes/use_pycryptpdome-8ef05275b779994b.yaml diff --git a/osm_common/dbbase.py b/osm_common/dbbase.py index 4021805..6b3a89a 100644 --- a/osm_common/dbbase.py +++ b/osm_common/dbbase.py @@ -14,21 +14,27 @@ # implied. # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio from base64 import b64decode, b64encode from copy import deepcopy from http import HTTPStatus import logging import re from threading import Lock +import typing + from Crypto.Cipher import AES +from motor.motor_asyncio import AsyncIOMotorClient from osm_common.common_utils import FakeLock import yaml __author__ = "Alfonso Tierno " +DB_NAME = "osm" + + class DbException(Exception): def __init__(self, message, http_code=HTTPStatus.NOT_FOUND): self.http_code = http_code @@ -36,7 +42,7 @@ class DbException(Exception): class DbBase(object): - def __init__(self, logger_name="db", lock=False): + def __init__(self, encoding_type="ascii", logger_name="db", lock=False): """ Constructor of dbBase :param logger_name: logging name @@ -47,6 +53,8 @@ class DbBase(object): """ self.logger = logging.getLogger(logger_name) self.secret_key = None # 32 bytes length array used for encrypt/decrypt + self.encrypt_mode = AES.MODE_ECB + self.encoding_type = encoding_type if not lock: self.lock = FakeLock() elif lock is True: @@ -263,51 +271,120 @@ class DbBase(object): """ pass - def encrypt(self, value, schema_version=None, salt=None): - """ - Encrypt a value - :param value: value to be encrypted. It is string/unicode - :param schema_version: used for version control. If None or '1.0' no encryption is done. - If '1.1' symmetric AES encryption is done - :param salt: optional salt to be used. Must be str - :return: Encrypted content of value + @staticmethod + def pad_data(value: str) -> str: + if not isinstance(value, str): + raise DbException( + f"Incorrect data type: type({value}), string is expected." + ) + return value + ("\0" * ((16 - len(value)) % 16)) + + @staticmethod + def unpad_data(value: str) -> str: + if not isinstance(value, str): + raise DbException( + f"Incorrect data type: type({value}), string is expected." + ) + return value.rstrip("\0") + + def _encrypt_value(self, value: str, schema_version: str, salt: str): + """Encrypt a value. + + Args: + value (str): value to be encrypted. It is string/unicode + schema_version (str): used for version control. If None or '1.0' no encryption is done. + If '1.1' symmetric AES encryption is done + salt (str): optional salt to be used. Must be str + + Returns: + Encrypted content of value (str) + """ - self.get_secret_key() if not self.secret_key or not schema_version or schema_version == "1.0": return value + else: + # Secret key as bytes secret_key = self._join_secret_key(salt) - cipher = AES.new(secret_key) - padded_private_msg = value + ("\0" * ((16 - len(value)) % 16)) - encrypted_msg = cipher.encrypt(padded_private_msg) + cipher = AES.new(secret_key, self.encrypt_mode) + # Padded data as string + padded_private_msg = self.pad_data(value) + # Padded data as bytes + padded_private_msg_bytes = padded_private_msg.encode(self.encoding_type) + # Encrypt padded data + encrypted_msg = cipher.encrypt(padded_private_msg_bytes) + # Base64 encoded encrypted data encoded_encrypted_msg = b64encode(encrypted_msg) - return encoded_encrypted_msg.decode("ascii") + # Converting to string + return encoded_encrypted_msg.decode(self.encoding_type) + + def encrypt(self, value: str, schema_version: str = None, salt: str = None) -> str: + """Encrypt a value. + + Args: + value (str): value to be encrypted. It is string/unicode + schema_version (str): used for version control. If None or '1.0' no encryption is done. + If '1.1' symmetric AES encryption is done + salt (str): optional salt to be used. Must be str + + Returns: + Encrypted content of value (str) - def decrypt(self, value, schema_version=None, salt=None): - """ - Decrypt an encrypted value - :param value: value to be decrypted. It is a base64 string - :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done. - If '1.1' symmetric AES encryption has been done - :param salt: optional salt to be used - :return: Plain content of value """ self.get_secret_key() + return self._encrypt_value(value, schema_version, salt) + + def _decrypt_value(self, value: str, schema_version: str, salt: str) -> str: + """Decrypt an encrypted value. + Args: + + value (str): value to be decrypted. It is a base64 string + schema_version (str): used for known encryption method used. + If None or '1.0' no encryption has been done. + If '1.1' symmetric AES encryption has been done + salt (str): optional salt to be used + + Returns: + Plain content of value (str) + + """ if not self.secret_key or not schema_version or schema_version == "1.0": return value + else: secret_key = self._join_secret_key(salt) + # Decoding encrypted data, output bytes encrypted_msg = b64decode(value) - cipher = AES.new(secret_key) + cipher = AES.new(secret_key, self.encrypt_mode) + # Decrypted data, output bytes decrypted_msg = cipher.decrypt(encrypted_msg) try: - unpadded_private_msg = decrypted_msg.decode().rstrip("\0") + # Converting to string + private_msg = decrypted_msg.decode(self.encoding_type) except UnicodeDecodeError: raise DbException( "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?", http_code=HTTPStatus.INTERNAL_SERVER_ERROR, ) - return unpadded_private_msg + # Unpadded data as string + return self.unpad_data(private_msg) + + def decrypt(self, value: str, schema_version: str = None, salt: str = None) -> str: + """Decrypt an encrypted value. + Args: + + value (str): value to be decrypted. It is a base64 string + schema_version (str): used for known encryption method used. + If None or '1.0' no encryption has been done. + If '1.1' symmetric AES encryption has been done + salt (str): optional salt to be used + + Returns: + Plain content of value (str) + + """ + self.get_secret_key() + return self._decrypt_value(value, schema_version, salt) def encrypt_decrypt_fields( self, item, action, fields=None, flags=None, schema_version=None, salt=None @@ -593,3 +670,163 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None): def deep_update(dict_to_change, dict_reference): """Maintained for backward compatibility. Use deep_update_rfc7396 instead""" return deep_update_rfc7396(dict_to_change, dict_reference) + + +class Encryption(DbBase): + def __init__(self, uri, config, encoding_type="ascii", loop=None, logger_name="db"): + """Constructor. + + Args: + uri (str): Connection string to connect to the database. + config (dict): Additional database info + encoding_type (str): ascii, utf-8 etc. + loop (object): Asyncio Loop + logger_name (str): Logger name + + """ + self.loop = loop or asyncio.get_event_loop() + self._secret_key = None # 32 bytes length array used for encrypt/decrypt + self.encrypt_mode = AES.MODE_ECB + super(Encryption, self).__init__( + encoding_type=encoding_type, logger_name=logger_name + ) + self._client = AsyncIOMotorClient(uri) + self._config = config + + @property + def secret_key(self): + return self._secret_key + + @secret_key.setter + def secret_key(self, value): + self._secret_key = value + + @property + def _database(self): + return self._client[DB_NAME] + + @property + def _admin_collection(self): + return self._database["admin"] + + @property + def database_key(self): + return self._config.get("database_commonkey") + + async def decrypt_fields( + self, + item: dict, + fields: typing.List[str], + schema_version: str = None, + salt: str = None, + ) -> None: + """Decrypt fields from a dictionary. Follows the same logic as in osm_common. + + Args: + + item (dict): Dictionary with the keys to be decrypted + fields (list): List of keys to decrypt + schema version (str): Schema version. (i.e. 1.11) + salt (str): Salt for the decryption + + """ + flags = re.I + + async def process(_item): + if isinstance(_item, list): + for elem in _item: + await process(elem) + elif isinstance(_item, dict): + for key, val in _item.items(): + if isinstance(val, str): + if any(re.search(f, key, flags) for f in fields): + _item[key] = await self.decrypt(val, schema_version, salt) + else: + await process(val) + + await process(item) + + async def encrypt( + self, value: str, schema_version: str = None, salt: str = None + ) -> str: + """Encrypt a value. + + Args: + value (str): value to be encrypted. It is string/unicode + schema_version (str): used for version control. If None or '1.0' no encryption is done. + If '1.1' symmetric AES encryption is done + salt (str): optional salt to be used. Must be str + + Returns: + Encrypted content of value (str) + + """ + await self.get_secret_key() + return self._encrypt_value(value, schema_version, salt) + + async def decrypt( + self, value: str, schema_version: str = None, salt: str = None + ) -> str: + """Decrypt an encrypted value. + Args: + + value (str): value to be decrypted. It is a base64 string + schema_version (str): used for known encryption method used. + If None or '1.0' no encryption has been done. + If '1.1' symmetric AES encryption has been done + salt (str): optional salt to be used + + Returns: + Plain content of value (str) + + """ + await self.get_secret_key() + return self._decrypt_value(value, schema_version, salt) + + def _join_secret_key(self, update_key: typing.Any) -> bytes: + """Join key with secret key. + + Args: + + update_key (str or bytes): str or bytes with the to update + + Returns: + + Joined key (bytes) + """ + return self._join_keys(update_key, self.secret_key) + + def _join_keys(self, key: typing.Any, secret_key: bytes) -> bytes: + """Join key with secret_key. + + Args: + + key (str or bytes): str or bytes of the key to update + secret_key (bytes): bytes of the secret key + + Returns: + + Joined key (bytes) + """ + if isinstance(key, str): + update_key_bytes = key.encode(self.encoding_type) + else: + update_key_bytes = key + new_secret_key = bytearray(secret_key) if secret_key else bytearray(32) + for i, b in enumerate(update_key_bytes): + new_secret_key[i % 32] ^= b + return bytes(new_secret_key) + + async def get_secret_key(self): + """Get secret key using the database key and the serial key in the DB. + The key is populated in the property self.secret_key. + """ + if self.secret_key: + return + secret_key = None + if self.database_key: + secret_key = self._join_keys(self.database_key, None) + version_data = await self._admin_collection.find_one({"_id": "version"}) + if version_data and version_data.get("serial"): + secret_key = self._join_keys(b64decode(version_data["serial"]), secret_key) + self._secret_key = secret_key diff --git a/osm_common/dbmemory.py b/osm_common/dbmemory.py index ad52135..272f6d6 100644 --- a/osm_common/dbmemory.py +++ b/osm_common/dbmemory.py @@ -29,7 +29,7 @@ __author__ = "Alfonso Tierno " class DbMemory(DbBase): def __init__(self, logger_name="db", lock=False): - super().__init__(logger_name, lock) + super().__init__(logger_name=logger_name, lock=lock) self.db = {} def db_connect(self, config): diff --git a/osm_common/dbmongo.py b/osm_common/dbmongo.py index f64949d..f5c4d30 100644 --- a/osm_common/dbmongo.py +++ b/osm_common/dbmongo.py @@ -65,7 +65,7 @@ class DbMongo(DbBase): conn_timout = 10 def __init__(self, logger_name="db", lock=False): - super().__init__(logger_name, lock) + super().__init__(logger_name=logger_name, lock=lock) self.client = None self.db = None self.database_key = None diff --git a/osm_common/fsmongo.py b/osm_common/fsmongo.py index 727410e..f99267f 100644 --- a/osm_common/fsmongo.py +++ b/osm_common/fsmongo.py @@ -570,9 +570,7 @@ class FsMongo(FsBase): self.__update_local_fs(from_path=from_path) def _update_mongo_fs(self, from_path): - os_path = self.path + from_path - # Obtain list of files and dirs in filesystem members = [] for root, dirs, files in os.walk(os_path): @@ -615,7 +613,6 @@ class FsMongo(FsBase): remote_files.pop(rel_filename, None) if last_modified_date >= upload_date: - stream = None fh = None try: @@ -646,13 +643,12 @@ class FsMongo(FsBase): if stream: stream.close() - # delete files that are not any more in local fs + # delete files that are not anymore in local fs for remote_file in remote_files.values(): for file in remote_file: self.fs.delete(file._id) def _get_mongo_files(self, from_path=None): - file_dict = {} file_cursor = self.fs.find(no_cursor_timeout=True, sort=[("uploadDate", -1)]) for file in file_cursor: diff --git a/osm_common/tests/test_dbbase.py b/osm_common/tests/test_dbbase.py index eabf5e0..050abdb 100644 --- a/osm_common/tests/test_dbbase.py +++ b/osm_common/tests/test_dbbase.py @@ -16,16 +16,47 @@ # For those usages not covered by the Apache License, Version 2.0 please # contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com ## - +import asyncio +import copy +from copy import deepcopy import http from http import HTTPStatus +import logging from os import urandom import unittest +from unittest.mock import MagicMock, Mock, patch -from osm_common.dbbase import DbBase, DbException, deep_update +from Crypto.Cipher import AES +from osm_common.dbbase import DbBase, DbException, deep_update, Encryption import pytest +# Variables used in TestBaseEncryption and TestAsyncEncryption +salt = "1afd5d1a-4a7e-4d9c-8c65-251290183106" +value = "private key txt" +padded_value = b"private key txt\0" +padded_encoded_value = b"private key txt\x00" +encoding_type = "ascii" +encyrpt_mode = AES.MODE_ECB +secret_key = b"\xeev\xc2\xb8\xb2#;Ek\xd0\xb5['\x04\xed\x1f\xb9?\xc5Ig\x80\xd5\x8d\x8aT\xd7\xf8Q\xe2u!" +encyrpted_value = "ZW5jcnlwdGVkIGRhdGE=" +encyrpted_bytes = b"ZW5jcnlwdGVkIGRhdGE=" +data_to_b4_encode = b"encrypted data" +b64_decoded = b"decrypted data" +schema_version = "1.1" +joined_key = b"\x9d\x17\xaf\xc8\xdeF\x1b.\x0e\xa9\xb5['\x04\xed\x1f\xb9?\xc5Ig\x80\xd5\x8d\x8aT\xd7\xf8Q\xe2u!" +serial_bytes = b"\xf8\x96Z\x1c:}\xb5\xdf\x94\x8d\x0f\x807\xe6)\x8f\xf5!\xee}\xc2\xfa\xb3\t\xb9\xe4\r7\x19\x08\xa5b" +base64_decoded_serial = b"g\xbe\xdb" +decrypted_val1 = "BiV9YZEuSRAudqvz7Gs+bg==" +decrypted_val2 = "q4LwnFdoryzbZJM5mCAnpA==" +item = { + "secret": "mysecret", + "cacert": "mycacert", + "path": "/var", + "ip": "192.168.12.23", +} + + def exception_message(message): return "database exception " + message @@ -195,6 +226,1232 @@ class TestEncryption(unittest.TestCase): ) +class AsyncMock(MagicMock): + async def __call__(self, *args, **kwargs): + args = deepcopy(args) + kwargs = deepcopy(kwargs) + return super(AsyncMock, self).__call__(*args, **kwargs) + + +class CopyingMock(MagicMock): + def __call__(self, *args, **kwargs): + args = deepcopy(args) + kwargs = deepcopy(kwargs) + return super(CopyingMock, self).__call__(*args, **kwargs) + + +def check_if_assert_not_called(mocks: list): + for mocking in mocks: + mocking.assert_not_called() + + +class TestBaseEncryption(unittest.TestCase): + @patch("logging.getLogger", autospec=True) + def setUp(self, mock_logger): + mock_logger = logging.getLogger() + mock_logger.disabled = True + self.db_base = DbBase() + self.mock_cipher = CopyingMock() + self.db_base.encoding_type = encoding_type + self.db_base.encrypt_mode = encyrpt_mode + self.db_base.secret_key = secret_key + self.mock_padded_msg = CopyingMock() + + def test_pad_data_len_not_multiplication_of_16(self): + data = "hello word hello hello word hello word" + data_len = len(data) + expected_len = 48 + padded = self.db_base.pad_data(data) + self.assertEqual(len(padded), expected_len) + self.assertTrue("\0" * (expected_len - data_len) in padded) + + def test_pad_data_len_multiplication_of_16(self): + data = "hello word!!!!!!" + padded = self.db_base.pad_data(data) + self.assertEqual(padded, data) + self.assertFalse("\0" in padded) + + def test_pad_data_empty_string(self): + data = "" + expected_len = 0 + padded = self.db_base.pad_data(data) + self.assertEqual(len(padded), expected_len) + self.assertFalse("\0" in padded) + + def test_pad_data_not_string(self): + data = None + with self.assertRaises(Exception) as err: + self.db_base.pad_data(data) + self.assertEqual( + str(err.exception), + "database exception Incorrect data type: type(None), string is expected.", + ) + + def test_unpad_data_null_char_at_right(self): + null_padded_data = "hell0word\0\0" + expected_length = len(null_padded_data) - 2 + unpadded = self.db_base.unpad_data(null_padded_data) + self.assertEqual(len(unpadded), expected_length) + self.assertFalse("\0" in unpadded) + self.assertTrue("0" in unpadded) + + def test_unpad_data_null_char_is_not_rightest(self): + null_padded_data = "hell0word\r\t\0\n" + expected_length = len(null_padded_data) + unpadded = self.db_base.unpad_data(null_padded_data) + self.assertEqual(len(unpadded), expected_length) + self.assertTrue("\0" in unpadded) + + def test_unpad_data_with_spaces_at_right(self): + null_padded_data = " hell0word\0 " + expected_length = len(null_padded_data) + unpadded = self.db_base.unpad_data(null_padded_data) + self.assertEqual(len(unpadded), expected_length) + self.assertTrue("\0" in unpadded) + + def test_unpad_data_empty_string(self): + data = "" + unpadded = self.db_base.unpad_data(data) + self.assertEqual(unpadded, "") + self.assertFalse("\0" in unpadded) + + def test_unpad_data_not_string(self): + data = None + with self.assertRaises(Exception) as err: + self.db_base.unpad_data(data) + self.assertEqual( + str(err.exception), + "database exception Incorrect data type: type(None), string is expected.", + ) + + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "pad_data") + def test__encrypt_value_schema_version_1_0_none_secret_key_none_salt( + self, mock_pad_data, mock_join_secret_key + ): + """schema_version 1.0, secret_key is None and salt is None.""" + schema_version = "1.0" + salt = None + self.db_base.secret_key = None + result = self.db_base._encrypt_value(value, schema_version, salt) + self.assertEqual(result, value) + check_if_assert_not_called([mock_pad_data, mock_join_secret_key]) + + @patch("osm_common.dbbase.b64encode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "pad_data") + def test__encrypt_value_schema_version_1_1_with_secret_key_exists_with_salt( + self, + mock_pad_data, + mock_join_secret_key, + mock_aes, + mock_b64_encode, + ): + """schema_version 1.1, secret_key exists, salt exists.""" + mock_aes.new.return_value = self.mock_cipher + self.mock_cipher.encrypt.return_value = data_to_b4_encode + self.mock_padded_msg.return_value = padded_value + mock_pad_data.return_value = self.mock_padded_msg + self.mock_padded_msg.encode.return_value = padded_encoded_value + + mock_b64_encode.return_value = encyrpted_bytes + + result = self.db_base._encrypt_value(value, schema_version, salt) + + self.assertTrue(isinstance(result, str)) + self.assertEqual(result, encyrpted_value) + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB) + mock_pad_data.assert_called_once_with(value) + mock_b64_encode.assert_called_once_with(data_to_b4_encode) + self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value) + self.mock_padded_msg.encode.assert_called_with(encoding_type) + + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "pad_data") + def test__encrypt_value_schema_version_1_0_secret_key_not_exists( + self, mock_pad_data, mock_join_secret_key + ): + """schema_version 1.0, secret_key is None, salt exists.""" + schema_version = "1.0" + self.db_base.secret_key = None + result = self.db_base._encrypt_value(value, schema_version, salt) + self.assertEqual(result, value) + check_if_assert_not_called([mock_pad_data, mock_join_secret_key]) + + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "pad_data") + def test__encrypt_value_schema_version_1_1_secret_key_not_exists( + self, mock_pad_data, mock_join_secret_key + ): + """schema_version 1.1, secret_key is None, salt exists.""" + self.db_base.secret_key = None + result = self.db_base._encrypt_value(value, schema_version, salt) + self.assertEqual(result, value) + check_if_assert_not_called([mock_pad_data, mock_join_secret_key]) + + @patch("osm_common.dbbase.b64encode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "pad_data") + def test__encrypt_value_schema_version_1_1_secret_key_exists_without_salt( + self, + mock_pad_data, + mock_join_secret_key, + mock_aes, + mock_b64_encode, + ): + """schema_version 1.1, secret_key exists, salt is None.""" + salt = None + mock_aes.new.return_value = self.mock_cipher + self.mock_cipher.encrypt.return_value = data_to_b4_encode + + self.mock_padded_msg.return_value = padded_value + mock_pad_data.return_value = self.mock_padded_msg + self.mock_padded_msg.encode.return_value = padded_encoded_value + + mock_b64_encode.return_value = encyrpted_bytes + + result = self.db_base._encrypt_value(value, schema_version, salt) + + self.assertEqual(result, encyrpted_value) + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB) + mock_pad_data.assert_called_once_with(value) + mock_b64_encode.assert_called_once_with(data_to_b4_encode) + self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value) + self.mock_padded_msg.encode.assert_called_with(encoding_type) + + @patch("osm_common.dbbase.b64encode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "pad_data") + def test__encrypt_value_invalid_encrpt_mode( + self, + mock_pad_data, + mock_join_secret_key, + mock_aes, + mock_b64_encode, + ): + """encrypt_mode is invalid.""" + mock_aes.new.side_effect = Exception("Invalid ciphering mode.") + self.db_base.encrypt_mode = "AES.MODE_XXX" + + with self.assertRaises(Exception) as err: + self.db_base._encrypt_value(value, schema_version, salt) + + self.assertEqual(str(err.exception), "Invalid ciphering mode.") + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], "AES.MODE_XXX") + check_if_assert_not_called([mock_pad_data, mock_b64_encode]) + + @patch("osm_common.dbbase.b64encode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "pad_data") + def test__encrypt_value_schema_version_1_1_secret_key_exists_value_none( + self, + mock_pad_data, + mock_join_secret_key, + mock_aes, + mock_b64_encode, + ): + """schema_version 1.1, secret_key exists, value is None.""" + value = None + mock_aes.new.return_value = self.mock_cipher + mock_pad_data.side_effect = DbException( + "Incorrect data type: type(None), string is expected." + ) + + with self.assertRaises(Exception) as err: + self.db_base._encrypt_value(value, schema_version, salt) + self.assertEqual( + str(err.exception), + "database exception Incorrect data type: type(None), string is expected.", + ) + + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB) + mock_pad_data.assert_called_once_with(value) + check_if_assert_not_called( + [mock_b64_encode, self.mock_cipher.encrypt, mock_b64_encode] + ) + + @patch("osm_common.dbbase.b64encode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "pad_data") + def test__encrypt_value_join_secret_key_raises( + self, + mock_pad_data, + mock_join_secret_key, + mock_aes, + mock_b64_encode, + ): + """Method join_secret_key raises DbException.""" + salt = b"3434o34-3wewrwr-222424-2242dwew" + + mock_join_secret_key.side_effect = DbException("Unexpected type") + + mock_aes.new.return_value = self.mock_cipher + + with self.assertRaises(Exception) as err: + self.db_base._encrypt_value(value, schema_version, salt) + + self.assertEqual(str(err.exception), "database exception Unexpected type") + check_if_assert_not_called( + [mock_pad_data, mock_aes.new, mock_b64_encode, self.mock_cipher.encrypt] + ) + mock_join_secret_key.assert_called_once_with(salt) + + @patch("osm_common.dbbase.b64encode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "pad_data") + def test__encrypt_value_schema_version_1_1_secret_key_exists_b64_encode_raises( + self, + mock_pad_data, + mock_join_secret_key, + mock_aes, + mock_b64_encode, + ): + """schema_version 1.1, secret_key exists, b64encode raises TypeError.""" + mock_aes.new.return_value = self.mock_cipher + self.mock_cipher.encrypt.return_value = "encrypted data" + + self.mock_padded_msg.return_value = padded_value + mock_pad_data.return_value = self.mock_padded_msg + self.mock_padded_msg.encode.return_value = padded_encoded_value + + mock_b64_encode.side_effect = TypeError( + "A bytes-like object is required, not 'str'" + ) + + with self.assertRaises(Exception) as error: + self.db_base._encrypt_value(value, schema_version, salt) + self.assertEqual( + str(error.exception), "A bytes-like object is required, not 'str'" + ) + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB) + mock_pad_data.assert_called_once_with(value) + self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value) + self.mock_padded_msg.encode.assert_called_with(encoding_type) + mock_b64_encode.assert_called_once_with("encrypted data") + + @patch("osm_common.dbbase.b64encode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "pad_data") + def test__encrypt_value_cipher_encrypt_raises( + self, + mock_pad_data, + mock_join_secret_key, + mock_aes, + mock_b64_encode, + ): + """AES encrypt method raises Exception.""" + mock_aes.new.return_value = self.mock_cipher + self.mock_cipher.encrypt.side_effect = Exception("Invalid data type.") + + self.mock_padded_msg.return_value = padded_value + mock_pad_data.return_value = self.mock_padded_msg + self.mock_padded_msg.encode.return_value = padded_encoded_value + + with self.assertRaises(Exception) as error: + self.db_base._encrypt_value(value, schema_version, salt) + + self.assertEqual(str(error.exception), "Invalid data type.") + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB) + mock_pad_data.assert_called_once_with(value) + self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value) + self.mock_padded_msg.encode.assert_called_with(encoding_type) + mock_b64_encode.assert_not_called() + + @patch.object(DbBase, "get_secret_key") + @patch.object(DbBase, "_encrypt_value") + def test_encrypt_without_schema_version_without_salt( + self, mock_encrypt_value, mock_get_secret_key + ): + """schema and salt is None.""" + mock_encrypt_value.return_value = encyrpted_value + result = self.db_base.encrypt(value) + mock_encrypt_value.assert_called_once_with(value, None, None) + mock_get_secret_key.assert_called_once() + self.assertEqual(result, encyrpted_value) + + @patch.object(DbBase, "get_secret_key") + @patch.object(DbBase, "_encrypt_value") + def test_encrypt_with_schema_version_with_salt( + self, mock_encrypt_value, mock_get_secret_key + ): + """schema version exists, salt is None.""" + mock_encrypt_value.return_value = encyrpted_value + result = self.db_base.encrypt(value, schema_version, salt) + mock_encrypt_value.assert_called_once_with(value, schema_version, salt) + mock_get_secret_key.assert_called_once() + self.assertEqual(result, encyrpted_value) + + @patch.object(DbBase, "get_secret_key") + @patch.object(DbBase, "_encrypt_value") + def test_encrypt_get_secret_key_raises( + self, mock_encrypt_value, mock_get_secret_key + ): + """get_secret_key method raises DbException.""" + mock_get_secret_key.side_effect = DbException("KeyError") + with self.assertRaises(Exception) as error: + self.db_base.encrypt(value) + self.assertEqual(str(error.exception), "database exception KeyError") + mock_encrypt_value.assert_not_called() + mock_get_secret_key.assert_called_once() + + @patch.object(DbBase, "get_secret_key") + @patch.object(DbBase, "_encrypt_value") + def test_encrypt_encrypt_raises(self, mock_encrypt_value, mock_get_secret_key): + """_encrypt method raises DbException.""" + mock_encrypt_value.side_effect = DbException( + "Incorrect data type: type(None), string is expected." + ) + with self.assertRaises(Exception) as error: + self.db_base.encrypt(value, schema_version, salt) + self.assertEqual( + str(error.exception), + "database exception Incorrect data type: type(None), string is expected.", + ) + mock_encrypt_value.assert_called_once_with(value, schema_version, salt) + mock_get_secret_key.assert_called_once() + + @patch("osm_common.dbbase.b64decode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "unpad_data") + def test__decrypt_value_schema_version_1_1_secret_key_exists_without_salt( + self, + mock_unpad_data, + mock_join_secret_key, + mock_aes, + mock_b64_decode, + ): + """schema_version 1.1, secret_key exists, salt is None.""" + salt = None + mock_aes.new.return_value = self.mock_cipher + self.mock_cipher.decrypt.return_value = padded_encoded_value + + mock_b64_decode.return_value = b64_decoded + + mock_unpad_data.return_value = value + + result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt) + self.assertEqual(result, value) + + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB) + mock_unpad_data.assert_called_once_with("private key txt\0") + mock_b64_decode.assert_called_once_with(encyrpted_value) + self.mock_cipher.decrypt.assert_called_once_with(b64_decoded) + + @patch("osm_common.dbbase.b64decode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "unpad_data") + def test__decrypt_value_schema_version_1_1_secret_key_exists_with_salt( + self, + mock_unpad_data, + mock_join_secret_key, + mock_aes, + mock_b64_decode, + ): + """schema_version 1.1, secret_key exists, salt is None.""" + mock_aes.new.return_value = self.mock_cipher + self.mock_cipher.decrypt.return_value = padded_encoded_value + + mock_b64_decode.return_value = b64_decoded + + mock_unpad_data.return_value = value + + result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt) + self.assertEqual(result, value) + + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB) + mock_unpad_data.assert_called_once_with("private key txt\0") + mock_b64_decode.assert_called_once_with(encyrpted_value) + self.mock_cipher.decrypt.assert_called_once_with(b64_decoded) + + @patch("osm_common.dbbase.b64decode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "unpad_data") + def test__decrypt_value_schema_version_1_1_without_secret_key( + self, + mock_unpad_data, + mock_join_secret_key, + mock_aes, + mock_b64_decode, + ): + """schema_version 1.1, secret_key is None, salt exists.""" + self.db_base.secret_key = None + + result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt) + + self.assertEqual(result, encyrpted_value) + check_if_assert_not_called( + [ + mock_join_secret_key, + mock_aes.new, + mock_unpad_data, + mock_b64_decode, + self.mock_cipher.decrypt, + ] + ) + + @patch("osm_common.dbbase.b64decode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "unpad_data") + def test__decrypt_value_schema_version_1_0_with_secret_key( + self, + mock_unpad_data, + mock_join_secret_key, + mock_aes, + mock_b64_decode, + ): + """schema_version 1.0, secret_key exists, salt exists.""" + schema_version = "1.0" + result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt) + + self.assertEqual(result, encyrpted_value) + check_if_assert_not_called( + [ + mock_join_secret_key, + mock_aes.new, + mock_unpad_data, + mock_b64_decode, + self.mock_cipher.decrypt, + ] + ) + + @patch("osm_common.dbbase.b64decode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "unpad_data") + def test__decrypt_value_join_secret_key_raises( + self, + mock_unpad_data, + mock_join_secret_key, + mock_aes, + mock_b64_decode, + ): + """_join_secret_key raises TypeError.""" + salt = object() + mock_join_secret_key.side_effect = TypeError("'type' object is not iterable") + + with self.assertRaises(Exception) as error: + self.db_base._decrypt_value(encyrpted_value, schema_version, salt) + self.assertEqual(str(error.exception), "'type' object is not iterable") + + mock_join_secret_key.assert_called_once_with(salt) + check_if_assert_not_called( + [mock_aes.new, mock_unpad_data, mock_b64_decode, self.mock_cipher.decrypt] + ) + + @patch("osm_common.dbbase.b64decode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "unpad_data") + def test__decrypt_value_b64decode_raises( + self, + mock_unpad_data, + mock_join_secret_key, + mock_aes, + mock_b64_decode, + ): + """b64decode raises TypeError.""" + mock_b64_decode.side_effect = TypeError( + "A str-like object is required, not 'bytes'" + ) + with self.assertRaises(Exception) as error: + self.db_base._decrypt_value(encyrpted_value, schema_version, salt) + self.assertEqual( + str(error.exception), "A str-like object is required, not 'bytes'" + ) + + mock_b64_decode.assert_called_once_with(encyrpted_value) + mock_join_secret_key.assert_called_once_with(salt) + check_if_assert_not_called( + [mock_aes.new, self.mock_cipher.decrypt, mock_unpad_data] + ) + + @patch("osm_common.dbbase.b64decode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "unpad_data") + def test__decrypt_value_invalid_encrypt_mode( + self, + mock_unpad_data, + mock_join_secret_key, + mock_aes, + mock_b64_decode, + ): + """Invalid AES encrypt mode.""" + mock_aes.new.side_effect = Exception("Invalid ciphering mode.") + self.db_base.encrypt_mode = "AES.MODE_XXX" + + mock_b64_decode.return_value = b64_decoded + with self.assertRaises(Exception) as error: + self.db_base._decrypt_value(encyrpted_value, schema_version, salt) + + self.assertEqual(str(error.exception), "Invalid ciphering mode.") + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], "AES.MODE_XXX") + mock_b64_decode.assert_called_once_with(encyrpted_value) + check_if_assert_not_called([mock_unpad_data, self.mock_cipher.decrypt]) + + @patch("osm_common.dbbase.b64decode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "unpad_data") + def test__decrypt_value_cipher_decrypt_raises( + self, + mock_unpad_data, + mock_join_secret_key, + mock_aes, + mock_b64_decode, + ): + """AES decrypt raises Exception.""" + mock_b64_decode.return_value = b64_decoded + + mock_aes.new.return_value = self.mock_cipher + self.mock_cipher.decrypt.side_effect = Exception("Invalid data type.") + + with self.assertRaises(Exception) as error: + self.db_base._decrypt_value(encyrpted_value, schema_version, salt) + self.assertEqual(str(error.exception), "Invalid data type.") + + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB) + mock_b64_decode.assert_called_once_with(encyrpted_value) + self.mock_cipher.decrypt.assert_called_once_with(b64_decoded) + mock_unpad_data.assert_not_called() + + @patch("osm_common.dbbase.b64decode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "unpad_data") + def test__decrypt_value_decode_raises( + self, + mock_unpad_data, + mock_join_secret_key, + mock_aes, + mock_b64_decode, + ): + """Decode raises UnicodeDecodeError.""" + mock_aes.new.return_value = self.mock_cipher + self.mock_cipher.decrypt.return_value = b"\xd0\x000091" + + mock_b64_decode.return_value = b64_decoded + + mock_unpad_data.return_value = value + with self.assertRaises(Exception) as error: + self.db_base._decrypt_value(encyrpted_value, schema_version, salt) + self.assertEqual( + str(error.exception), + "database exception Cannot decrypt information. Are you using same COMMONKEY in all OSM components?", + ) + self.assertEqual(type(error.exception), DbException) + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB) + mock_b64_decode.assert_called_once_with(encyrpted_value) + self.mock_cipher.decrypt.assert_called_once_with(b64_decoded) + mock_unpad_data.assert_not_called() + + @patch("osm_common.dbbase.b64decode") + @patch("osm_common.dbbase.AES") + @patch.object(DbBase, "_join_secret_key") + @patch.object(DbBase, "unpad_data") + def test__decrypt_value_unpad_data_raises( + self, + mock_unpad_data, + mock_join_secret_key, + mock_aes, + mock_b64_decode, + ): + """Method unpad_data raises error.""" + mock_decrypted_message = MagicMock() + mock_decrypted_message.decode.return_value = None + mock_aes.new.return_value = self.mock_cipher + self.mock_cipher.decrypt.return_value = mock_decrypted_message + mock_unpad_data.side_effect = DbException( + "Incorrect data type: type(None), string is expected." + ) + mock_b64_decode.return_value = b64_decoded + + with self.assertRaises(Exception) as error: + self.db_base._decrypt_value(encyrpted_value, schema_version, salt) + self.assertEqual( + str(error.exception), + "database exception Incorrect data type: type(None), string is expected.", + ) + self.assertEqual(type(error.exception), DbException) + mock_join_secret_key.assert_called_once_with(salt) + _call_mock_aes_new = mock_aes.new.call_args_list[0].args + self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB) + mock_b64_decode.assert_called_once_with(encyrpted_value) + self.mock_cipher.decrypt.assert_called_once_with(b64_decoded) + mock_decrypted_message.decode.assert_called_once_with( + self.db_base.encoding_type + ) + mock_unpad_data.assert_called_once_with(None) + + @patch.object(DbBase, "get_secret_key") + @patch.object(DbBase, "_decrypt_value") + def test_decrypt_without_schema_version_without_salt( + self, mock_decrypt_value, mock_get_secret_key + ): + """schema_version is None, salt is None.""" + mock_decrypt_value.return_value = encyrpted_value + result = self.db_base.decrypt(value) + mock_decrypt_value.assert_called_once_with(value, None, None) + mock_get_secret_key.assert_called_once() + self.assertEqual(result, encyrpted_value) + + @patch.object(DbBase, "get_secret_key") + @patch.object(DbBase, "_decrypt_value") + def test_decrypt_with_schema_version_with_salt( + self, mock_decrypt_value, mock_get_secret_key + ): + """schema_version and salt exist.""" + mock_decrypt_value.return_value = encyrpted_value + result = self.db_base.decrypt(value, schema_version, salt) + mock_decrypt_value.assert_called_once_with(value, schema_version, salt) + mock_get_secret_key.assert_called_once() + self.assertEqual(result, encyrpted_value) + + @patch.object(DbBase, "get_secret_key") + @patch.object(DbBase, "_decrypt_value") + def test_decrypt_get_secret_key_raises( + self, mock_decrypt_value, mock_get_secret_key + ): + """Method get_secret_key raises KeyError.""" + mock_get_secret_key.side_effect = DbException("KeyError") + with self.assertRaises(Exception) as error: + self.db_base.decrypt(value) + self.assertEqual(str(error.exception), "database exception KeyError") + mock_decrypt_value.assert_not_called() + mock_get_secret_key.assert_called_once() + + @patch.object(DbBase, "get_secret_key") + @patch.object(DbBase, "_decrypt_value") + def test_decrypt_decrypt_value_raises( + self, mock_decrypt_value, mock_get_secret_key + ): + """Method _decrypt raises error.""" + mock_decrypt_value.side_effect = DbException( + "Incorrect data type: type(None), string is expected." + ) + with self.assertRaises(Exception) as error: + self.db_base.decrypt(value, schema_version, salt) + self.assertEqual( + str(error.exception), + "database exception Incorrect data type: type(None), string is expected.", + ) + mock_decrypt_value.assert_called_once_with(value, schema_version, salt) + mock_get_secret_key.assert_called_once() + + def test_encrypt_decrypt_with_schema_version_1_1_with_salt(self): + """Encrypt and decrypt with schema version 1.1, salt exists.""" + encrypted_msg = self.db_base.encrypt(value, schema_version, salt) + decrypted_msg = self.db_base.decrypt(encrypted_msg, schema_version, salt) + self.assertEqual(value, decrypted_msg) + + def test_encrypt_decrypt_with_schema_version_1_0_with_salt(self): + """Encrypt and decrypt with schema version 1.0, salt exists.""" + schema_version = "1.0" + encrypted_msg = self.db_base.encrypt(value, schema_version, salt) + decrypted_msg = self.db_base.decrypt(encrypted_msg, schema_version, salt) + self.assertEqual(value, decrypted_msg) + + def test_encrypt_decrypt_with_schema_version_1_1_without_salt(self): + """Encrypt and decrypt with schema version 1.1 and without salt.""" + salt = None + encrypted_msg = self.db_base.encrypt(value, schema_version, salt) + decrypted_msg = self.db_base.decrypt(encrypted_msg, schema_version, salt) + self.assertEqual(value, decrypted_msg) + + +class TestAsyncEncryption(unittest.TestCase): + @patch("logging.getLogger", autospec=True) + def setUp(self, mock_logger): + mock_logger = logging.getLogger() + mock_logger.disabled = True + self.loop = asyncio.get_event_loop() + self.encryption = Encryption(uri="uri", config={}) + self.encryption.encoding_type = encoding_type + self.encryption.encrypt_mode = encyrpt_mode + self.encryption._secret_key = secret_key + self.admin_collection = Mock() + self.admin_collection.find_one = AsyncMock() + self.encryption._client = { + "osm": { + "admin": self.admin_collection, + } + } + + @patch.object(Encryption, "decrypt", new_callable=AsyncMock) + def test_decrypt_fields_with_item_with_fields(self, mock_decrypt): + """item and fields exist.""" + mock_decrypt.side_effect = [decrypted_val1, decrypted_val2] + input_item = copy.deepcopy(item) + expected_item = { + "secret": decrypted_val1, + "cacert": decrypted_val2, + "path": "/var", + "ip": "192.168.12.23", + } + fields = ["secret", "cacert"] + self.loop.run_until_complete( + self.encryption.decrypt_fields(input_item, fields, schema_version, salt) + ) + self.assertEqual(input_item, expected_item) + _call_mock_decrypt = mock_decrypt.call_args_list + self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.1", salt)) + self.assertEqual(_call_mock_decrypt[1].args, ("mycacert", "1.1", salt)) + + @patch.object(Encryption, "decrypt", new_callable=AsyncMock) + def test_decrypt_fields_empty_item_with_fields(self, mock_decrypt): + """item is empty and fields exists.""" + input_item = {} + fields = ["secret", "cacert"] + self.loop.run_until_complete( + self.encryption.decrypt_fields(input_item, fields, schema_version, salt) + ) + self.assertEqual(input_item, {}) + mock_decrypt.assert_not_called() + + @patch.object(Encryption, "decrypt", new_callable=AsyncMock) + def test_decrypt_fields_with_item_without_fields(self, mock_decrypt): + """item exists and fields is empty.""" + input_item = copy.deepcopy(item) + fields = [] + self.loop.run_until_complete( + self.encryption.decrypt_fields(input_item, fields, schema_version, salt) + ) + self.assertEqual(input_item, item) + mock_decrypt.assert_not_called() + + @patch.object(Encryption, "decrypt", new_callable=AsyncMock) + def test_decrypt_fields_with_item_with_single_field(self, mock_decrypt): + """item exists and field has single value.""" + mock_decrypt.return_value = decrypted_val1 + fields = ["secret"] + input_item = copy.deepcopy(item) + expected_item = { + "secret": decrypted_val1, + "cacert": "mycacert", + "path": "/var", + "ip": "192.168.12.23", + } + self.loop.run_until_complete( + self.encryption.decrypt_fields(input_item, fields, schema_version, salt) + ) + self.assertEqual(input_item, expected_item) + _call_mock_decrypt = mock_decrypt.call_args_list + self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.1", salt)) + + @patch.object(Encryption, "decrypt", new_callable=AsyncMock) + def test_decrypt_fields_with_item_with_field_none_salt_1_0_schema_version( + self, mock_decrypt + ): + """item exists and field has single value, salt is None, schema version is 1.0.""" + schema_version = "1.0" + salt = None + mock_decrypt.return_value = "mysecret" + input_item = copy.deepcopy(item) + fields = ["secret"] + self.loop.run_until_complete( + self.encryption.decrypt_fields(input_item, fields, schema_version, salt) + ) + self.assertEqual(input_item, item) + _call_mock_decrypt = mock_decrypt.call_args_list + self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.0", None)) + + @patch.object(Encryption, "decrypt", new_callable=AsyncMock) + def test_decrypt_fields_decrypt_raises(self, mock_decrypt): + """Method decrypt raises error.""" + mock_decrypt.side_effect = DbException( + "Incorrect data type: type(None), string is expected." + ) + fields = ["secret"] + input_item = copy.deepcopy(item) + with self.assertRaises(Exception) as error: + self.loop.run_until_complete( + self.encryption.decrypt_fields(input_item, fields, schema_version, salt) + ) + self.assertEqual( + str(error.exception), + "database exception Incorrect data type: type(None), string is expected.", + ) + self.assertEqual(input_item, item) + _call_mock_decrypt = mock_decrypt.call_args_list + self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.1", salt)) + + @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock) + @patch.object(Encryption, "_encrypt_value") + def test_encrypt(self, mock_encrypt_value, mock_get_secret_key): + """Method decrypt raises error.""" + mock_encrypt_value.return_value = encyrpted_value + result = self.loop.run_until_complete( + self.encryption.encrypt(value, schema_version, salt) + ) + self.assertEqual(result, encyrpted_value) + mock_get_secret_key.assert_called_once() + mock_encrypt_value.assert_called_once_with(value, schema_version, salt) + + @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock) + @patch.object(Encryption, "_encrypt_value") + def test_encrypt_get_secret_key_raises( + self, mock_encrypt_value, mock_get_secret_key + ): + """Method get_secret_key raises error.""" + mock_get_secret_key.side_effect = DbException("Unexpected type.") + with self.assertRaises(Exception) as error: + self.loop.run_until_complete( + self.encryption.encrypt(value, schema_version, salt) + ) + self.assertEqual(str(error.exception), "database exception Unexpected type.") + mock_get_secret_key.assert_called_once() + mock_encrypt_value.assert_not_called() + + @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock) + @patch.object(Encryption, "_encrypt_value") + def test_encrypt_get_encrypt_raises(self, mock_encrypt_value, mock_get_secret_key): + """Method _encrypt raises error.""" + mock_encrypt_value.side_effect = TypeError( + "A bytes-like object is required, not 'str'" + ) + with self.assertRaises(Exception) as error: + self.loop.run_until_complete( + self.encryption.encrypt(value, schema_version, salt) + ) + self.assertEqual( + str(error.exception), "A bytes-like object is required, not 'str'" + ) + mock_get_secret_key.assert_called_once() + mock_encrypt_value.assert_called_once_with(value, schema_version, salt) + + @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock) + @patch.object(Encryption, "_decrypt_value") + def test_decrypt(self, mock_decrypt_value, mock_get_secret_key): + """Decrypted successfully.""" + mock_decrypt_value.return_value = value + result = self.loop.run_until_complete( + self.encryption.decrypt(encyrpted_value, schema_version, salt) + ) + self.assertEqual(result, value) + mock_get_secret_key.assert_called_once() + mock_decrypt_value.assert_called_once_with( + encyrpted_value, schema_version, salt + ) + + @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock) + @patch.object(Encryption, "_decrypt_value") + def test_decrypt_get_secret_key_raises( + self, mock_decrypt_value, mock_get_secret_key + ): + """Method get_secret_key raises error.""" + mock_get_secret_key.side_effect = DbException("Unexpected type.") + with self.assertRaises(Exception) as error: + self.loop.run_until_complete( + self.encryption.decrypt(encyrpted_value, schema_version, salt) + ) + self.assertEqual(str(error.exception), "database exception Unexpected type.") + mock_get_secret_key.assert_called_once() + mock_decrypt_value.assert_not_called() + + @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock) + @patch.object(Encryption, "_decrypt_value") + def test_decrypt_decrypt_value_raises( + self, mock_decrypt_value, mock_get_secret_key + ): + """Method get_secret_key raises error.""" + mock_decrypt_value.side_effect = TypeError( + "A bytes-like object is required, not 'str'" + ) + with self.assertRaises(Exception) as error: + self.loop.run_until_complete( + self.encryption.decrypt(encyrpted_value, schema_version, salt) + ) + self.assertEqual( + str(error.exception), "A bytes-like object is required, not 'str'" + ) + mock_get_secret_key.assert_called_once() + mock_decrypt_value.assert_called_once_with( + encyrpted_value, schema_version, salt + ) + + def test_join_keys_string_key(self): + """key is string.""" + string_key = "sample key" + result = self.encryption._join_keys(string_key, secret_key) + self.assertEqual(result, joined_key) + self.assertTrue(isinstance(result, bytes)) + + def test_join_keys_bytes_key(self): + """key is bytes.""" + bytes_key = b"sample key" + result = self.encryption._join_keys(bytes_key, secret_key) + self.assertEqual(result, joined_key) + self.assertTrue(isinstance(result, bytes)) + self.assertEqual(len(result.decode("unicode_escape")), 32) + + def test_join_keys_int_key(self): + """key is int.""" + int_key = 923 + with self.assertRaises(Exception) as error: + self.encryption._join_keys(int_key, None) + self.assertEqual(str(error.exception), "'int' object is not iterable") + + def test_join_keys_none_secret_key(self): + """key is as bytes and secret key is None.""" + bytes_key = b"sample key" + result = self.encryption._join_keys(bytes_key, None) + self.assertEqual( + result, + b"sample key\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + ) + self.assertTrue(isinstance(result, bytes)) + self.assertEqual(len(result.decode("unicode_escape")), 32) + + def test_join_keys_none_key_none_secret_key(self): + """key is None and secret key is None.""" + with self.assertRaises(Exception) as error: + self.encryption._join_keys(None, None) + self.assertEqual(str(error.exception), "'NoneType' object is not iterable") + + def test_join_keys_none_key(self): + """key is None and secret key exists.""" + with self.assertRaises(Exception) as error: + self.encryption._join_keys(None, secret_key) + self.assertEqual(str(error.exception), "'NoneType' object is not iterable") + + @patch.object(Encryption, "_join_keys") + def test_join_secret_key_string_sample_key(self, mock_join_keys): + """key is None and secret key exists as string.""" + update_key = "sample key" + mock_join_keys.return_value = joined_key + result = self.encryption._join_secret_key(update_key) + self.assertEqual(result, joined_key) + self.assertTrue(isinstance(result, bytes)) + mock_join_keys.assert_called_once_with(update_key, secret_key) + + @patch.object(Encryption, "_join_keys") + def test_join_secret_key_byte_sample_key(self, mock_join_keys): + """key is None and secret key exists as bytes.""" + update_key = b"sample key" + mock_join_keys.return_value = joined_key + result = self.encryption._join_secret_key(update_key) + self.assertEqual(result, joined_key) + self.assertTrue(isinstance(result, bytes)) + mock_join_keys.assert_called_once_with(update_key, secret_key) + + @patch.object(Encryption, "_join_keys") + def test_join_secret_key_join_keys_raises(self, mock_join_keys): + """Method _join_secret_key raises.""" + update_key = 3434 + mock_join_keys.side_effect = TypeError("'int' object is not iterable") + with self.assertRaises(Exception) as error: + self.encryption._join_secret_key(update_key) + self.assertEqual(str(error.exception), "'int' object is not iterable") + mock_join_keys.assert_called_once_with(update_key, secret_key) + + @patch.object(Encryption, "_join_keys") + def test_get_secret_key_exists(self, mock_join_keys): + """secret_key exists.""" + self.encryption._secret_key = secret_key + self.loop.run_until_complete(self.encryption.get_secret_key()) + self.assertEqual(self.encryption.secret_key, secret_key) + mock_join_keys.assert_not_called() + + @patch.object(Encryption, "_join_keys") + @patch("osm_common.dbbase.b64decode") + def test_get_secret_key_not_exist_database_key_exist( + self, mock_b64decode, mock_join_keys + ): + """secret_key does not exist, database key exists.""" + self.encryption._secret_key = None + self.encryption._admin_collection.find_one.return_value = None + self.encryption._config = {"database_commonkey": "osm_new_key"} + mock_join_keys.return_value = joined_key + self.loop.run_until_complete(self.encryption.get_secret_key()) + self.assertEqual(self.encryption.secret_key, joined_key) + self.assertEqual(mock_join_keys.call_count, 1) + mock_b64decode.assert_not_called() + + @patch.object(Encryption, "_join_keys") + @patch("osm_common.dbbase.b64decode") + def test_get_secret_key_not_exist_with_database_key_version_data_exist_without_serial( + self, mock_b64decode, mock_join_keys + ): + """secret_key does not exist, database key exists.""" + self.encryption._secret_key = None + self.encryption._admin_collection.find_one.return_value = {"version": "1.0"} + self.encryption._config = {"database_commonkey": "osm_new_key"} + mock_join_keys.return_value = joined_key + self.loop.run_until_complete(self.encryption.get_secret_key()) + self.assertEqual(self.encryption.secret_key, joined_key) + self.assertEqual(mock_join_keys.call_count, 1) + mock_b64decode.assert_not_called() + self.encryption._admin_collection.find_one.assert_called_once_with( + {"_id": "version"} + ) + _call_mock_join_keys = mock_join_keys.call_args_list + self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None)) + + @patch.object(Encryption, "_join_keys") + @patch("osm_common.dbbase.b64decode") + def test_get_secret_key_not_exist_with_database_key_version_data_exist_with_serial( + self, mock_b64decode, mock_join_keys + ): + """secret_key does not exist, database key exists, version and serial exist + in admin collection.""" + self.encryption._secret_key = None + self.encryption._admin_collection.find_one.return_value = { + "version": "1.0", + "serial": serial_bytes, + } + self.encryption._config = {"database_commonkey": "osm_new_key"} + mock_join_keys.side_effect = [secret_key, joined_key] + mock_b64decode.return_value = base64_decoded_serial + self.loop.run_until_complete(self.encryption.get_secret_key()) + self.assertEqual(self.encryption.secret_key, joined_key) + self.assertEqual(mock_join_keys.call_count, 2) + mock_b64decode.assert_called_once_with(serial_bytes) + self.encryption._admin_collection.find_one.assert_called_once_with( + {"_id": "version"} + ) + _call_mock_join_keys = mock_join_keys.call_args_list + self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None)) + self.assertEqual( + _call_mock_join_keys[1].args, (base64_decoded_serial, secret_key) + ) + + @patch.object(Encryption, "_join_keys") + @patch("osm_common.dbbase.b64decode") + def test_get_secret_key_join_keys_raises(self, mock_b64decode, mock_join_keys): + """Method _join_keys raises.""" + self.encryption._secret_key = None + self.encryption._admin_collection.find_one.return_value = { + "version": "1.0", + "serial": serial_bytes, + } + self.encryption._config = {"database_commonkey": "osm_new_key"} + mock_join_keys.side_effect = DbException("Invalid data type.") + with self.assertRaises(Exception) as error: + self.loop.run_until_complete(self.encryption.get_secret_key()) + self.assertEqual(str(error.exception), "database exception Invalid data type.") + self.assertEqual(mock_join_keys.call_count, 1) + check_if_assert_not_called( + [mock_b64decode, self.encryption._admin_collection.find_one] + ) + _call_mock_join_keys = mock_join_keys.call_args_list + self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None)) + + @patch.object(Encryption, "_join_keys") + @patch("osm_common.dbbase.b64decode") + def test_get_secret_key_b64decode_raises(self, mock_b64decode, mock_join_keys): + """Method b64decode raises.""" + self.encryption._secret_key = None + self.encryption._admin_collection.find_one.return_value = { + "version": "1.0", + "serial": base64_decoded_serial, + } + self.encryption._config = {"database_commonkey": "osm_new_key"} + mock_join_keys.return_value = secret_key + mock_b64decode.side_effect = TypeError( + "A bytes-like object is required, not 'str'" + ) + with self.assertRaises(Exception) as error: + self.loop.run_until_complete(self.encryption.get_secret_key()) + self.assertEqual( + str(error.exception), "A bytes-like object is required, not 'str'" + ) + self.assertEqual(self.encryption.secret_key, None) + self.assertEqual(mock_join_keys.call_count, 1) + mock_b64decode.assert_called_once_with(base64_decoded_serial) + self.encryption._admin_collection.find_one.assert_called_once_with( + {"_id": "version"} + ) + _call_mock_join_keys = mock_join_keys.call_args_list + self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None)) + + @patch.object(Encryption, "_join_keys") + @patch("osm_common.dbbase.b64decode") + def test_get_secret_key_admin_collection_find_one_raises( + self, mock_b64decode, mock_join_keys + ): + """admin_collection find_one raises.""" + self.encryption._secret_key = None + self.encryption._admin_collection.find_one.side_effect = DbException( + "Connection failed." + ) + self.encryption._config = {"database_commonkey": "osm_new_key"} + mock_join_keys.return_value = secret_key + with self.assertRaises(Exception) as error: + self.loop.run_until_complete(self.encryption.get_secret_key()) + self.assertEqual(str(error.exception), "database exception Connection failed.") + self.assertEqual(self.encryption.secret_key, None) + self.assertEqual(mock_join_keys.call_count, 1) + mock_b64decode.assert_not_called() + self.encryption._admin_collection.find_one.assert_called_once_with( + {"_id": "version"} + ) + _call_mock_join_keys = mock_join_keys.call_args_list + self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None)) + + def test_encrypt_decrypt_with_schema_version_1_1_with_salt(self): + """Encrypt and decrypt with schema version 1.1, salt exists.""" + encrypted_msg = self.loop.run_until_complete( + self.encryption.encrypt(value, schema_version, salt) + ) + decrypted_msg = self.loop.run_until_complete( + self.encryption.decrypt(encrypted_msg, schema_version, salt) + ) + self.assertEqual(value, decrypted_msg) + + def test_encrypt_decrypt_with_schema_version_1_0_with_salt(self): + """Encrypt and decrypt with schema version 1.0, salt exists.""" + schema_version = "1.0" + encrypted_msg = self.loop.run_until_complete( + self.encryption.encrypt(value, schema_version, salt) + ) + decrypted_msg = self.loop.run_until_complete( + self.encryption.decrypt(encrypted_msg, schema_version, salt) + ) + self.assertEqual(value, decrypted_msg) + + def test_encrypt_decrypt_with_schema_version_1_1_without_salt(self): + """Encrypt and decrypt with schema version 1.1, without salt.""" + salt = None + with self.assertRaises(Exception) as error: + self.loop.run_until_complete( + self.encryption.encrypt(value, schema_version, salt) + ) + self.assertEqual(str(error.exception), "'NoneType' object is not iterable") + + class TestDeepUpdate(unittest.TestCase): def test_update_dict(self): # Original, patch, expected result diff --git a/releasenotes/notes/use_pycryptpdome-8ef05275b779994b.yaml b/releasenotes/notes/use_pycryptpdome-8ef05275b779994b.yaml new file mode 100644 index 0000000..94d5437 --- /dev/null +++ b/releasenotes/notes/use_pycryptpdome-8ef05275b779994b.yaml @@ -0,0 +1,21 @@ +####################################################################################### +# Copyright ETSI Contributors and Others. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. +####################################################################################### +--- +features: + - | + Feature 10950: Remove the pycrypto library and change encrypt and decrypt methods to work with pycryptodome. + Move encryption methods from N2VC to common. diff --git a/requirements.in b/requirements.in index b8e0f2e..21033b8 100644 --- a/requirements.in +++ b/requirements.in @@ -16,5 +16,6 @@ pymongo<4 aiokafka pyyaml==5.4.1 -pycrypto -dataclasses \ No newline at end of file +pycryptodome +dataclasses +motor==1.3.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 5995b6a..fd903ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,11 +22,15 @@ dataclasses==0.6 # via -r requirements.in kafka-python==2.0.2 # via aiokafka +motor==1.3.1 + # via -r requirements.in packaging==23.0 # via aiokafka -pycrypto==2.6.1 +pycryptodome==3.17 # via -r requirements.in pymongo==3.13.0 - # via -r requirements.in + # via + # -r requirements.in + # motor pyyaml==5.4.1 # via -r requirements.in -- 2.25.1