# See the License for the specific language governing permissions and
# limitations under the License.
-import yaml
+from base64 import b64decode, b64encode
+from copy import deepcopy
+from http import HTTPStatus
import logging
import re
-from http import HTTPStatus
-from copy import deepcopy
+from threading import Lock
+import typing
+
+
from Crypto.Cipher import AES
-from base64 import b64decode, b64encode
+from motor.motor_asyncio import AsyncIOMotorClient
from osm_common.common_utils import FakeLock
-from threading import Lock
+import yaml
__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
+DB_NAME = "osm"
+
+
class DbException(Exception):
def __init__(self, message, http_code=HTTPStatus.NOT_FOUND):
self.http_code = http_code
class DbBase(object):
- def __init__(self, logger_name="db", lock=False):
+ def __init__(self, encoding_type="ascii", logger_name="db", lock=False):
"""
Constructor of dbBase
:param logger_name: logging name
"""
self.logger = logging.getLogger(logger_name)
self.secret_key = None # 32 bytes length array used for encrypt/decrypt
+ self.encrypt_mode = AES.MODE_ECB
+ self.encoding_type = encoding_type
if not lock:
self.lock = FakeLock()
elif lock is True:
"""
pass
- def encrypt(self, value, schema_version=None, salt=None):
- """
- Encrypt a value
- :param value: value to be encrypted. It is string/unicode
- :param schema_version: used for version control. If None or '1.0' no encryption is done.
- If '1.1' symmetric AES encryption is done
- :param salt: optional salt to be used. Must be str
- :return: Encrypted content of value
+ @staticmethod
+ def pad_data(value: str) -> str:
+ if not isinstance(value, str):
+ raise DbException(
+ f"Incorrect data type: type({value}), string is expected."
+ )
+ return value + ("\0" * ((16 - len(value)) % 16))
+
+ @staticmethod
+ def unpad_data(value: str) -> str:
+ if not isinstance(value, str):
+ raise DbException(
+ f"Incorrect data type: type({value}), string is expected."
+ )
+ return value.rstrip("\0")
+
+ def _encrypt_value(self, value: str, schema_version: str, salt: str):
+ """Encrypt a value.
+
+ Args:
+ value (str): value to be encrypted. It is string/unicode
+ schema_version (str): used for version control. If None or '1.0' no encryption is done.
+ If '1.1' symmetric AES encryption is done
+ salt (str): optional salt to be used. Must be str
+
+ Returns:
+ Encrypted content of value (str)
+
"""
- self.get_secret_key()
if not self.secret_key or not schema_version or schema_version == "1.0":
return value
+
else:
+ # Secret key as bytes
secret_key = self._join_secret_key(salt)
- cipher = AES.new(secret_key)
- padded_private_msg = value + ("\0" * ((16 - len(value)) % 16))
- encrypted_msg = cipher.encrypt(padded_private_msg)
+ cipher = AES.new(secret_key, self.encrypt_mode)
+ # Padded data as string
+ padded_private_msg = self.pad_data(value)
+ # Padded data as bytes
+ padded_private_msg_bytes = padded_private_msg.encode(self.encoding_type)
+ # Encrypt padded data
+ encrypted_msg = cipher.encrypt(padded_private_msg_bytes)
+ # Base64 encoded encrypted data
encoded_encrypted_msg = b64encode(encrypted_msg)
- return encoded_encrypted_msg.decode("ascii")
+ # Converting to string
+ return encoded_encrypted_msg.decode(self.encoding_type)
+
+ def encrypt(self, value: str, schema_version: str = None, salt: str = None) -> str:
+ """Encrypt a value.
+
+ Args:
+ value (str): value to be encrypted. It is string/unicode
+ schema_version (str): used for version control. If None or '1.0' no encryption is done.
+ If '1.1' symmetric AES encryption is done
+ salt (str): optional salt to be used. Must be str
+
+ Returns:
+ Encrypted content of value (str)
- def decrypt(self, value, schema_version=None, salt=None):
- """
- Decrypt an encrypted value
- :param value: value to be decrypted. It is a base64 string
- :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done.
- If '1.1' symmetric AES encryption has been done
- :param salt: optional salt to be used
- :return: Plain content of value
"""
self.get_secret_key()
+ return self._encrypt_value(value, schema_version, salt)
+
+ def _decrypt_value(self, value: str, schema_version: str, salt: str) -> str:
+ """Decrypt an encrypted value.
+ Args:
+
+ value (str): value to be decrypted. It is a base64 string
+ schema_version (str): used for known encryption method used.
+ If None or '1.0' no encryption has been done.
+ If '1.1' symmetric AES encryption has been done
+ salt (str): optional salt to be used
+
+ Returns:
+ Plain content of value (str)
+
+ """
if not self.secret_key or not schema_version or schema_version == "1.0":
return value
+
else:
secret_key = self._join_secret_key(salt)
+ # Decoding encrypted data, output bytes
encrypted_msg = b64decode(value)
- cipher = AES.new(secret_key)
+ cipher = AES.new(secret_key, self.encrypt_mode)
+ # Decrypted data, output bytes
decrypted_msg = cipher.decrypt(encrypted_msg)
try:
- unpadded_private_msg = decrypted_msg.decode().rstrip("\0")
+ # Converting to string
+ private_msg = decrypted_msg.decode(self.encoding_type)
except UnicodeDecodeError:
raise DbException(
"Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
- return unpadded_private_msg
+ # Unpadded data as string
+ return self.unpad_data(private_msg)
+
+ def decrypt(self, value: str, schema_version: str = None, salt: str = None) -> str:
+ """Decrypt an encrypted value.
+ Args:
+
+ value (str): value to be decrypted. It is a base64 string
+ schema_version (str): used for known encryption method used.
+ If None or '1.0' no encryption has been done.
+ If '1.1' symmetric AES encryption has been done
+ salt (str): optional salt to be used
+
+ Returns:
+ Plain content of value (str)
+
+ """
+ self.get_secret_key()
+ return self._decrypt_value(value, schema_version, salt)
def encrypt_decrypt_fields(
self, item, action, fields=None, flags=None, schema_version=None, salt=None
def deep_update(dict_to_change, dict_reference):
"""Maintained for backward compatibility. Use deep_update_rfc7396 instead"""
return deep_update_rfc7396(dict_to_change, dict_reference)
+
+
+class Encryption(DbBase):
+ def __init__(self, uri, config, encoding_type="ascii", logger_name="db"):
+ """Constructor.
+
+ Args:
+ uri (str): Connection string to connect to the database.
+ config (dict): Additional database info
+ encoding_type (str): ascii, utf-8 etc.
+ logger_name (str): Logger name
+
+ """
+ self._secret_key = None # 32 bytes length array used for encrypt/decrypt
+ self.encrypt_mode = AES.MODE_ECB
+ super(Encryption, self).__init__(
+ encoding_type=encoding_type, logger_name=logger_name
+ )
+ self._client = AsyncIOMotorClient(uri)
+ self._config = config
+
+ @property
+ def secret_key(self):
+ return self._secret_key
+
+ @secret_key.setter
+ def secret_key(self, value):
+ self._secret_key = value
+
+ @property
+ def _database(self):
+ return self._client[DB_NAME]
+
+ @property
+ def _admin_collection(self):
+ return self._database["admin"]
+
+ @property
+ def database_key(self):
+ return self._config.get("database_commonkey")
+
+ async def decrypt_fields(
+ self,
+ item: dict,
+ fields: typing.List[str],
+ schema_version: str = None,
+ salt: str = None,
+ ) -> None:
+ """Decrypt fields from a dictionary. Follows the same logic as in osm_common.
+
+ Args:
+
+ item (dict): Dictionary with the keys to be decrypted
+ fields (list): List of keys to decrypt
+ schema version (str): Schema version. (i.e. 1.11)
+ salt (str): Salt for the decryption
+
+ """
+ flags = re.I
+
+ async def process(_item):
+ if isinstance(_item, list):
+ for elem in _item:
+ await process(elem)
+ elif isinstance(_item, dict):
+ for key, val in _item.items():
+ if isinstance(val, str):
+ if any(re.search(f, key, flags) for f in fields):
+ _item[key] = await self.decrypt(val, schema_version, salt)
+ else:
+ await process(val)
+
+ await process(item)
+
+ async def encrypt(
+ self, value: str, schema_version: str = None, salt: str = None
+ ) -> str:
+ """Encrypt a value.
+
+ Args:
+ value (str): value to be encrypted. It is string/unicode
+ schema_version (str): used for version control. If None or '1.0' no encryption is done.
+ If '1.1' symmetric AES encryption is done
+ salt (str): optional salt to be used. Must be str
+
+ Returns:
+ Encrypted content of value (str)
+
+ """
+ await self.get_secret_key()
+ return self._encrypt_value(value, schema_version, salt)
+
+ async def decrypt(
+ self, value: str, schema_version: str = None, salt: str = None
+ ) -> str:
+ """Decrypt an encrypted value.
+ Args:
+
+ value (str): value to be decrypted. It is a base64 string
+ schema_version (str): used for known encryption method used.
+ If None or '1.0' no encryption has been done.
+ If '1.1' symmetric AES encryption has been done
+ salt (str): optional salt to be used
+
+ Returns:
+ Plain content of value (str)
+
+ """
+ await self.get_secret_key()
+ return self._decrypt_value(value, schema_version, salt)
+
+ def _join_secret_key(self, update_key: typing.Any) -> bytes:
+ """Join key with secret key.
+
+ Args:
+
+ update_key (str or bytes): str or bytes with the to update
+
+ Returns:
+
+ Joined key (bytes)
+ """
+ return self._join_keys(update_key, self.secret_key)
+
+ def _join_keys(self, key: typing.Any, secret_key: bytes) -> bytes:
+ """Join key with secret_key.
+
+ Args:
+
+ key (str or bytes): str or bytes of the key to update
+ secret_key (bytes): bytes of the secret key
+
+ Returns:
+
+ Joined key (bytes)
+ """
+ if isinstance(key, str):
+ update_key_bytes = key.encode(self.encoding_type)
+ else:
+ update_key_bytes = key
+ new_secret_key = bytearray(secret_key) if secret_key else bytearray(32)
+ for i, b in enumerate(update_key_bytes):
+ new_secret_key[i % 32] ^= b
+ return bytes(new_secret_key)
+
+ async def get_secret_key(self):
+ """Get secret key using the database key and the serial key in the DB.
+ The key is populated in the property self.secret_key.
+ """
+ if self.secret_key:
+ return
+ secret_key = None
+ if self.database_key:
+ secret_key = self._join_keys(self.database_key, None)
+ version_data = await self._admin_collection.find_one({"_id": "version"})
+ if version_data and version_data.get("serial"):
+ secret_key = self._join_keys(b64decode(version_data["serial"]), secret_key)
+ self._secret_key = secret_key