Enable parallel execution and output of tox env
[osm/common.git] / osm_common / dbbase.py
index 7a98c76..d0d4fb0 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import yaml
+from base64 import b64decode, b64encode
+from copy import deepcopy
+from http import HTTPStatus
 import logging
 import re
-from http import HTTPStatus
-from copy import deepcopy
+from threading import Lock
+import typing
+
+
 from Crypto.Cipher import AES
-from base64 import b64decode, b64encode
+from motor.motor_asyncio import AsyncIOMotorClient
 from osm_common.common_utils import FakeLock
-from threading import Lock
+import yaml
 
 __author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
 
 
-class DbException(Exception):
+DB_NAME = "osm"
 
+
+class DbException(Exception):
     def __init__(self, message, http_code=HTTPStatus.NOT_FOUND):
         self.http_code = http_code
         Exception.__init__(self, "database exception " + str(message))
 
 
 class DbBase(object):
-
-    def __init__(self, logger_name='db', lock=False):
+    def __init__(self, encoding_type="ascii", logger_name="db", lock=False):
         """
         Constructor of dbBase
         :param logger_name: logging name
@@ -48,6 +53,8 @@ class DbBase(object):
         """
         self.logger = logging.getLogger(logger_name)
         self.secret_key = None  # 32 bytes length array used for encrypt/decrypt
+        self.encrypt_mode = AES.MODE_ECB
+        self.encoding_type = encoding_type
         if not lock:
             self.lock = FakeLock()
         elif lock is True:
@@ -151,7 +158,18 @@ class DbBase(object):
         """
         raise DbException("Method 'create_list' not implemented")
 
-    def set_one(self, table, q_filter, update_dict, fail_on_empty=True, unset=None, pull=None, push=None):
+    def set_one(
+        self,
+        table,
+        q_filter,
+        update_dict,
+        fail_on_empty=True,
+        unset=None,
+        pull=None,
+        push=None,
+        push_list=None,
+        pull_list=None,
+    ):
         """
         Modifies an entry at database
         :param table: collection or table
@@ -165,11 +183,24 @@ class DbBase(object):
                      if exist in the array is removed. If not exist, it is ignored
         :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
                      is appended to the end of the array
+        :param pull_list: Same as pull but values are arrays where each item is removed from the array
+        :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
+                          whole array
         :return: Dict with the number of entries modified. None if no matching is found.
         """
         raise DbException("Method 'set_one' not implemented")
 
-    def set_list(self, table, q_filter, update_dict, unset=None, pull=None, push=None):
+    def set_list(
+        self,
+        table,
+        q_filter,
+        update_dict,
+        unset=None,
+        pull=None,
+        push=None,
+        push_list=None,
+        pull_list=None,
+    ):
         """
         Modifies al matching entries at database
         :param table: collection or table
@@ -181,6 +212,9 @@ class DbBase(object):
                      if exist in the array is removed. If not exist, it is ignored
         :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
                      is appended to the end of the array
+        :param pull_list: Same as pull but values are arrays where each item is removed from the array
+        :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
+                          whole array
         :return: Dict with the number of entries modified
         """
         raise DbException("Method 'set_list' not implemented")
@@ -212,7 +246,9 @@ class DbBase(object):
         else:
             update_key_bytes = update_key
 
-        new_secret_key = bytearray(self.secret_key) if self.secret_key else bytearray(32)
+        new_secret_key = (
+            bytearray(self.secret_key) if self.secret_key else bytearray(32)
+        )
         for i, b in enumerate(update_key_bytes):
             new_secret_key[i % 32] ^= b
         return bytes(new_secret_key)
@@ -235,68 +271,146 @@ class DbBase(object):
         """
         pass
 
