Fix 976. Get serial key after database is inited
[osm/common.git] / osm_common / dbbase.py
index 3959383..7428ed9 100644 (file)
 
 import yaml
 import logging
 
 import yaml
 import logging
+import re
 from http import HTTPStatus
 from copy import deepcopy
 from Crypto.Cipher import AES
 from base64 import b64decode, b64encode
 from http import HTTPStatus
 from copy import deepcopy
 from Crypto.Cipher import AES
 from base64 import b64decode, b64encode
+from osm_common.common_utils import FakeLock
+from threading import Lock
 
 __author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
 
 
 __author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
 
@@ -34,20 +37,37 @@ class DbException(Exception):
 
 class DbBase(object):
 
 
 class DbBase(object):
 
-    def __init__(self, logger_name='db', master_password=None):
+    def __init__(self, logger_name='db', lock=False):
         """
         """
-        Constructor od dbBase
+        Constructor of dbBase
         :param logger_name: logging name
         :param logger_name: logging name
-        :param master_password: master password used for encrypt decrypt methods
+        :param lock: Used to protect simultaneous access to the same instance class by several threads:
+            False, None: Do not protect, this object will only be accessed by one thread
+            True: This object needs to be protected by several threads accessing.
+            Lock object. Use thi Lock for the threads access protection
         """
         self.logger = logging.getLogger(logger_name)
         """
         self.logger = logging.getLogger(logger_name)
-        self.master_password = master_password
-        self.secret_key = None
+        self.secret_key = None  # 32 bytes length array used for encrypt/decrypt
+        if not lock:
+            self.lock = FakeLock()
+        elif lock is True:
+            self.lock = Lock()
+        elif isinstance(lock, Lock):
+            self.lock = lock
+        else:
+            raise ValueError("lock parameter must be a Lock classclass or boolean")
 
     def db_connect(self, config, target_version=None):
         """
         Connect to database
 
     def db_connect(self, config, target_version=None):
         """
         Connect to database
-        :param config: Configuration of database
+        :param config: Configuration of database. Contains among others:
+            host:   database hosst (mandatory)
+            port:   database port (mandatory)
+            name:   database name (mandatory)
+            user:   database username
+            password:   database password
+            commonkey: common OSM key used for sensible information encryption
+            materpassword: same as commonkey, for backward compatibility. Deprecated, to be removed in the future
         :param target_version: if provided it checks if database contains required version, raising exception otherwise.
         :return: None or raises DbException on error
         """
         :param target_version: if provided it checks if database contains required version, raising exception otherwise.
         :return: None or raises DbException on error
         """
@@ -69,6 +89,16 @@ class DbBase(object):
         """
         raise DbException("Method 'get_list' not implemented")
 
         """
         raise DbException("Method 'get_list' not implemented")
 
+    def count(self, table, q_filter=None):
+        """
+        Count the number of entries matching q_filter
+        :param table: collection or table
+        :param q_filter: Filter
+        :return: number of entries found (can be zero)
+        :raise: DbException on error
+        """
+        raise DbException("Method 'count' not implemented")
+
     def get_one(self, table, q_filter=None, fail_on_empty=True, fail_on_more=True):
         """
         Obtain one entry matching q_filter
     def get_one(self, table, q_filter=None, fail_on_empty=True, fail_on_more=True):
         """
         Obtain one entry matching q_filter
@@ -111,7 +141,7 @@ class DbBase(object):
         """
         raise DbException("Method 'create' not implemented")
 
         """
         raise DbException("Method 'create' not implemented")
 
-    def set_one(self, table, q_filter, update_dict, fail_on_empty=True):
+    def set_one(self, table, q_filter, update_dict, fail_on_empty=True, unset=None, pull=None, push=None):
         """
         Modifies an entry at database
         :param table: collection or table
         """
         Modifies an entry at database
         :param table: collection or table
@@ -119,6 +149,12 @@ class DbBase(object):
         :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
         :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
         it raises a DbException
         :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
         :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
         it raises a DbException
+        :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
+                      ignored. If not exist, it is ignored
+        :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
+                     if exist in the array is removed. If not exist, it is ignored
+        :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
+                     is appended to the end of the array
         :return: Dict with the number of entries modified. None if no matching is found.
         """
         raise DbException("Method 'set_one' not implemented")
         :return: Dict with the number of entries modified. None if no matching is found.
         """
         raise DbException("Method 'set_one' not implemented")
@@ -145,29 +181,43 @@ class DbBase(object):
         """
         raise DbException("Method 'replace' not implemented")
 
         """
         raise DbException("Method 'replace' not implemented")
 
-    @staticmethod
-    def _join_passwords(passwd_byte, passwd_str):
+    def _join_secret_key(self, update_key):
         """
         """
-        Modifies passwd_byte with the xor of passwd_str. Used for adding salt, join passwords, etc
-        :param passwd_byte: original password in bytes, 32 byte length
-        :param passwd_str: string salt to be added
-        :return: modified password in bytes
+        Returns a xor byte combination of the internal secret_key and the provided update_key.
+        It does not modify the internal secret_key. Used for adding salt, join keys, etc.
+        :param update_key: Can be a string, byte or None. Recommended a long one (e.g. 32 byte length)
+        :return: joined key in bytes with a 32 bytes length. Can be None if both internal secret_key and update_key
+                 are None
         """
         """
