Feature 10950 Replace pycrypto with pycryptodome
[osm/common.git] / osm_common / dbbase.py
index 4021805..6b3a89a 100644 (file)
 # implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import asyncio
 from base64 import b64decode, b64encode
 from copy import deepcopy
 from http import HTTPStatus
 import logging
 import re
 from threading import Lock
 from base64 import b64decode, b64encode
 from copy import deepcopy
 from http import HTTPStatus
 import logging
 import re
 from threading import Lock
+import typing
+
 
 from Crypto.Cipher import AES
 
 from Crypto.Cipher import AES
+from motor.motor_asyncio import AsyncIOMotorClient
 from osm_common.common_utils import FakeLock
 import yaml
 
 __author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
 
 
 from osm_common.common_utils import FakeLock
 import yaml
 
 __author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
 
 
+DB_NAME = "osm"
+
+
 class DbException(Exception):
     def __init__(self, message, http_code=HTTPStatus.NOT_FOUND):
         self.http_code = http_code
 class DbException(Exception):
     def __init__(self, message, http_code=HTTPStatus.NOT_FOUND):
         self.http_code = http_code
@@ -36,7 +42,7 @@ class DbException(Exception):
 
 
 class DbBase(object):
 
 
 class DbBase(object):
-    def __init__(self, logger_name="db", lock=False):
+    def __init__(self, encoding_type="ascii", logger_name="db", lock=False):
         """
         Constructor of dbBase
         :param logger_name: logging name
         """
         Constructor of dbBase
         :param logger_name: logging name
@@ -47,6 +53,8 @@ class DbBase(object):
         """
         self.logger = logging.getLogger(logger_name)
         self.secret_key = None  # 32 bytes length array used for encrypt/decrypt
         """
         self.logger = logging.getLogger(logger_name)
         self.secret_key = None  # 32 bytes length array used for encrypt/decrypt
+        self.encrypt_mode = AES.MODE_ECB
+        self.encoding_type = encoding_type
         if not lock:
             self.lock = FakeLock()
         elif lock is True:
         if not lock:
             self.lock = FakeLock()
         elif lock is True:
@@ -263,51 +271,120 @@ class DbBase(object):
         """
         pass
 
         """
         pass
 
-    def encrypt(self, value, schema_version=None, salt=None):
-        """
-        Encrypt a value
-        :param value: value to be encrypted. It is string/unicode
-        :param schema_version: used for version control. If None or '1.0' no encryption is done.
-               If '1.1' symmetric AES encryption is done
-        :param salt: optional salt to be used. Must be str
-        :return: Encrypted content of value
+    @staticmethod
+    def pad_data(value: str) -> str:
+        if not isinstance(value, str):
+            raise DbException(
+                f"Incorrect data type: type({value}), string is expected."
+            )
+        return value + ("\0" * ((16 - len(value)) % 16))
+
+    @staticmethod
+    def unpad_data(value: str) -> str:
+        if not isinstance(value, str):
+            raise DbException(
+                f"Incorrect data type: type({value}), string is expected."
+            )
+        return value.rstrip("\0")
+
+    def _encrypt_value(self, value: str, schema_version: str, salt: str):
+        """Encrypt a value.
+
+        Args:
+            value   (str):                  value to be encrypted. It is string/unicode
+            schema_version  (str):          used for version control. If None or '1.0' no encryption is done.
+                                            If '1.1' symmetric AES encryption is done
+            salt    (str):                  optional salt to be used. Must be str
+
+        Returns:
+            Encrypted content of value  (str)
+
         """
         """
-        self.get_secret_key()
         if not self.secret_key or not schema_version or schema_version == "1.0":
             return value
         if not self.secret_key or not schema_version or schema_version == "1.0":
             return value
+
         else:
         else:
+            # Secret key as bytes
             secret_key = self._join_secret_key(salt)
             secret_key = self._join_secret_key(salt)
-            cipher = AES.new(secret_key)
-            padded_private_msg = value + ("\0" * ((16 - len(value)) % 16))
-            encrypted_msg = cipher.encrypt(padded_private_msg)
+            cipher = AES.new(secret_key, self.encrypt_mode)
+            # Padded data as string
+            padded_private_msg = self.pad_data(value)
+            # Padded data as bytes
+            padded_private_msg_bytes = padded_private_msg.encode(self.encoding_type)
+            # Encrypt padded data
+            encrypted_msg = cipher.encrypt(padded_private_msg_bytes)
+            # Base64 encoded encrypted data
             encoded_encrypted_msg = b64encode(encrypted_msg)
             encoded_encrypted_msg = b64encode(encrypted_msg)