-    def encrypt(self, value, schema_version=None, salt=None):
-        """
-        Encrypt a value
-        :param value: value to be encrypted. It is string/unicode
-        :param schema_version: used for version control. If None or '1.0' no encryption is done.
-               If '1.1' symmetric AES encryption is done
-        :param salt: optional salt to be used. Must be str
-        :return: Encrypted content of value
+    @staticmethod
+    def pad_data(value: str) -> str:
+        if not isinstance(value, str):
+            raise DbException(
+                f"Incorrect data type: type({value}), string is expected."
+            )
+        return value + ("\0" * ((16 - len(value)) % 16))
+
+    @staticmethod
+    def unpad_data(value: str) -> str:
+        if not isinstance(value, str):
+            raise DbException(
+                f"Incorrect data type: type({value}), string is expected."
+            )
+        return value.rstrip("\0")
+
+    def _encrypt_value(self, value: str, schema_version: str, salt: str):
+        """Encrypt a value.
+
+        Args:
+            value   (str):                  value to be encrypted. It is string/unicode
+            schema_version  (str):          used for version control. If None or '1.0' no encryption is done.
+                                            If '1.1' symmetric AES encryption is done
+            salt    (str):                  optional salt to be used. Must be str
+
+        Returns:
+            Encrypted content of value  (str)
+
         """
-        self.get_secret_key()
-        if not self.secret_key or not schema_version or schema_version == '1.0':
+        if not self.secret_key or not schema_version or schema_version == "1.0":
             return value
+
         else:
+            # Secret key as bytes
             secret_key = self._join_secret_key(salt)
-            cipher = AES.new(secret_key)
-            padded_private_msg = value + ('\0' * ((16-len(value)) % 16))
-            encrypted_msg = cipher.encrypt(padded_private_msg)
+            cipher = AES.new(secret_key, self.encrypt_mode)
+            # Padded data as string
+            padded_private_msg = self.pad_data(value)
+            # Padded data as bytes
+            padded_private_msg_bytes = padded_private_msg.encode(self.encoding_type)
+            # Encrypt padded data
+            encrypted_msg = cipher.encrypt(padded_private_msg_bytes)
+            # Base64 encoded encrypted data
             encoded_encrypted_msg = b64encode(encrypted_msg)
-            return encoded_encrypted_msg.decode("ascii")
+            # Converting to string
+            return encoded_encrypted_msg.decode(self.encoding_type)
+
+    def encrypt(self, value: str, schema_version: str = None, salt: str = None) -> str:
+        """Encrypt a value.
+
+        Args:
+            value   (str):                  value to be encrypted. It is string/unicode
+            schema_version  (str):          used for version control. If None or '1.0' no encryption is done.
+                                            If '1.1' symmetric AES encryption is done
+            salt    (str):                  optional salt to be used. Must be str
+
+        Returns:
+            Encrypted content of value  (str)
 
-    def decrypt(self, value, schema_version=None, salt=None):
-        """
-        Decrypt an encrypted value
-        :param value: value to be decrypted. It is a base64 string
-        :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done.
-               If '1.1' symmetric AES encryption has been done
-        :param salt: optional salt to be used
-        :return: Plain content of value
         """
         self.get_secret_key()
-        if not self.secret_key or not schema_version or schema_version == '1.0':
+        return self._encrypt_value(value, schema_version, salt)
+
+    def _decrypt_value(self, value: str, schema_version: str, salt: str) -> str:
+        """Decrypt an encrypted value.
+        Args:
+
+            value   (str):                      value to be decrypted. It is a base64 string
+            schema_version  (str):              used for known encryption method used.
+                                                If None or '1.0' no encryption has been done.
+                                                If '1.1' symmetric AES encryption has been done
+            salt    (str):                      optional salt to be used
+
+        Returns:
+            Plain content of value  (str)
+
+        """
+        if not self.secret_key or not schema_version or schema_version == "1.0":
             return value
+
         else:
             secret_key = self._join_secret_key(salt)
+            # Decoding encrypted data, output bytes
             encrypted_msg = b64decode(value)
-            cipher = AES.new(secret_key)
+            cipher = AES.new(secret_key, self.encrypt_mode)
+            # Decrypted data, output bytes
             decrypted_msg = cipher.decrypt(encrypted_msg)
             try:
