Bug 559, changes to make optional intenal database key 66/6866/3
authortierno <alfonso.tiernosepulveda@telefonica.com>
Mon, 12 Nov 2018 10:51:49 +0000 (11:51 +0100)
committertierno <alfonso.tiernosepulveda@telefonica.com>
Mon, 12 Nov 2018 14:34:28 +0000 (15:34 +0100)
Change-Id: I883bb3f874a917d5632b8aa9e937e08d2d7b5507
Signed-off-by: tierno <alfonso.tiernosepulveda@telefonica.com>
osm_common/__init__.py
osm_common/dbbase.py
osm_common/dbmemory.py
osm_common/dbmongo.py
osm_common/tests/test_dbbase.py

index 81567f8..858586c 100644 (file)
@@ -15,6 +15,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-version = '0.1.12'
+version = '0.1.13'
 # TODO add package version filling commit id with 0's; e.g.:  '5.0.0.post11+00000000.dirty-1'
-date_version = '2018-11-05'
+date_version = '2018-11-12'
index 81586ec..b418079 100644 (file)
@@ -46,8 +46,7 @@ class DbBase(object):
             Lock object. Use thi Lock for the threads access protection
         """
         self.logger = logging.getLogger(logger_name)
-        self.master_password = None
-        self.secret_key = None
+        self.secret_key = None  # 32 bytes length array used for encrypt/decrypt
         if not lock:
             self.lock = FakeLock()
         elif lock is True:
@@ -66,7 +65,8 @@ class DbBase(object):
             name:   database name (mandatory)
             user:   database username
             password:   database password
-            masterpassword: database password used for sensible information encryption
+            commonkey: common OSM key used for sensible information encryption
+            materpassword: same as commonkey, for backward compatibility. Deprecated, to be removed in the future
         :param target_version: if provided it checks if database contains required version, raising exception otherwise.
         :return: None or raises DbException on error
         """
@@ -164,29 +164,36 @@ class DbBase(object):
         """
         raise DbException("Method 'replace' not implemented")
 
-    @staticmethod
-    def _join_passwords(passwd_byte, passwd_str):
+    def _join_secret_key(self, update_key):
         """
-        Modifies passwd_byte with the xor of passwd_str. Used for adding salt, join passwords, etc
-        :param passwd_byte: original password in bytes, 32 byte length
-        :param passwd_str: string salt to be added
-        :return: modified password in bytes
+        Returns a xor byte combination of the internal secret_key and the provided update_key.
+        It does not modify the internal secret_key. Used for adding salt, join keys, etc.
+        :param update_key: Can be a string, byte or None. Recommended a long one (e.g. 32 byte length)
+        :return: joined key in bytes with a 32 bytes length. Can be None if both internal secret_key and update_key
+                 are None
         """
-        if not passwd_str:
-            return passwd_byte
-        secret_key = bytearray(passwd_byte)
-        for i, b in enumerate(passwd_str.encode()):
-            secret_key[i % 32] ^= b
-        return bytes(secret_key)
-
-    def set_secret_key(self, secret_key):
+        if not update_key:
+            return self.secret_key
+        elif isinstance(update_key, str):
+            update_key_bytes = update_key.encode()
+        else:
+            update_key_bytes = update_key
+
+        new_secret_key = bytearray(self.secret_key) if self.secret_key else bytearray(32)
+        for i, b in enumerate(update_key_bytes):
+            new_secret_key[i % 32] ^= b
+        return bytes(new_secret_key)
+
+    def set_secret_key(self, new_secret_key, replace=False):
         """
-        Set internal secret key used for encryption
-        :param secret_key: byte array length 32 with the secret_key
+        Updates internal secret_key used for encryption, with a byte xor
+        :param new_secret_key: string or byte array. It is recommended a 32 byte length
+        :param replace: if True, old value of internal secret_key is ignored and replaced. If false, a byte xor is used
         :return: None
         """
-        assert (len(secret_key) == 32)
-        self.secret_key = self._join_passwords(secret_key, self.master_password)
+        if replace:
+            self.secret_key = None
+        self.secret_key = self._join_secret_key(new_secret_key)
 
     def encrypt(self, value, schema_version=None, salt=None):
         """
