Adding more parameters to db_set for array edition
[osm/common.git] / osm_common / dbbase.py
index cdb1644..d199dde 100644 (file)
@@ -19,6 +19,10 @@ import yaml
 import logging
 from http import HTTPStatus
 from copy import deepcopy
+from Crypto.Cipher import AES
+from base64 import b64decode, b64encode
+from osm_common.common_utils import FakeLock
+from threading import Lock
 
 __author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
 
@@ -32,22 +36,41 @@ class DbException(Exception):
 
 class DbBase(object):
 
-    def __init__(self, logger_name='db', master_password=None):
+    def __init__(self, logger_name='db', lock=False):
         """
-        Constructor od dbBase
+        Constructor of dbBase
         :param logger_name: logging name
-        :param master_password: master password used for encrypt decrypt methods
+        :param lock: Used to protect simultaneous access to the same instance class by several threads:
+            False, None: Do not protect, this object will only be accessed by one thread
+            True: This object needs to be protected by several threads accessing.
+            Lock object. Use thi Lock for the threads access protection
         """
         self.logger = logging.getLogger(logger_name)
-        self.master_password = master_password
-
-    def db_connect(self, config):
+        self.secret_key = None  # 32 bytes length array used for encrypt/decrypt
+        if not lock:
+            self.lock = FakeLock()
+        elif lock is True:
+            self.lock = Lock()
+        elif isinstance(lock, Lock):
+            self.lock = lock
+        else:
+            raise ValueError("lock parameter must be a Lock classclass or boolean")
+
+    def db_connect(self, config, target_version=None):
         """
         Connect to database
-        :param config: Configuration of database
+        :param config: Configuration of database. Contains among others:
+            host:   database hosst (mandatory)
+            port:   database port (mandatory)
+            name:   database name (mandatory)
+            user:   database username
+            password:   database password
+            commonkey: common OSM key used for sensible information encryption
+            materpassword: same as commonkey, for backward compatibility. Deprecated, to be removed in the future
+        :param target_version: if provided it checks if database contains required version, raising exception otherwise.
         :return: None or raises DbException on error
         """
-        pass
+        raise DbException("Method 'db_connect' not implemented")
 
     def db_disconnect(self):
         """
@@ -107,7 +130,7 @@ class DbBase(object):
         """
         raise DbException("Method 'create' not implemented")
 
-    def set_one(self, table, q_filter, update_dict, fail_on_empty=True):
+    def set_one(self, table, q_filter, update_dict, fail_on_empty=True, unset=None, pull=None, push=None):
         """
         Modifies an entry at database
         :param table: collection or table
@@ -115,6 +138,12 @@ class DbBase(object):
         :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
         :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
         it raises a DbException
+        :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
+                      ignored. If not exist, it is ignored
+        :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
+                     if exist in the array is removed. If not exist, it is ignored
+        :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
+                     is appended to the end of the array
         :return: Dict with the number of entries modified. None if no matching is found.
         """
         raise DbException("Method 'set_one' not implemented")
@@ -141,27 +170,74 @@ class DbBase(object):
         """
         raise DbException("Method 'replace' not implemented")
 
-    def encrypt(self, value, salt=None):
+    def _join_secret_key(self, update_key):
+        """
+        Returns a xor byte combination of the internal secret_key and the provided update_key.
+        It does not modify the internal secret_key. Used for adding salt, join keys, etc.
+        :param update_key: Can be a string, byte or None. Recommended a long one (e.g. 32 byte length)
+        :return: joined key in bytes with a 32 bytes length. Can be None if both internal secret_key and update_key
+                 are None
+        """
+        if not update_key:
+            return self.secret_key
+        elif isinstance(update_key, str):
+            update_key_bytes = update_key.encode()
+        else:
+            update_key_bytes = update_key
+
+        new_secret_key = bytearray(self.secret_key) if self.secret_key else bytearray(32)
+        for i, b in enumerate(update_key_bytes):
+            new_secret_key[i % 32] ^= b
+        return bytes(new_secret_key)
+
+    def set_secret_key(self, new_secret_key, replace=False):
+        """
+        Updates internal secret_key used for encryption, with a byte xor
+        :param new_secret_key: string or byte array. It is recommended a 32 byte length
+        :param replace: if True, old value of internal secret_key is ignored and replaced. If false, a byte xor is used
+        :return: None
+        """
+        if replace:
+            self.secret_key = None
+        self.secret_key = self._join_secret_key(new_secret_key)
+
+    def encrypt(self, value, schema_version=None, salt=None):
         """
         Encrypt a value
-        :param value: value to be encrypted
-        :param salt: optional salt to be used
+        :param value: value to be encrypted. It is string/unicode
+        :param schema_version: used for version control. If None or '1.0' no encryption is done.
+               If '1.1' symmetric AES encryption is done
+        :param salt: optional salt to be used. Must be str
         :return: Encrypted content of value
         """
-        # for the moment return same value. until all modules call this method
-        return value
-        # raise DbException("Method 'encrypt' not implemented")
-
-    def decrypt(self, value, salt=None):
+        if not self.secret_key or not schema_version or schema_version == '1.0':
+            return value
+        else:
+            secret_key = self._join_secret_key(salt)
+            cipher = AES.new(secret_key)
+            padded_private_msg = value + ('\0' * ((16-len(value)) % 16))
+            encrypted_msg = cipher.encrypt(padded_private_msg)
+            encoded_encrypted_msg = b64encode(encrypted_msg)
+            return encoded_encrypted_msg.decode("ascii")
+
+    def decrypt(self, value, schema_version=None, salt=None):
         """
         Decrypt an encrypted value
-        :param value: value to be decrypted
+        :param value: value to be decrypted. It is a base64 string
+        :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done.
+               If '1.1' symmetric AES encryption has been done
         :param salt: optional salt to be used
         :return: Plain content of value
         """
-        # for the moment return same value. until all modules call this method
-        return value
-        # raise DbException("Method 'decrypt' not implemented")
+        if not self.secret_key or not schema_version or schema_version == '1.0':
+            return value
+        else:
+            secret_key = self._join_secret_key(salt)
+            encrypted_msg = b64decode(value)
+            cipher = AES.new(secret_key)
+            decrypted_msg = cipher.decrypt(encrypted_msg)
+            unpadded_private_msg = decrypted_msg.decode().rstrip('\0')
+            return unpadded_private_msg
 
 
 def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):