| # Copyright 2021 Canonical Ltd. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import abc |
| import asyncio |
| from base64 import b64decode |
| import re |
| import typing |
| |
| from Crypto.Cipher import AES |
| from motor.motor_asyncio import AsyncIOMotorClient |
| from n2vc.config import EnvironConfig |
| from n2vc.vca.connection_data import ConnectionData |
| from osm_common.dbmongo import DbMongo, DbException |
| |
| DB_NAME = "osm" |
| |
| |
| class Store(abc.ABC): |
| @abc.abstractmethod |
| async def get_vca_connection_data(self, vca_id: str) -> ConnectionData: |
| """ |
| Get VCA connection data |
| |
| :param: vca_id: VCA ID |
| |
| :returns: ConnectionData with the information of the database |
| """ |
| |
| @abc.abstractmethod |
| async def update_vca_endpoints(self, hosts: typing.List[str], vca_id: str): |
| """ |
| Update VCA endpoints |
| |
| :param: endpoints: List of endpoints to write in the database |
| :param: vca_id: VCA ID |
| """ |
| |
| @abc.abstractmethod |
| async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]: |
| """ |
| Get list if VCA endpoints |
| |
| :param: vca_id: VCA ID |
| |
| :returns: List of endpoints |
| """ |
| |
| @abc.abstractmethod |
| async def get_vca_id(self, vim_id: str = None) -> str: |
| """ |
| Get VCA id for a VIM account |
| |
| :param: vim_id: Vim account ID |
| """ |
| |
| |
| class DbMongoStore(Store): |
| def __init__(self, db: DbMongo): |
| """ |
| Constructor |
| |
| :param: db: osm_common.dbmongo.DbMongo object |
| """ |
| self.db = db |
| |
| async def get_vca_connection_data(self, vca_id: str) -> ConnectionData: |
| """ |
| Get VCA connection data |
| |
| :param: vca_id: VCA ID |
| |
| :returns: ConnectionData with the information of the database |
| """ |
| data = self.db.get_one("vca", q_filter={"_id": vca_id}) |
| self.db.encrypt_decrypt_fields( |
| data, |
| "decrypt", |
| ["secret", "cacert"], |
| schema_version=data["schema_version"], |
| salt=data["_id"], |
| ) |
| return ConnectionData(**data) |
| |
| async def update_vca_endpoints( |
| self, endpoints: typing.List[str], vca_id: str = None |
| ): |
| """ |
| Update VCA endpoints |
| |
| :param: endpoints: List of endpoints to write in the database |
| :param: vca_id: VCA ID |
| """ |
| if vca_id: |
| data = self.db.get_one("vca", q_filter={"_id": vca_id}) |
| data["endpoints"] = endpoints |
| self._update("vca", vca_id, data) |
| else: |
| # The default VCA. Data for the endpoints is in a different place |
| juju_info = self._get_juju_info() |
| # If it doesn't, then create it |
| if not juju_info: |
| try: |
| self.db.create( |
| "vca", |
| {"_id": "juju"}, |
| ) |
| except DbException as e: |
| # Racing condition: check if another N2VC worker has created it |
| juju_info = self._get_juju_info() |
| if not juju_info: |
| raise e |
| self.db.set_one( |
| "vca", |
| {"_id": "juju"}, |
| {"api_endpoints": endpoints}, |
| ) |
| |
| async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]: |
| """ |
| Get list if VCA endpoints |
| |
| :param: vca_id: VCA ID |
| |
| :returns: List of endpoints |
| """ |
| endpoints = [] |
| if vca_id: |
| endpoints = self.get_vca_connection_data(vca_id).endpoints |
| else: |
| juju_info = self._get_juju_info() |
| if juju_info and "api_endpoints" in juju_info: |
| endpoints = juju_info["api_endpoints"] |
| return endpoints |
| |
| async def get_vca_id(self, vim_id: str = None) -> str: |
| """ |
| Get VCA ID from the database for a given VIM account ID |
| |
| :param: vim_id: VIM account ID |
| """ |
| return ( |
| self.db.get_one( |
| "vim_accounts", |
| q_filter={"_id": vim_id}, |
| fail_on_empty=False, |
| ).get("vca") |
| if vim_id |
| else None |
| ) |
| |
| def _update(self, collection: str, id: str, data: dict): |
| """ |
| Update object in database |
| |
| :param: collection: Collection name |
| :param: id: ID of the object |
| :param: data: Object data |
| """ |
| self.db.replace( |
| collection, |
| id, |
| data, |
| ) |
| |
| def _get_juju_info(self): |
| """Get Juju information (the default VCA) from the admin collection""" |
| return self.db.get_one( |
| "vca", |
| q_filter={"_id": "juju"}, |
| fail_on_empty=False, |
| ) |
| |
| |
| class MotorStore(Store): |
| def __init__(self, uri: str, loop=None): |
| """ |
| Constructor |
| |
| :param: uri: Connection string to connect to the database. |
| :param: loop: Asyncio Loop |
| """ |
| self._client = AsyncIOMotorClient(uri) |
| self.loop = loop or asyncio.get_event_loop() |
| self._secret_key = None |
| self._config = EnvironConfig(prefixes=["OSMLCM_", "OSMMON_"]) |
| |
| @property |
| def _database(self): |
| return self._client[DB_NAME] |
| |
| @property |
| def _vca_collection(self): |
| return self._database["vca"] |
| |
| @property |
| def _admin_collection(self): |
| return self._database["admin"] |
| |
| @property |
| def _vim_accounts_collection(self): |
| return self._database["vim_accounts"] |
| |
| async def get_vca_connection_data(self, vca_id: str) -> ConnectionData: |
| """ |
| Get VCA connection data |
| |
| :param: vca_id: VCA ID |
| |
| :returns: ConnectionData with the information of the database |
| """ |
| data = await self._vca_collection.find_one({"_id": vca_id}) |
| if not data: |
| raise Exception("vca with id {} not found".format(vca_id)) |
| await self.decrypt_fields( |
| data, |
| ["secret", "cacert"], |
| schema_version=data["schema_version"], |
| salt=data["_id"], |
| ) |
| return ConnectionData(**data) |
| |
| async def update_vca_endpoints( |
| self, endpoints: typing.List[str], vca_id: str = None |
| ): |
| """ |
| Update VCA endpoints |
| |
| :param: endpoints: List of endpoints to write in the database |
| :param: vca_id: VCA ID |
| """ |
| if vca_id: |
| data = await self._vca_collection.find_one({"_id": vca_id}) |
| data["endpoints"] = endpoints |
| await self._vca_collection.replace_one({"_id": vca_id}, data) |
| else: |
| # The default VCA. Data for the endpoints is in a different place |
| juju_info = await self._get_juju_info() |
| # If it doesn't, then create it |
| if not juju_info: |
| try: |
| await self._admin_collection.insert_one({"_id": "juju"}) |
| except Exception as e: |
| # Racing condition: check if another N2VC worker has created it |
| juju_info = await self._get_juju_info() |
| if not juju_info: |
| raise e |
| |
| await self._admin_collection.replace_one( |
| {"_id": "juju"}, {"api_endpoints": endpoints} |
| ) |
| |
| async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]: |
| """ |
| Get list if VCA endpoints |
| |
| :param: vca_id: VCA ID |
| |
| :returns: List of endpoints |
| """ |
| endpoints = [] |
| if vca_id: |
| endpoints = (await self.get_vca_connection_data(vca_id)).endpoints |
| else: |
| juju_info = await self._get_juju_info() |
| if juju_info and "api_endpoints" in juju_info: |
| endpoints = juju_info["api_endpoints"] |
| return endpoints |
| |
| async def get_vca_id(self, vim_id: str = None) -> str: |
| """ |
| Get VCA ID from the database for a given VIM account ID |
| |
| :param: vim_id: VIM account ID |
| """ |
| vca_id = None |
| if vim_id: |
| vim_account = await self._vim_accounts_collection.find_one({"_id": vim_id}) |
| if vim_account and "vca" in vim_account: |
| vca_id = vim_account["vca"] |
| return vca_id |
| |
| async def _get_juju_info(self): |
| """Get Juju information (the default VCA) from the admin collection""" |
| return await self._admin_collection.find_one({"_id": "juju"}) |
| |
| # DECRYPT METHODS |
| async def decrypt_fields( |
| self, |
| item: dict, |
| fields: typing.List[str], |
| schema_version: str = None, |
| salt: str = None, |
| ): |
| """ |
| Decrypt fields |
| |
| Decrypt fields from a dictionary. Follows the same logic as in osm_common. |
| |
| :param: item: Dictionary with the keys to be decrypted |
| :param: fields: List of keys to decrypt |
| :param: schema version: Schema version. (i.e. 1.11) |
| :param: salt: Salt for the decryption |
| """ |
| flags = re.I |
| |
| async def process(_item): |
| if isinstance(_item, list): |
| for elem in _item: |
| await process(elem) |
| elif isinstance(_item, dict): |
| for key, val in _item.items(): |
| if isinstance(val, str): |
| if any(re.search(f, key, flags) for f in fields): |
| _item[key] = await self.decrypt(val, schema_version, salt) |
| else: |
| await process(val) |
| |
| await process(item) |
| |
| async def decrypt(self, value, schema_version=None, salt=None): |
| """ |
| Decrypt an encrypted value |
| :param value: value to be decrypted. It is a base64 string |
| :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done. |
| If '1.1' symmetric AES encryption has been done |
| :param salt: optional salt to be used |
| :return: Plain content of value |
| """ |
| if not await self.secret_key or not schema_version or schema_version == "1.0": |
| return value |
| else: |
| secret_key = self._join_secret_key(salt) |
| encrypted_msg = b64decode(value) |
| cipher = AES.new(secret_key) |
| decrypted_msg = cipher.decrypt(encrypted_msg) |
| try: |
| unpadded_private_msg = decrypted_msg.decode().rstrip("\0") |
| except UnicodeDecodeError: |
| raise DbException( |
| "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?", |
| http_code=500, |
| ) |
| return unpadded_private_msg |
| |
| def _join_secret_key(self, update_key: typing.Any): |
| """ |
| Join secret key |
| |
| :param: update_key: str or bytes with the to update |
| """ |
| if isinstance(update_key, str): |
| update_key_bytes = update_key.encode() |
| else: |
| update_key_bytes = update_key |
| new_secret_key = ( |
| bytearray(self._secret_key) if self._secret_key else bytearray(32) |
| ) |
| for i, b in enumerate(update_key_bytes): |
| new_secret_key[i % 32] ^= b |
| return bytes(new_secret_key) |
| |
| @property |
| async def secret_key(self): |
| if self._secret_key: |
| return self._secret_key |
| else: |
| if self.database_key: |
| self._secret_key = self._join_secret_key(self.database_key) |
| version_data = await self._admin_collection.find_one({"_id": "version"}) |
| if version_data and version_data.get("serial"): |
| self._secret_key = self._join_secret_key( |
| b64decode(version_data["serial"]) |
| ) |
| return self._secret_key |
| |
| @property |
| def database_key(self): |
| return self._config["database_commonkey"] |