-            return encoded_encrypted_msg.decode("ascii")
+            # Converting to string
+            return encoded_encrypted_msg.decode(self.encoding_type)
+
+    def encrypt(self, value: str, schema_version: str = None, salt: str = None) -> str:
+        """Encrypt a value.
+
+        Args:
+            value   (str):                  value to be encrypted. It is string/unicode
+            schema_version  (str):          used for version control. If None or '1.0' no encryption is done.
+                                            If '1.1' symmetric AES encryption is done
+            salt    (str):                  optional salt to be used. Must be str
+
+        Returns:
+            Encrypted content of value  (str)
 
 
-    def decrypt(self, value, schema_version=None, salt=None):
-        """
-        Decrypt an encrypted value
-        :param value: value to be decrypted. It is a base64 string
-        :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done.
-               If '1.1' symmetric AES encryption has been done
-        :param salt: optional salt to be used
-        :return: Plain content of value
         """
         self.get_secret_key()
         """
         self.get_secret_key()
+        return self._encrypt_value(value, schema_version, salt)
+
+    def _decrypt_value(self, value: str, schema_version: str, salt: str) -> str:
+        """Decrypt an encrypted value.
+        Args:
+
+            value   (str):                      value to be decrypted. It is a base64 string
+            schema_version  (str):              used for known encryption method used.
+                                                If None or '1.0' no encryption has been done.
+                                                If '1.1' symmetric AES encryption has been done
+            salt    (str):                      optional salt to be used
+
+        Returns:
+            Plain content of value  (str)
+
+        """
         if not self.secret_key or not schema_version or schema_version == "1.0":
             return value
         if not self.secret_key or not schema_version or schema_version == "1.0":
             return value
+
         else:
             secret_key = self._join_secret_key(salt)
         else:
             secret_key = self._join_secret_key(salt)
+            # Decoding encrypted data, output bytes
             encrypted_msg = b64decode(value)
             encrypted_msg = b64decode(value)
-            cipher = AES.new(secret_key)
+            cipher = AES.new(secret_key, self.encrypt_mode)
+            # Decrypted data, output bytes
             decrypted_msg = cipher.decrypt(encrypted_msg)
             try:
             decrypted_msg = cipher.decrypt(encrypted_msg)
             try:
-                unpadded_private_msg = decrypted_msg.decode().rstrip("\0")
+                # Converting to string
+                private_msg = decrypted_msg.decode(self.encoding_type)
             except UnicodeDecodeError:
                 raise DbException(
                     "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
                     http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
                 )
             except UnicodeDecodeError:
                 raise DbException(
                     "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
                     http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
                 )