-        if not passwd_str:
-            return passwd_byte
-        secret_key = bytearray(passwd_byte)
-        for i, b in enumerate(passwd_str.encode()):
-            secret_key[i % 32] ^= b
-        return bytes(secret_key)
+        if not update_key:
+            return self.secret_key
+        elif isinstance(update_key, str):
+            update_key_bytes = update_key.encode()
+        else:
+            update_key_bytes = update_key
 
 
-    def set_secret_key(self, secret_key):
+        new_secret_key = bytearray(self.secret_key) if self.secret_key else bytearray(32)
+        for i, b in enumerate(update_key_bytes):
+            new_secret_key[i % 32] ^= b
+        return bytes(new_secret_key)
+
+    def set_secret_key(self, new_secret_key, replace=False):
         """
         """
-        Set internal secret key used for encryption
-        :param secret_key: byte array length 32 with the secret_key
+        Updates internal secret_key used for encryption, with a byte xor
+        :param new_secret_key: string or byte array. It is recommended a 32 byte length
+        :param replace: if True, old value of internal secret_key is ignored and replaced. If false, a byte xor is used
         :return: None
         """
         :return: None
         """
-        assert (len(secret_key) == 32)
-        self.secret_key = self._join_passwords(secret_key, self.master_password)
+        if replace:
+            self.secret_key = None
+        self.secret_key = self._join_secret_key(new_secret_key)
+
+    def get_secret_key(self):
+        """
+        Get the database secret key in case it is not done when "connect" is called. It can happens when database is
+        empty after an initial install. It should skip if secret is already obtained.
+        """
+        pass
 
     def encrypt(self, value, schema_version=None, salt=None):
         """
 
     def encrypt(self, value, schema_version=None, salt=None):
         """
@@ -178,12 +228,11 @@ class DbBase(object):
         :param salt: optional salt to be used. Must be str
         :return: Encrypted content of value
         """
         :param salt: optional salt to be used. Must be str
         :return: Encrypted content of value
         """
-        if not schema_version or schema_version == '1.0':
+        self.get_secret_key()
+        if not self.secret_key or not schema_version or schema_version == '1.0':
             return value
         else:
             return value
         else:
-            if not self.secret_key:
-                raise DbException("Cannot encrypt. Missing secret_key", http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-            secret_key = self._join_passwords(self.secret_key, salt)
+            secret_key = self._join_secret_key(salt)
             cipher = AES.new(secret_key)
             padded_private_msg = value + ('\0' * ((16-len(value)) % 16))
             encrypted_msg = cipher.encrypt(padded_private_msg)
             cipher = AES.new(secret_key)
             padded_private_msg = value + ('\0' * ((16-len(value)) % 16))
             encrypted_msg = cipher.encrypt(padded_private_msg)
@@ -199,18 +248,44 @@ class DbBase(object):
         :param salt: optional salt to be used
         :return: Plain content of value
         """
         :param salt: optional salt to be used
         :return: Plain content of value
         """
-        if not schema_version or schema_version == '1.0':
+        self.get_secret_key()
+        if not self.secret_key or not schema_version or schema_version == '1.0':
             return value
         else:
             return value
         else:
-            if not self.secret_key:
-                raise DbException("Cannot decrypt. Missing secret_key", http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-            secret_key = self._join_passwords(self.secret_key, salt)
+            secret_key = self._join_secret_key(salt)
             encrypted_msg = b64decode(value)
             cipher = AES.new(secret_key)
             decrypted_msg = cipher.decrypt(encrypted_msg)
             encrypted_msg = b64decode(value)
             cipher = AES.new(secret_key)
             decrypted_msg = cipher.decrypt(encrypted_msg)
-            unpadded_private_msg = decrypted_msg.decode().rstrip('\0')
+            try:
+                unpadded_private_msg = decrypted_msg.decode().rstrip('\0')
+            except UnicodeDecodeError:
+                raise DbException("Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
+                                  http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
             return unpadded_private_msg
 
             return unpadded_private_msg
 
+    def encrypt_decrypt_fields(self, item, action, fields=None, flags=re.I, schema_version=None, salt=None):
+        if not fields:
+            return
+        self.get_secret_key()
+        actions = ['encrypt', 'decrypt']
+        if action.lower() not in actions:
+            raise DbException("Unknown action ({}): Must be one of {}".format(action, actions),
+                              http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
+        method = self.encrypt if action.lower() == 'encrypt' else self.decrypt
+
+        def process(item):
+            if isinstance(item, list):
+                for elem in item:
+                    process(elem)
+            elif isinstance(item, dict):
+                for key, val in item.items():
+                    if any(re.search(f, key, flags) for f in fields) and isinstance(val, str):
+                        item[key] = method(val, schema_version, salt)
+                    else:
+                        process(val)
+
+        process(item)
+
 
 def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
     """
 
 def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
     """
@@ -227,7 +302,7 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
                     Nothing happens if no match is found. If the value is None the matched elements are deleted.
         $key: val   In case a dictionary is passed in yaml format, if looks for all items in the array dict_to_change
                     that are dictionaries and contains this <key> equal to <val>. Several keys can be used by yaml
                     Nothing happens if no match is found. If the value is None the matched elements are deleted.
         $key: val   In case a dictionary is passed in yaml format, if looks for all items in the array dict_to_change
                     that are dictionaries and contains this <key> equal to <val>. Several keys can be used by yaml
-                    format '{key: val, key: val, ...}'; and all of them mast match. Nothing happens if no match is
+                    format '{key: val, key: val, ...}'; and all of them must match. Nothing happens if no match is
                     found. If value is None the matched items are deleted, otherwise they are edited.
         $+val       If no match if found (see '$val'), the value is appended to the array. If any match is found nothing
                     is changed. A value of None has not sense.
                     found. If value is None the matched items are deleted, otherwise they are edited.
         $+val       If no match if found (see '$val'), the value is appended to the array. If any match is found nothing
                     is changed. A value of None has not sense.