X-Git-Url: https://osm.etsi.org/gitweb/?a=blobdiff_plain;f=osm_common%2Fdbmongo.py;h=2e94a5a989d8af3e5a2b9ab0cca32278827a9c75;hb=cfc5272864156b706d7147fc4e7c0fe46dc386c8;hp=9b5bc57e8c6bfac246e5caec44b9273acf1e4685;hpb=87858cab98b3b169fc891fd2e0a0ba10f8b46127;p=osm%2Fcommon.git diff --git a/osm_common/dbmongo.py b/osm_common/dbmongo.py index 9b5bc57..2e94a5a 100644 --- a/osm_common/dbmongo.py +++ b/osm_common/dbmongo.py @@ -22,6 +22,7 @@ from osm_common.dbbase import DbException, DbBase from http import HTTPStatus from time import time, sleep from copy import deepcopy +from base64 import b64decode __author__ = "Alfonso Tierno " @@ -61,21 +62,25 @@ class DbMongo(DbBase): conn_initial_timout = 120 conn_timout = 10 - def __init__(self, logger_name='db', master_password=None): - super().__init__(logger_name, master_password) + def __init__(self, logger_name='db'): + super().__init__(logger_name) self.client = None self.db = None - def db_connect(self, config): + def db_connect(self, config, target_version=None): """ Connect to database :param config: Configuration of database + :param target_version: if provided it checks if database contains required version, raising exception otherwise. :return: None or raises DbException on error """ try: if "logger_name" in config: self.logger = logging.getLogger(config["logger_name"]) + self.master_password = config.get("masterpassword") self.client = MongoClient(config["host"], config["port"]) + # TODO add as parameters also username=config.get("user"), password=config.get("password")) + # when all modules are ready self.db = self.client[config["name"]] if "loglevel" in config: self.logger.setLevel(getattr(logging, config['loglevel'])) @@ -83,7 +88,19 @@ class DbMongo(DbBase): now = time() while True: try: - self.db.users.find_one({"username": "admin"}) + version_data = self.get_one("admin", {"_id": "version"}, fail_on_empty=False, fail_on_more=True) + # check database status is ok + if version_data and version_data.get("status") != 'ENABLED': + raise DbException("Wrong database status '{}'".format(version_data.get("status")), + http_code=HTTPStatus.INTERNAL_SERVER_ERROR) + # check version + db_version = None if not version_data else version_data.get("version") + if target_version and target_version != db_version: + raise DbException("Invalid database version {}. Expected {}".format(db_version, target_version)) + # get serial + if version_data and version_data.get("serial"): + self.set_secret_key(b64decode(version_data["serial"])) + self.logger.info("Connected to database {} version {}".format(config["name"], db_version)) return except errors.ConnectionFailure as e: if time() - now >= self.conn_initial_timout: