Support for mongodb replicaset connection (HA).
[osm/common.git] / osm_common / dbbase.py
index 5fef9ee..8d694a2 100644 (file)
 
 import yaml
 import logging
 
 import yaml
 import logging
+import re
 from http import HTTPStatus
 from copy import deepcopy
 from Crypto.Cipher import AES
 from base64 import b64decode, b64encode
 from http import HTTPStatus
 from copy import deepcopy
 from Crypto.Cipher import AES
 from base64 import b64decode, b64encode
+from osm_common.common_utils import FakeLock
+from threading import Lock
 
 __author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
 
 
 __author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
 
@@ -34,25 +37,37 @@ class DbException(Exception):
 
 class DbBase(object):
 
 
 class DbBase(object):
 
-    def __init__(self, logger_name='db'):
+    def __init__(self, logger_name='db', lock=False):
         """
         """
-        Constructor od dbBase
+        Constructor of dbBase
         :param logger_name: logging name
         :param logger_name: logging name
+        :param lock: Used to protect simultaneous access to the same instance class by several threads:
+            False, None: Do not protect, this object will only be accessed by one thread
+            True: This object needs to be protected by several threads accessing.
+            Lock object. Use thi Lock for the threads access protection
         """
         self.logger = logging.getLogger(logger_name)
         """
         self.logger = logging.getLogger(logger_name)
-        self.master_password = None
-        self.secret_key = None
+        self.secret_key = None  # 32 bytes length array used for encrypt/decrypt
+        if not lock:
+            self.lock = FakeLock()
+        elif lock is True:
+            self.lock = Lock()
+        elif isinstance(lock, Lock):
+            self.lock = lock
+        else:
+            raise ValueError("lock parameter must be a Lock classclass or boolean")
 
     def db_connect(self, config, target_version=None):
         """
         Connect to database
         :param config: Configuration of database. Contains among others:
 
     def db_connect(self, config, target_version=None):
         """
         Connect to database
         :param config: Configuration of database. Contains among others:
-            host:   database hosst (mandatory)
+            host:   database host (mandatory)
             port:   database port (mandatory)
             name:   database name (mandatory)
             user:   database username
             password:   database password
             port:   database port (mandatory)
             name:   database name (mandatory)
             user:   database username
             password:   database password
-            masterpassword: database password used for sensible information encryption
+            commonkey: common OSM key used for sensible information encryption
+            materpassword: same as commonkey, for backward compatibility. Deprecated, to be removed in the future
         :param target_version: if provided it checks if database contains required version, raising exception otherwise.
         :return: None or raises DbException on error
         """
         :param target_version: if provided it checks if database contains required version, raising exception otherwise.
         :return: None or raises DbException on error
         """
@@ -74,6 +89,16 @@ class DbBase(object):
         """
         raise DbException("Method 'get_list' not implemented")
 
         """
         raise DbException("Method 'get_list' not implemented")
 
+    def count(self, table, q_filter=None):
+        """
+        Count the number of entries matching q_filter
+        :param table: collection or table
+        :param q_filter: Filter
+        :return: number of entries found (can be zero)
+        :raise: DbException on error
+        """
+        raise DbException("Method 'count' not implemented")
+
     def get_one(self, table, q_filter=None, fail_on_empty=True, fail_on_more=True):
         """
         Obtain one entry matching q_filter
     def get_one(self, table, q_filter=None, fail_on_empty=True, fail_on_more=True):
         """
         Obtain one entry matching q_filter
@@ -112,11 +137,22 @@ class DbBase(object):
         Add a new entry at database
         :param table: collection or table
         :param indata: content to be added
         Add a new entry at database
         :param table: collection or table
         :param indata: content to be added
-        :return: database id of the inserted element. Raises a DbException on error
+        :return: database '_id' of the inserted element. Raises a DbException on error
         """
         raise DbException("Method 'create' not implemented")
 
         """
         raise DbException("Method 'create' not implemented")
 
-    def set_one(self, table, q_filter, update_dict, fail_on_empty=True):
+    def create_list(self, table, indata_list):
+        """
+        Add several entries at once
+        :param table: collection or table
+        :param indata_list: list of elements to insert. Each element must be a dictionary.
+            An '_id' key based on random uuid is added at each element if missing
+        :return: list of inserted '_id's. Exception on error
+        """
+        raise DbException("Method 'create_list' not implemented")
+
+    def set_one(self, table, q_filter, update_dict, fail_on_empty=True, unset=None, pull=None, push=None,
+                push_list=None, pull_list=None):
         """
         Modifies an entry at database
         :param table: collection or table
         """
         Modifies an entry at database
         :param table: collection or table