@@ -197,12 +204,10 @@ class DbBase(object):
         :param salt: optional salt to be used. Must be str
         :return: Encrypted content of value
         """
-        if not schema_version or schema_version == '1.0':
+        if not self.secret_key or not schema_version or schema_version == '1.0':
             return value
         else:
-            if not self.secret_key:
-                raise DbException("Cannot encrypt. Missing secret_key", http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-            secret_key = self._join_passwords(self.secret_key, salt)
+            secret_key = self._join_secret_key(salt)
             cipher = AES.new(secret_key)
             padded_private_msg = value + ('\0' * ((16-len(value)) % 16))
             encrypted_msg = cipher.encrypt(padded_private_msg)
@@ -218,12 +223,10 @@ class DbBase(object):
         :param salt: optional salt to be used
         :return: Plain content of value
         """
-        if not schema_version or schema_version == '1.0':
+        if not self.secret_key or not schema_version or schema_version == '1.0':
             return value
         else:
-            if not self.secret_key:
-                raise DbException("Cannot decrypt. Missing secret_key", http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-            secret_key = self._join_passwords(self.secret_key, salt)
+            secret_key = self._join_secret_key(salt)
             encrypted_msg = b64decode(value)
             cipher = AES.new(secret_key)
             decrypted_msg = cipher.decrypt(encrypted_msg)
index b796e81..c2d9a11 100644 (file)
@@ -38,7 +38,9 @@ class DbMemory(DbBase):
         """
         if "logger_name" in config:
             self.logger = logging.getLogger(config["logger_name"])
-        self.master_password = config.get("masterpassword")
+        master_key = config.get("commonkey") or config.get("masterpassword")
+        if master_key:
+            self.set_secret_key(master_key)
 
     @staticmethod
     def _format_filter(q_filter):
index 8140521..0f89c96 100644 (file)
@@ -77,7 +77,9 @@ class DbMongo(DbBase):
         try:
             if "logger_name" in config:
                 self.logger = logging.getLogger(config["logger_name"])
-            self.master_password = config.get("masterpassword")
+            master_key = config.get("commonkey") or config.get("masterpassword")
+            if master_key:
+                self.set_secret_key(master_key)
             self.client = MongoClient(config["host"], config["port"])
             # TODO add as parameters also username=config.get("user"), password=config.get("password"))
             # when all modules are ready
index 918ced7..3a8861d 100644 (file)
@@ -67,13 +67,17 @@ def test_del_one(db_base):
 
 class TestEncryption(unittest.TestCase):
     def setUp(self):
-        master_password = "Setting a long master password with numbers 123 and capitals AGHBNHD and symbols %&8)!'"
+        master_key = "Setting a long master key with numbers 123 and capitals AGHBNHD and symbols %&8)!'"
         db_base1 = DbBase()
         db_base2 = DbBase()
+        db_base3 = DbBase()
         # set self.secret_key obtained when connect
-        db_base1.secret_key = DbBase._join_passwords(urandom(32), master_password)
-        db_base2.secret_key = DbBase._join_passwords(urandom(32), None)
-        self.db_base = [db_base1, db_base2]
+        db_base1.set_secret_key(master_key, replace=True)
+        db_base1.set_secret_key(urandom(32))
+        db_base2.set_secret_key(None, replace=True)
+        db_base2.set_secret_key(urandom(30))
+        db_base3.set_secret_key(master_key)
+        self.db_bases = [db_base1, db_base2, db_base3]
 
     def test_encrypt_decrypt(self):
         TEST = (
@@ -84,7 +88,7 @@ class TestEncryption(unittest.TestCase):
             (u"plain unicode 5 with salt ! ", "1a000d1a-4a7e-4d9c-8c65-251290183106"),
             (u"plain unicode 6 with usalt ! ", u"1abcdd1a-4a7e-4d9c-8c65-251290183106"),
         )
-        for db_base in self.db_base:
+        for db_base in self.db_bases:
             for value, salt in TEST:
                 # no encryption
                 encrypted = db_base.encrypt(value, schema_version='1.0', salt=salt)
@@ -102,7 +106,7 @@ class TestEncryption(unittest.TestCase):
     def test_encrypt_decrypt_salt(self):
         value = "value to be encrypted!"
         encrypted = []
-        for db_base in self.db_base:
+        for db_base in self.db_bases:
             for salt in (None, "salt 1", "1afd5d1a-4a7e-4d9c-8c65-251290183106"):
                 # encrypt/decrypt
                 encrypted.append(db_base.encrypt(value, schema_version='1.1', salt=salt))
@@ -113,7 +117,7 @@ class TestEncryption(unittest.TestCase):
         for i in range(0, len(encrypted)):
             for j in range(i+1, len(encrypted)):
                 self.assertNotEqual(encrypted[i], encrypted[j],
-                                    "encryption with different salt contains different result")
+                                    "encryption with different salt must contain different result")
 
 
 class TestDeepUpdate(unittest.TestCase):