From c5297e4f2313738e2c8df4902339d647f9ada75f Mon Sep 17 00:00:00 2001 From: tierno Date: Wed, 11 Dec 2019 12:32:41 +0000 Subject: [PATCH] Fix 976. Get serial key after database is inited Change-Id: Ic6692c5eabdb3ff7d8b1a7fc6501321dc69ea43a Signed-off-by: tierno --- osm_common/dbbase.py | 10 ++++++++++ osm_common/dbmongo.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/osm_common/dbbase.py b/osm_common/dbbase.py index 95250c1..7428ed9 100644 --- a/osm_common/dbbase.py +++ b/osm_common/dbbase.py @@ -212,6 +212,13 @@ class DbBase(object): self.secret_key = None self.secret_key = self._join_secret_key(new_secret_key) + def get_secret_key(self): + """ + Get the database secret key in case it is not done when "connect" is called. It can happens when database is + empty after an initial install. It should skip if secret is already obtained. + """ + pass + def encrypt(self, value, schema_version=None, salt=None): """ Encrypt a value @@ -221,6 +228,7 @@ class DbBase(object): :param salt: optional salt to be used. Must be str :return: Encrypted content of value """ + self.get_secret_key() if not self.secret_key or not schema_version or schema_version == '1.0': return value else: @@ -240,6 +248,7 @@ class DbBase(object): :param salt: optional salt to be used :return: Plain content of value """ + self.get_secret_key() if not self.secret_key or not schema_version or schema_version == '1.0': return value else: @@ -257,6 +266,7 @@ class DbBase(object): def encrypt_decrypt_fields(self, item, action, fields=None, flags=re.I, schema_version=None, salt=None): if not fields: return + self.get_secret_key() actions = ['encrypt', 'decrypt'] if action.lower() not in actions: raise DbException("Unknown action ({}): Must be one of {}".format(action, actions), diff --git a/osm_common/dbmongo.py b/osm_common/dbmongo.py index 86ec7b7..6eb5ef5 100644 --- a/osm_common/dbmongo.py +++ b/osm_common/dbmongo.py @@ -66,6 +66,22 @@ class DbMongo(DbBase): super().__init__(logger_name, lock) self.client = None self.db = None + self.database_key = None + self.secret_obtained = False + # ^ This is used to know if database serial has been got. Database is inited by NBI, who generates the serial + # In case it is not ready when connected, it should be got later on before any decrypt operation + + def get_secret_key(self): + if self.secret_obtained: + return + + self.secret_key = None + if self.database_key: + self.set_secret_key(self.database_key) + version_data = self.get_one("admin", {"_id": "version"}, fail_on_empty=False, fail_on_more=True) + if version_data and version_data.get("serial"): + self.set_secret_key(b64decode(version_data["serial"])) + self.secret_obtained = True def db_connect(self, config, target_version=None): """ @@ -79,6 +95,7 @@ class DbMongo(DbBase): self.logger = logging.getLogger(config["logger_name"]) master_key = config.get("commonkey") or config.get("masterpassword") if master_key: + self.database_key = master_key self.set_secret_key(master_key) if config.get("uri"): self.client = MongoClient(config["uri"]) @@ -104,6 +121,7 @@ class DbMongo(DbBase): raise DbException("Invalid database version {}. Expected {}".format(db_version, target_version)) # get serial if version_data and version_data.get("serial"): + self.secret_obtained = True self.set_secret_key(b64decode(version_data["serial"])) self.logger.info("Connected to database {} version {}".format(config["name"], db_version)) return -- 2.17.1