@@ -124,16 +160,34 @@ class DbBase(object):
         :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
         :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
         it raises a DbException
         :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
         :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
         it raises a DbException
+        :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
+                      ignored. If not exist, it is ignored
+        :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
+                     if exist in the array is removed. If not exist, it is ignored
+        :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
+                     is appended to the end of the array
+        :param pull_list: Same as pull but values are arrays where each item is removed from the array
+        :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
+                          whole array
         :return: Dict with the number of entries modified. None if no matching is found.
         """
         raise DbException("Method 'set_one' not implemented")
 
         :return: Dict with the number of entries modified. None if no matching is found.
         """
         raise DbException("Method 'set_one' not implemented")
 
-    def set_list(self, table, q_filter, update_dict):
+    def set_list(self, table, q_filter, update_dict, unset=None, pull=None, push=None, push_list=None, pull_list=None):
         """
         Modifies al matching entries at database
         :param table: collection or table
         :param q_filter: Filter
         :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
         """
         Modifies al matching entries at database
         :param table: collection or table
         :param q_filter: Filter
         :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
+        :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
+                      ignored. If not exist, it is ignored
+        :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
+                     if exist in the array is removed. If not exist, it is ignored
+        :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
+                     is appended to the end of the array
+        :param pull_list: Same as pull but values are arrays where each item is removed from the array
+        :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
+                          whole array
         :return: Dict with the number of entries modified
         """
         raise DbException("Method 'set_list' not implemented")
         :return: Dict with the number of entries modified
         """
         raise DbException("Method 'set_list' not implemented")
@@ -150,29 +204,43 @@ class DbBase(object):
         """
         raise DbException("Method 'replace' not implemented")
 
         """
         raise DbException("Method 'replace' not implemented")
 
-    @staticmethod
-    def _join_passwords(passwd_byte, passwd_str):
+    def _join_secret_key(self, update_key):
         """
         """
-        Modifies passwd_byte with the xor of passwd_str. Used for adding salt, join passwords, etc
-        :param passwd_byte: original password in bytes, 32 byte length
-        :param passwd_str: string salt to be added
-        :return: modified password in bytes
+        Returns a xor byte combination of the internal secret_key and the provided update_key.
+        It does not modify the internal secret_key. Used for adding salt, join keys, etc.
+        :param update_key: Can be a string, byte or None. Recommended a long one (e.g. 32 byte length)
+        :return: joined key in bytes with a 32 bytes length. Can be None if both internal secret_key and update_key
+                 are None
         """
         """
-        if not passwd_str:
-            return passwd_byte
-        secret_key = bytearray(passwd_byte)
-        for i, b in enumerate(passwd_str.encode()):
-            secret_key[i % 32] ^= b
-        return bytes(secret_key)
+        if not update_key:
+            return self.secret_key
+        elif isinstance(update_key, str):
+            update_key_bytes = update_key.encode()
+        else:
+            update_key_bytes = update_key
+
+        new_secret_key = bytearray(self.secret_key) if self.secret_key else bytearray(32)
+        for i, b in enumerate(update_key_bytes):
+            new_secret_key[i % 32] ^= b
+        return bytes(new_secret_key)
 
 
-    def set_secret_key(self, secret_key):
+    def set_secret_key(self, new_secret_key, replace=False):
         """
         """
-        Set internal secret key used for encryption
-        :param secret_key: byte array length 32 with the secret_key
+        Updates internal secret_key used for encryption, with a byte xor
+        :param new_secret_key: string or byte array. It is recommended a 32 byte length
+        :param replace: if True, old value of internal secret_key is ignored and replaced. If false, a byte xor is used
         :return: None
         """
         :return: None
         """
-        assert (len(secret_key) == 32)
-        self.secret_key = self._join_passwords(secret_key, self.master_password)
+        if replace:
+            self.secret_key = None
+        self.secret_key = self._join_secret_key(new_secret_key)
+
+    def get_secret_key(self):
+        """
+        Get the database secret key in case it is not done when "connect" is called. It can happens when database is
+        empty after an initial install. It should skip if secret is already obtained.
+        """
+        pass
 
     def encrypt(self, value, schema_version=None, salt=None):
         """
 
     def encrypt(self, value, schema_version=None, salt=None):
         """