-            return unpadded_private_msg
+            # Unpadded data as string
+            return self.unpad_data(private_msg)
+
+    def decrypt(self, value: str, schema_version: str = None, salt: str = None) -> str:
+        """Decrypt an encrypted value.
+        Args:
+
+            value   (str):                      value to be decrypted. It is a base64 string
+            schema_version  (str):              used for known encryption method used.
+                                                If None or '1.0' no encryption has been done.
+                                                If '1.1' symmetric AES encryption has been done
+            salt    (str):                      optional salt to be used
+
+        Returns:
+            Plain content of value  (str)
+
+        """
+        self.get_secret_key()
+        return self._decrypt_value(value, schema_version, salt)
 
     def encrypt_decrypt_fields(
         self, item, action, fields=None, flags=None, schema_version=None, salt=None
 
     def encrypt_decrypt_fields(
         self, item, action, fields=None, flags=None, schema_version=None, salt=None
@@ -593,3 +670,163 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
 def deep_update(dict_to_change, dict_reference):
     """Maintained for backward compatibility. Use deep_update_rfc7396 instead"""
     return deep_update_rfc7396(dict_to_change, dict_reference)
 def deep_update(dict_to_change, dict_reference):
     """Maintained for backward compatibility. Use deep_update_rfc7396 instead"""
     return deep_update_rfc7396(dict_to_change, dict_reference)
+
+
+class Encryption(DbBase):
+    def __init__(self, uri, config, encoding_type="ascii", loop=None, logger_name="db"):
+        """Constructor.
+
+        Args:
+            uri (str):              Connection string to connect to the database.
+            config  (dict):         Additional database info
+            encoding_type   (str):  ascii, utf-8 etc.
+            loop    (object):       Asyncio Loop
+            logger_name (str):      Logger name
+
+        """
+        self.loop = loop or asyncio.get_event_loop()
+        self._secret_key = None  # 32 bytes length array used for encrypt/decrypt
+        self.encrypt_mode = AES.MODE_ECB
+        super(Encryption, self).__init__(
+            encoding_type=encoding_type, logger_name=logger_name
+        )
+        self._client = AsyncIOMotorClient(uri)
+        self._config = config
+
+    @property
+    def secret_key(self):
+        return self._secret_key
+
+    @secret_key.setter
+    def secret_key(self, value):
+        self._secret_key = value
+
+    @property
+    def _database(self):
+        return self._client[DB_NAME]
+
+    @property
+    def _admin_collection(self):
+        return self._database["admin"]
+
+    @property
+    def database_key(self):
+        return self._config.get("database_commonkey")
+
+    async def decrypt_fields(
+        self,
+        item: dict,
+        fields: typing.List[str],
+        schema_version: str = None,
+        salt: str = None,
+    ) -> None:
+        """Decrypt fields from a dictionary. Follows the same logic as in osm_common.
+
+        Args:
+
+            item    (dict):         Dictionary with the keys to be decrypted
+            fields  (list):          List of keys to decrypt
+            schema version  (str):  Schema version. (i.e. 1.11)
+            salt    (str):          Salt for the decryption
+
+        """
+        flags = re.I
+
+        async def process(_item):
+            if isinstance(_item, list):
+                for elem in _item:
+                    await process(elem)
+            elif isinstance(_item, dict):
+                for key, val in _item.items():
+                    if isinstance(val, str):
+                        if any(re.search(f, key, flags) for f in fields):
+                            _item[key] = await self.decrypt(val, schema_version, salt)
+                    else:
+                        await process(val)
+
+        await process(item)
+
+    async def encrypt(
+        self, value: str, schema_version: str = None, salt: str = None
+    ) -> str:
+        """Encrypt a value.
+
+        Args:
+            value   (str):                  value to be encrypted. It is string/unicode
+            schema_version  (str):          used for version control. If None or '1.0' no encryption is done.
+                                            If '1.1' symmetric AES encryption is done
+            salt    (str):                  optional salt to be used. Must be str
+
+        Returns:
+            Encrypted content of value  (str)
+
+        """
+        await self.get_secret_key()
+        return self._encrypt_value(value, schema_version, salt)
+
+    async def decrypt(
+        self, value: str, schema_version: str = None, salt: str = None
+    ) -> str:
+        """Decrypt an encrypted value.
+        Args:
+
+            value   (str):                      value to be decrypted. It is a base64 string
+            schema_version  (str):              used for known encryption method used.
+                                                If None or '1.0' no encryption has been done.
+                                                If '1.1' symmetric AES encryption has been done
+            salt    (str):                      optional salt to be used
+
+        Returns:
+            Plain content of value  (str)
+
+        """
+        await self.get_secret_key()
+        return self._decrypt_value(value, schema_version, salt)
+
+    def _join_secret_key(self, update_key: typing.Any) -> bytes:
+        """Join key with secret key.
+
+        Args:
+
+        update_key  (str or bytes):    str or bytes with the to update
+
+        Returns:
+
+            Joined key  (bytes)
+        """
+        return self._join_keys(update_key, self.secret_key)
+
+    def _join_keys(self, key: typing.Any, secret_key: bytes) -> bytes:
+        """Join key with secret_key.
+
+        Args:
+
+            key (str or bytes):         str or bytes of the key to update
+            secret_key  (bytes):         bytes of the secret key
+
+        Returns:
+
+            Joined key  (bytes)
+        """
+        if isinstance(key, str):
+            update_key_bytes = key.encode(self.encoding_type)
+        else:
+            update_key_bytes = key
+        new_secret_key = bytearray(secret_key) if secret_key else bytearray(32)
+        for i, b in enumerate(update_key_bytes):
+            new_secret_key[i % 32] ^= b
+        return bytes(new_secret_key)
+
+    async def get_secret_key(self):
+        """Get secret key using the database key and the serial key in the DB.
+        The key is populated in the property self.secret_key.
+        """
+        if self.secret_key:
+            return
+        secret_key = None
+        if self.database_key:
+            secret_key = self._join_keys(self.database_key, None)
+        version_data = await self._admin_collection.find_one({"_id": "version"})
+        if version_data and version_data.get("serial"):
+            secret_key = self._join_keys(b64decode(version_data["serial"]), secret_key)
+        self._secret_key = secret_key