-                unpadded_private_msg = decrypted_msg.decode().rstrip('\0')
+                # Converting to string
+                private_msg = decrypted_msg.decode(self.encoding_type)
             except UnicodeDecodeError:
-                raise DbException("Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
-                                  http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-            return unpadded_private_msg
+                raise DbException(
+                    "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
+                    http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+                )
+            # Unpadded data as string
+            return self.unpad_data(private_msg)
 
-    def encrypt_decrypt_fields(self, item, action, fields=None, flags=re.I, schema_version=None, salt=None):
+    def decrypt(self, value: str, schema_version: str = None, salt: str = None) -> str:
+        """Decrypt an encrypted value.
+        Args:
+
+            value   (str):                      value to be decrypted. It is a base64 string
+            schema_version  (str):              used for known encryption method used.
+                                                If None or '1.0' no encryption has been done.
+                                                If '1.1' symmetric AES encryption has been done
+            salt    (str):                      optional salt to be used
+
+        Returns:
+            Plain content of value  (str)
+
+        """
+        self.get_secret_key()
+        return self._decrypt_value(value, schema_version, salt)
+
+    def encrypt_decrypt_fields(
+        self, item, action, fields=None, flags=None, schema_version=None, salt=None
+    ):
         if not fields:
             return
         self.get_secret_key()
-        actions = ['encrypt', 'decrypt']
+        actions = ["encrypt", "decrypt"]
         if action.lower() not in actions:
-            raise DbException("Unknown action ({}): Must be one of {}".format(action, actions),
-                              http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-        method = self.encrypt if action.lower() == 'encrypt' else self.decrypt
-
-        def process(item):
-            if isinstance(item, list):
-                for elem in item:
+            raise DbException(
+                "Unknown action ({}): Must be one of {}".format(action, actions),
+                http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+            )
+        method = self.encrypt if action.lower() == "encrypt" else self.decrypt
+        if flags is None:
+            flags = re.I
+
+        def process(_item):
+            if isinstance(_item, list):
+                for elem in _item:
                     process(elem)
-            elif isinstance(item, dict):
-                for key, val in item.items():
-                    if any(re.search(f, key, flags) for f in fields) and isinstance(val, str):
-                        item[key] = method(val, schema_version, salt)
+            elif isinstance(_item, dict):
+                for key, val in _item.items():
+                    if isinstance(val, str):
+                        if any(re.search(f, key, flags) for f in fields):
+                            _item[key] = method(val, schema_version, salt)
                     else:
                         process(val)
 
@@ -332,6 +446,7 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
     :param key_list: This is used internally for recursive calls. Do not fill this parameter.
     :return: none or raises and exception only at array modification when there is a bad format or conflict.
     """
+
     def _deep_update_array(array_to_change, _dict_reference, _key_list):
         to_append = {}
         to_insert_at_index = {}
@@ -343,26 +458,33 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
             _key_list[-1] = str(k)
             if not isinstance(k, str) or not k.startswith("$"):
                 if array_edition is True:
-                    raise DbException("Found array edition (keys starting with '$') and pure dictionary edition in the"
-                                      " same dict at '{}'".format(":".join(_key_list[:-1])))
+                    raise DbException(
+                        "Found array edition (keys starting with '$') and pure dictionary edition in the"
+                        " same dict at '{}'".format(":".join(_key_list[:-1]))
+                    )
                 array_edition = False
                 continue
             else:
                 if array_edition is False:
-                    raise DbException("Found array edition (keys starting with '$') and pure dictionary edition in the"
-                                      " same dict at '{}'".format(":".join(_key_list[:-1])))
+                    raise DbException(
+                        "Found array edition (keys starting with '$') and pure dictionary edition in the"
+                        " same dict at '{}'".format(":".join(_key_list[:-1]))
+                    )
                 array_edition = True
             insert = False
             indexes = []  # indexes to edit or insert
             kitem = k[1:]
-            if kitem.startswith('+'):
+            if kitem.startswith("+"):
                 insert = True
                 kitem = kitem[1:]
                 if _dict_reference[k] is None:
-                    raise DbException("A value of None has not sense for insertions at '{}'".format(
-                        ":".join(_key_list)))
+                    raise DbException(
+                        "A value of None has not sense for insertions at '{}'".format(
+                            ":".join(_key_list)
+                        )
+                    )
 
-            if kitem.startswith('[') and kitem.endswith(']'):
+            if kitem.startswith("[") and kitem.endswith("]"):
                 try:
                     index = int(kitem[1:-1])
                     if index < 0:
@@ -371,18 +493,29 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
                         index = 0  # skip outside index edition
                     indexes.append(index)
                 except Exception:
-                    raise DbException("Wrong format at '{}'. Expecting integer index inside quotes".format(
-                        ":".join(_key_list)))
+                    raise DbException(
+                        "Wrong format at '{}'. Expecting integer index inside quotes".format(
+                            ":".join(_key_list)
+                        )
+                    )
             elif kitem:
                 # match_found_skip = False
                 try:
                     filter_in = yaml.safe_load(kitem)
                 except Exception:
-                    raise DbException("Wrong format at '{}'. Expecting '$<yaml-format>'".format(":".join(_key_list)))
+                    raise DbException(
+                        "Wrong format at '{}'. Expecting '$<yaml-format>'".format(
+                            ":".join(_key_list)
+                        )
+                    )
                 if isinstance(filter_in, dict):
                     for index, item in enumerate(array_to_change):
                         for filter_k, filter_v in filter_in.items():
-                            if not isinstance(item, dict) or filter_k not in item or item[filter_k] != filter_v:
+                            if (
+                                not isinstance(item, dict)
+                                or filter_k not in item
+                                or item[filter_k] != filter_v
+                            ):
                                 break
                         else:  # match found
                             if insert:
@@ -408,20 +541,35 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
                 # if match_found_skip:
                 #     continue
             elif not insert:
-                raise DbException("Wrong format at '{}'. Expecting '$+', '$[<index]' or '$[<filter>]'".format(
-                    ":".join(_key_list)))
+                raise DbException(
+                    "Wrong format at '{}'. Expecting '$+', '$[<index]' or '$[<filter>]'".format(
+                        ":".join(_key_list)
+                    )
+                )
             for index in indexes:
                 if insert:
-                    if index in to_insert_at_index and to_insert_at_index[index] != _dict_reference[k]:
+                    if (
+                        index in to_insert_at_index
+                        and to_insert_at_index[index] != _dict_reference[k]
+                    ):
                         # Several different insertions on the same item of the array
-                        raise DbException("Conflict at '{}'. Several insertions on same array index {}".format(
-                            ":".join(_key_list), index))
+                        raise DbException(
+                            "Conflict at '{}'. Several insertions on same array index {}".format(
+                                ":".join(_key_list), index
+                            )
+                        )
                     to_insert_at_index[index] = _dict_reference[k]
                 else:
-                    if index in indexes_to_edit_delete and values_to_edit_delete[index] != _dict_reference[k]:
+                    if (
+                        index in indexes_to_edit_delete
+                        and values_to_edit_delete[index] != _dict_reference[k]
+                    ):
                         # Several different editions on the same item of the array
-                        raise DbException("Conflict at '{}'. Several editions on array index {}".format(
-                            ":".join(_key_list), index))
+                        raise DbException(
+                            "Conflict at '{}'. Several editions on array index {}".format(
+                                ":".join(_key_list), index
+                            )
+                        )
                     indexes_to_edit_delete.append(index)
                     values_to_edit_delete[index] = _dict_reference[k]
             if not indexes:
@@ -438,22 +586,38 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
             try:
                 if values_to_edit_delete[index] is None:  # None->Anything
                     try:
-                        del (array_to_change[index])
+                        del array_to_change[index]
                     except IndexError:
                         pass  # it is not consider an error if this index does not exist
-                elif not isinstance(values_to_edit_delete[index], dict):  # NotDict->Anything
+                elif not isinstance(
+                    values_to_edit_delete[index], dict
+                ):  # NotDict->Anything
                     array_to_change[index] = deepcopy(values_to_edit_delete[index])
                 elif isinstance(array_to_change[index], dict):  # Dict->Dict
-                    deep_update_rfc7396(array_to_change[index], values_to_edit_delete[index], _key_list)
+                    deep_update_rfc7396(
+                        array_to_change[index], values_to_edit_delete[index], _key_list
+                    )
                 else:  # Dict->NotDict
-                    if isinstance(array_to_change[index], list):  # Dict->List. Check extra array edition
-                        if _deep_update_array(array_to_change[index], values_to_edit_delete[index], _key_list):
+                    if isinstance(
+                        array_to_change[index], list
+                    ):  # Dict->List. Check extra array edition
+                        if _deep_update_array(
+                            array_to_change[index],
+                            values_to_edit_delete[index],
+                            _key_list,
+                        ):
                             continue
                     array_to_change[index] = deepcopy(values_to_edit_delete[index])
                     # calling deep_update_rfc7396 to delete the None values
-                    deep_update_rfc7396(array_to_change[index], values_to_edit_delete[index], _key_list)
+                    deep_update_rfc7396(
+                        array_to_change[index], values_to_edit_delete[index], _key_list
+                    )
             except IndexError:
-                raise DbException("Array edition index out of range at '{}'".format(":".join(_key_list)))
+                raise DbException(
+                    "Array edition index out of range at '{}'".format(
+                        ":".join(_key_list)
+                    )
+                )
 
         # insertion with indexes
         to_insert_indexes = list(to_insert_at_index.keys())
@@ -480,7 +644,7 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
     key_list.append("")
     for k in dict_reference:
         key_list[-1] = str(k)
-        if dict_reference[k] is None:   # None->Anything
+        if dict_reference[k] is None:  # None->Anything
             if k in dict_to_change:
                 del dict_to_change[k]
         elif not isinstance(dict_reference[k], dict):  # NotDict->Anything
@@ -491,8 +655,10 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
             deep_update_rfc7396(dict_to_change[k], dict_reference[k], key_list)
         elif isinstance(dict_to_change[k], dict):  # Dict->Dict
             deep_update_rfc7396(dict_to_change[k], dict_reference[k], key_list)
-        else:       # Dict->NotDict
-            if isinstance(dict_to_change[k], list):  # Dict->List. Check extra array edition
+        else:  # Dict->NotDict
+            if isinstance(
+                dict_to_change[k], list
+            ):  # Dict->List. Check extra array edition
                 if _deep_update_array(dict_to_change[k], dict_reference[k], key_list):
                     continue
             dict_to_change[k] = deepcopy(dict_reference[k])
@@ -502,5 +668,163 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
 
 
 def deep_update(dict_to_change, dict_reference):
-    """ Maintained for backward compatibility. Use deep_update_rfc7396 instead"""
+    """Maintained for backward compatibility. Use deep_update_rfc7396 instead"""
     return deep_update_rfc7396(dict_to_change, dict_reference)
+
+
+class Encryption(DbBase):
+    def __init__(self, uri, config, encoding_type="ascii", logger_name="db"):
+        """Constructor.
+
+        Args:
+            uri (str):              Connection string to connect to the database.
+            config  (dict):         Additional database info
+            encoding_type   (str):  ascii, utf-8 etc.
+            logger_name (str):      Logger name
+
+        """
+        self._secret_key = None  # 32 bytes length array used for encrypt/decrypt
+        self.encrypt_mode = AES.MODE_ECB
+        super(Encryption, self).__init__(
+            encoding_type=encoding_type, logger_name=logger_name
+        )
+        self._client = AsyncIOMotorClient(uri)
+        self._config = config
+
+    @property
+    def secret_key(self):
+        return self._secret_key
+
+    @secret_key.setter
+    def secret_key(self, value):
+        self._secret_key = value
+
+    @property
+    def _database(self):
+        return self._client[DB_NAME]
+
+    @property
+    def _admin_collection(self):
+        return self._database["admin"]
+
+    @property
+    def database_key(self):
+        return self._config.get("database_commonkey")
+
+    async def decrypt_fields(
+        self,
+        item: dict,
+        fields: typing.List[str],
+        schema_version: str = None,
+        salt: str = None,
+    ) -> None:
+        """Decrypt fields from a dictionary. Follows the same logic as in osm_common.
+
+        Args:
+
+            item    (dict):         Dictionary with the keys to be decrypted
+            fields  (list):          List of keys to decrypt
+            schema version  (str):  Schema version. (i.e. 1.11)
+            salt    (str):          Salt for the decryption
+
+        """
+        flags = re.I
+
+        async def process(_item):
+            if isinstance(_item, list):
+                for elem in _item:
+                    await process(elem)
+            elif isinstance(_item, dict):
+                for key, val in _item.items():
+                    if isinstance(val, str):
+                        if any(re.search(f, key, flags) for f in fields):
+                            _item[key] = await self.decrypt(val, schema_version, salt)
+                    else:
+                        await process(val)
+
+        await process(item)
+
+    async def encrypt(
+        self, value: str, schema_version: str = None, salt: str = None
+    ) -> str:
+        """Encrypt a value.
+
+        Args:
+            value   (str):                  value to be encrypted. It is string/unicode
+            schema_version  (str):          used for version control. If None or '1.0' no encryption is done.
+                                            If '1.1' symmetric AES encryption is done
+            salt    (str):                  optional salt to be used. Must be str
+
+        Returns:
+            Encrypted content of value  (str)
+
+        """
+        await self.get_secret_key()
+        return self._encrypt_value(value, schema_version, salt)
+
+    async def decrypt(
+        self, value: str, schema_version: str = None, salt: str = None
+    ) -> str:
+        """Decrypt an encrypted value.
+        Args:
+
+            value   (str):                      value to be decrypted. It is a base64 string
+            schema_version  (str):              used for known encryption method used.
+                                                If None or '1.0' no encryption has been done.
+                                                If '1.1' symmetric AES encryption has been done
+            salt    (str):                      optional salt to be used
+
+        Returns:
+            Plain content of value  (str)
+
+        """
+        await self.get_secret_key()
+        return self._decrypt_value(value, schema_version, salt)
+
+    def _join_secret_key(self, update_key: typing.Any) -> bytes:
+        """Join key with secret key.
+
+        Args:
+
+        update_key  (str or bytes):    str or bytes with the to update
+
+        Returns:
+
+            Joined key  (bytes)
+        """
+        return self._join_keys(update_key, self.secret_key)
+
+    def _join_keys(self, key: typing.Any, secret_key: bytes) -> bytes:
+        """Join key with secret_key.
+
+        Args:
+
+            key (str or bytes):         str or bytes of the key to update
+            secret_key  (bytes):         bytes of the secret key
+
+        Returns:
+
+            Joined key  (bytes)
+        """
+        if isinstance(key, str):
+            update_key_bytes = key.encode(self.encoding_type)
+        else:
+            update_key_bytes = key
+        new_secret_key = bytearray(secret_key) if secret_key else bytearray(32)
+        for i, b in enumerate(update_key_bytes):
+            new_secret_key[i % 32] ^= b
+        return bytes(new_secret_key)
+
+    async def get_secret_key(self):
+        """Get secret key using the database key and the serial key in the DB.
+        The key is populated in the property self.secret_key.
+        """
+        if self.secret_key:
+            return
+        secret_key = None
+        if self.database_key:
+            secret_key = self._join_keys(self.database_key, None)
+        version_data = await self._admin_collection.find_one({"_id": "version"})
+        if version_data and version_data.get("serial"):
+            secret_key = self._join_keys(b64decode(version_data["serial"]), secret_key)
+        self._secret_key = secret_key