@@ -183,12 +251,11 @@ class DbBase(object):
         :param salt: optional salt to be used. Must be str
         :return: Encrypted content of value
         """
         :param salt: optional salt to be used. Must be str
         :return: Encrypted content of value
         """
-        if not schema_version or schema_version == '1.0':
+        self.get_secret_key()
+        if not self.secret_key or not schema_version or schema_version == '1.0':
             return value
         else:
             return value
         else:
-            if not self.secret_key:
-                raise DbException("Cannot encrypt. Missing secret_key", http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-            secret_key = self._join_passwords(self.secret_key, salt)
+            secret_key = self._join_secret_key(salt)
             cipher = AES.new(secret_key)
             padded_private_msg = value + ('\0' * ((16-len(value)) % 16))
             encrypted_msg = cipher.encrypt(padded_private_msg)
             cipher = AES.new(secret_key)
             padded_private_msg = value + ('\0' * ((16-len(value)) % 16))
             encrypted_msg = cipher.encrypt(padded_private_msg)
@@ -204,18 +271,46 @@ class DbBase(object):
         :param salt: optional salt to be used
         :return: Plain content of value
         """
         :param salt: optional salt to be used
         :return: Plain content of value
         """
-        if not schema_version or schema_version == '1.0':
+        self.get_secret_key()
+        if not self.secret_key or not schema_version or schema_version == '1.0':
             return value
         else:
             return value
         else:
-            if not self.secret_key:
-                raise DbException("Cannot decrypt. Missing secret_key", http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-            secret_key = self._join_passwords(self.secret_key, salt)
+            secret_key = self._join_secret_key(salt)
             encrypted_msg = b64decode(value)
             cipher = AES.new(secret_key)
             decrypted_msg = cipher.decrypt(encrypted_msg)
             encrypted_msg = b64decode(value)
             cipher = AES.new(secret_key)
             decrypted_msg = cipher.decrypt(encrypted_msg)
-            unpadded_private_msg = decrypted_msg.decode().rstrip('\0')
+            try:
+                unpadded_private_msg = decrypted_msg.decode().rstrip('\0')
+            except UnicodeDecodeError:
+                raise DbException("Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
+                                  http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
             return unpadded_private_msg
 
             return unpadded_private_msg
 
+    def encrypt_decrypt_fields(self, item, action, fields=None, flags=None, schema_version=None, salt=None):
+        if not fields:
+            return
+        self.get_secret_key()
+        actions = ['encrypt', 'decrypt']
+        if action.lower() not in actions:
+            raise DbException("Unknown action ({}): Must be one of {}".format(action, actions),
+                              http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
+        method = self.encrypt if action.lower() == 'encrypt' else self.decrypt
+        if flags is None:
+            flags = re.I
+
+        def process(_item):
+            if isinstance(_item, list):
+                for elem in _item:
+                    process(elem)
+            elif isinstance(_item, dict):
+                for key, val in _item.items():
+                    if isinstance(val, str):
+                        if any(re.search(f, key, flags) for f in fields):
+                            _item[key] = method(val, schema_version, salt)
+                    else:
+                        process(val)
+        process(item)
+
 
 def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
     """
 
 def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
     """
@@ -232,7 +327,7 @@ def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
                     Nothing happens if no match is found. If the value is None the matched elements are deleted.
         $key: val   In case a dictionary is passed in yaml format, if looks for all items in the array dict_to_change
                     that are dictionaries and contains this <key> equal to <val>. Several keys can be used by yaml
                     Nothing happens if no match is found. If the value is None the matched elements are deleted.
         $key: val   In case a dictionary is passed in yaml format, if looks for all items in the array dict_to_change
                     that are dictionaries and contains this <key> equal to <val>. Several keys can be used by yaml
-                    format '{key: val, key: val, ...}'; and all of them mast match. Nothing happens if no match is
+                    format '{key: val, key: val, ...}'; and all of them must match. Nothing happens if no match is
                     found. If value is None the matched items are deleted, otherwise they are edited.
         $+val       If no match if found (see '$val'), the value is appended to the array. If any match is found nothing
                     is changed. A value of None has not sense.
                     found. If value is None the matched items are deleted, otherwise they are edited.
         $+val       If no match if found (see '$val'), the value is appended to the array. If any match is found nothing
                     is changed. A value of None has not sense.