Bug 559 adding encrypt/decrypt methods
[osm/common.git] / osm_common / tests / test_dbbase.py
index 2da4506..64bfb3e 100644 (file)
@@ -2,6 +2,7 @@ import http
 import pytest
 import unittest
 from osm_common.dbbase import DbBase, DbException, deep_update
+from os import urandom
 
 
 def exception_message(message):
@@ -20,7 +21,9 @@ def test_constructor():
 
 
 def test_db_connect(db_base):
-    db_base.db_connect(None)
+    with pytest.raises(DbException) as excinfo:
+        db_base.db_connect(None)
+    assert str(excinfo.value).startswith(exception_message("Method 'db_connect' not implemented"))
 
 
 def test_db_disconnect(db_base):
@@ -62,6 +65,57 @@ def test_del_one(db_base):
     assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
 
 
+class TestEncryption(unittest.TestCase):
+    def setUp(self):
+        master_password = "Setting a long master password with numbers 123 and capitals AGHBNHD and symbols %&8)!'"
+        db_base1 = DbBase(master_password=master_password)
+        db_base2 = DbBase()
+        # set self.secret_key obtained when connect
+        db_base1.secret_key = DbBase._join_passwords(urandom(32), db_base1.master_password)
+        db_base2.secret_key = DbBase._join_passwords(urandom(32), db_base2.master_password)
+        self.db_base = [db_base1, db_base2]
+
+    def test_encrypt_decrypt(self):
+        TEST = (
+            ("plain text 1 ! ", None),
+            ("plain text 2 with salt ! ", "1afd5d1a-4a7e-4d9c-8c65-251290183106"),
+            ("plain text 3 with usalt ! ", u"1afd5d1a-4a7e-4d9c-8c65-251290183106"),
+            (u"plain unicode 4 ! ", None),
+            (u"plain unicode 5 with salt ! ", "1a000d1a-4a7e-4d9c-8c65-251290183106"),
+            (u"plain unicode 6 with usalt ! ", u"1abcdd1a-4a7e-4d9c-8c65-251290183106"),
+        )
+        for db_base in self.db_base:
+            for value, salt in TEST:
+                # no encryption
+                encrypted = db_base.encrypt(value, schema_version='1.0', salt=salt)
+                self.assertEqual(encrypted, value, "value '{}' has been encrypted".format(value))
+                decrypted = db_base.decrypt(encrypted, schema_version='1.0', salt=salt)
+                self.assertEqual(decrypted, value, "value '{}' has been decrypted".format(value))
+
+                # encrypt/decrypt
+                encrypted = db_base.encrypt(value, schema_version='1.1', salt=salt)
+                self.assertNotEqual(encrypted, value, "value '{}' has not been encrypted".format(value))
+                self.assertIsInstance(encrypted, str, "Encrypted is not ascii text")
+                decrypted = db_base.decrypt(encrypted, schema_version='1.1', salt=salt)
+                self.assertEqual(decrypted, value, "value is not equal after encryption/decryption")
+
+    def test_encrypt_decrypt_salt(self):
+        value = "value to be encrypted!"
+        encrypted = []
+        for db_base in self.db_base:
+            for salt in (None, "salt 1", "1afd5d1a-4a7e-4d9c-8c65-251290183106"):
+                # encrypt/decrypt
+                encrypted.append(db_base.encrypt(value, schema_version='1.1', salt=salt))
+                self.assertNotEqual(encrypted[-1], value, "value '{}' has not been encrypted".format(value))
+                self.assertIsInstance(encrypted[-1], str, "Encrypted is not ascii text")
+                decrypted = db_base.decrypt(encrypted[-1], schema_version='1.1', salt=salt)
+                self.assertEqual(decrypted, value, "value is not equal after encryption/decryption")
+        for i in range(0, len(encrypted)):
+            for j in range(i+1, len(encrypted)):
+                self.assertNotEqual(encrypted[i], encrypted[j],
+                                    "encryption with different salt contains different result")
+
+
 class TestDeepUpdate(unittest.TestCase):
     def test_update_dict(self):
         # Original, patch, expected result
@@ -161,3 +215,7 @@ class TestDeepUpdate(unittest.TestCase):
                 deep_update(t[0], t[1])
             except DbException as e:
                 print(e)
+
+
+if __name__ == '__main__':
+    unittest.main()