Feature 10239: Distributed VCA
- Add vca_id in all calls that invoke libjuju. This is for being able to
talk to the default VCA or the VCA associated to the VIM
- Add store.py: Abstraction to talk to the database.
- DBMongoStore: Use the db from common to talk to the database
- MotorStore: Use motor, an asynchronous mongodb client to talk to the
database
- Add vca/connection.py: Represents the data needed to connect the VCA
- Add EnvironConfig in config.py: Class to get the environment config,
and avoid LCM from passing that
Change-Id: I28625e0c56ce408114022c83d4b7cacbb649434c
Signed-off-by: David Garcia <david.garcia@canonical.com>
diff --git a/n2vc/store.py b/n2vc/store.py
new file mode 100644
index 0000000..b827d51
--- /dev/null
+++ b/n2vc/store.py
@@ -0,0 +1,390 @@
+# Copyright 2021 Canonical Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import asyncio
+from base64 import b64decode
+import re
+import typing
+
+from Crypto.Cipher import AES
+from motor.motor_asyncio import AsyncIOMotorClient
+from n2vc.config import EnvironConfig
+from n2vc.vca.connection_data import ConnectionData
+from osm_common.dbmongo import DbMongo, DbException
+
+DB_NAME = "osm"
+
+
+class Store(abc.ABC):
+ @abc.abstractmethod
+ async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
+ """
+ Get VCA connection data
+
+ :param: vca_id: VCA ID
+
+ :returns: ConnectionData with the information of the database
+ """
+
+ @abc.abstractmethod
+ async def update_vca_endpoints(self, hosts: typing.List[str], vca_id: str):
+ """
+ Update VCA endpoints
+
+ :param: endpoints: List of endpoints to write in the database
+ :param: vca_id: VCA ID
+ """
+
+ @abc.abstractmethod
+ async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
+ """
+ Get list if VCA endpoints
+
+ :param: vca_id: VCA ID
+
+ :returns: List of endpoints
+ """
+
+ @abc.abstractmethod
+ async def get_vca_id(self, vim_id: str = None) -> str:
+ """
+ Get VCA id for a VIM account
+
+ :param: vim_id: Vim account ID
+ """
+
+
+class DbMongoStore(Store):
+ def __init__(self, db: DbMongo):
+ """
+ Constructor
+
+ :param: db: osm_common.dbmongo.DbMongo object
+ """
+ self.db = db
+
+ async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
+ """
+ Get VCA connection data
+
+ :param: vca_id: VCA ID
+
+ :returns: ConnectionData with the information of the database
+ """
+ data = self.db.get_one("vca", q_filter={"_id": vca_id})
+ self.db.encrypt_decrypt_fields(
+ data,
+ "decrypt",
+ ["secret", "cacert"],
+ schema_version=data["schema_version"],
+ salt=data["_id"],
+ )
+ return ConnectionData(**data)
+
+ async def update_vca_endpoints(
+ self, endpoints: typing.List[str], vca_id: str = None
+ ):
+ """
+ Update VCA endpoints
+
+ :param: endpoints: List of endpoints to write in the database
+ :param: vca_id: VCA ID
+ """
+ if vca_id:
+ data = self.db.get_one("vca", q_filter={"_id": vca_id})
+ data["endpoints"] = endpoints
+ self._update("vca", vca_id, data)
+ else:
+ # The default VCA. Data for the endpoints is in a different place
+ juju_info = self._get_juju_info()
+ # If it doesn't, then create it
+ if not juju_info:
+ try:
+ self.db.create(
+ "vca",
+ {"_id": "juju"},
+ )
+ except DbException as e:
+ # Racing condition: check if another N2VC worker has created it
+ juju_info = self._get_juju_info()
+ if not juju_info:
+ raise e
+ self.db.set_one(
+ "vca",
+ {"_id": "juju"},
+ {"api_endpoints": endpoints},
+ )
+
+ async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
+ """
+ Get list if VCA endpoints
+
+ :param: vca_id: VCA ID
+
+ :returns: List of endpoints
+ """
+ endpoints = []
+ if vca_id:
+ endpoints = self.get_vca_connection_data(vca_id).endpoints
+ else:
+ juju_info = self._get_juju_info()
+ if juju_info and "api_endpoints" in juju_info:
+ endpoints = juju_info["api_endpoints"]
+ return endpoints
+
+ async def get_vca_id(self, vim_id: str = None) -> str:
+ """
+ Get VCA ID from the database for a given VIM account ID
+
+ :param: vim_id: VIM account ID
+ """
+ return (
+ self.db.get_one(
+ "vim_accounts",
+ q_filter={"_id": vim_id},
+ fail_on_empty=False,
+ ).get("vca")
+ if vim_id
+ else None
+ )
+
+ def _update(self, collection: str, id: str, data: dict):
+ """
+ Update object in database
+
+ :param: collection: Collection name
+ :param: id: ID of the object
+ :param: data: Object data
+ """
+ self.db.replace(
+ collection,
+ id,
+ data,
+ )
+
+ def _get_juju_info(self):
+ """Get Juju information (the default VCA) from the admin collection"""
+ return self.db.get_one(
+ "vca",
+ q_filter={"_id": "juju"},
+ fail_on_empty=False,
+ )
+
+
+class MotorStore(Store):
+ def __init__(self, uri: str, loop=None):
+ """
+ Constructor
+
+ :param: uri: Connection string to connect to the database.
+ :param: loop: Asyncio Loop
+ """
+ self._client = AsyncIOMotorClient(uri)
+ self.loop = loop or asyncio.get_event_loop()
+ self._secret_key = None
+ self._config = EnvironConfig(prefixes=["OSMLCM_", "OSMMON_"])
+
+ @property
+ def _database(self):
+ return self._client[DB_NAME]
+
+ @property
+ def _vca_collection(self):
+ return self._database["vca"]
+
+ @property
+ def _admin_collection(self):
+ return self._database["admin"]
+
+ @property
+ def _vim_accounts_collection(self):
+ return self._database["vim_accounts"]
+
+ async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
+ """
+ Get VCA connection data
+
+ :param: vca_id: VCA ID
+
+ :returns: ConnectionData with the information of the database
+ """
+ data = await self._vca_collection.find_one({"_id": vca_id})
+ if not data:
+ raise Exception("vca with id {} not found".format(vca_id))
+ await self.decrypt_fields(
+ data,
+ ["secret", "cacert"],
+ schema_version=data["schema_version"],
+ salt=data["_id"],
+ )
+ return ConnectionData(**data)
+
+ async def update_vca_endpoints(
+ self, endpoints: typing.List[str], vca_id: str = None
+ ):
+ """
+ Update VCA endpoints
+
+ :param: endpoints: List of endpoints to write in the database
+ :param: vca_id: VCA ID
+ """
+ if vca_id:
+ data = await self._vca_collection.find_one({"_id": vca_id})
+ data["endpoints"] = endpoints
+ await self._vca_collection.replace_one({"_id": vca_id}, data)
+ else:
+ # The default VCA. Data for the endpoints is in a different place
+ juju_info = await self._get_juju_info()
+ # If it doesn't, then create it
+ if not juju_info:
+ try:
+ await self._admin_collection.insert_one({"_id": "juju"})
+ except Exception as e:
+ # Racing condition: check if another N2VC worker has created it
+ juju_info = await self._get_juju_info()
+ if not juju_info:
+ raise e
+
+ await self._admin_collection.replace_one(
+ {"_id": "juju"}, {"api_endpoints": endpoints}
+ )
+
+ async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
+ """
+ Get list if VCA endpoints
+
+ :param: vca_id: VCA ID
+
+ :returns: List of endpoints
+ """
+ endpoints = []
+ if vca_id:
+ endpoints = (await self.get_vca_connection_data(vca_id)).endpoints
+ else:
+ juju_info = await self._get_juju_info()
+ if juju_info and "api_endpoints" in juju_info:
+ endpoints = juju_info["api_endpoints"]
+ return endpoints
+
+ async def get_vca_id(self, vim_id: str = None) -> str:
+ """
+ Get VCA ID from the database for a given VIM account ID
+
+ :param: vim_id: VIM account ID
+ """
+ vca_id = None
+ if vim_id:
+ vim_account = await self._vim_accounts_collection.find_one({"_id": vim_id})
+ if vim_account and "vca" in vim_account:
+ vca_id = vim_account["vca"]
+ return vca_id
+
+ async def _get_juju_info(self):
+ """Get Juju information (the default VCA) from the admin collection"""
+ return await self._admin_collection.find_one({"_id": "juju"})
+
+ # DECRYPT METHODS
+ async def decrypt_fields(
+ self,
+ item: dict,
+ fields: typing.List[str],
+ schema_version: str = None,
+ salt: str = None,
+ ):
+ """
+ Decrypt fields
+
+ Decrypt fields from a dictionary. Follows the same logic as in osm_common.
+
+ :param: item: Dictionary with the keys to be decrypted
+ :param: fields: List of keys to decrypt
+ :param: schema version: Schema version. (i.e. 1.11)
+ :param: salt: Salt for the decryption
+ """
+ flags = re.I
+
+ async def process(_item):
+ if isinstance(_item, list):
+ for elem in _item:
+ await process(elem)
+ elif isinstance(_item, dict):
+ for key, val in _item.items():
+ if isinstance(val, str):
+ if any(re.search(f, key, flags) for f in fields):
+ _item[key] = await self.decrypt(val, schema_version, salt)
+ else:
+ await process(val)
+
+ await process(item)
+
+ async def decrypt(self, value, schema_version=None, salt=None):
+ """
+ Decrypt an encrypted value
+ :param value: value to be decrypted. It is a base64 string
+ :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done.
+ If '1.1' symmetric AES encryption has been done
+ :param salt: optional salt to be used
+ :return: Plain content of value
+ """
+ if not await self.secret_key or not schema_version or schema_version == "1.0":
+ return value
+ else:
+ secret_key = self._join_secret_key(salt)
+ encrypted_msg = b64decode(value)
+ cipher = AES.new(secret_key)
+ decrypted_msg = cipher.decrypt(encrypted_msg)
+ try:
+ unpadded_private_msg = decrypted_msg.decode().rstrip("\0")
+ except UnicodeDecodeError:
+ raise DbException(
+ "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
+ http_code=500,
+ )
+ return unpadded_private_msg
+
+ def _join_secret_key(self, update_key: typing.Any):
+ """
+ Join secret key
+
+ :param: update_key: str or bytes with the to update
+ """
+ if isinstance(update_key, str):
+ update_key_bytes = update_key.encode()
+ else:
+ update_key_bytes = update_key
+ new_secret_key = (
+ bytearray(self._secret_key) if self._secret_key else bytearray(32)
+ )
+ for i, b in enumerate(update_key_bytes):
+ new_secret_key[i % 32] ^= b
+ return bytes(new_secret_key)
+
+ @property
+ async def secret_key(self):
+ if self._secret_key:
+ return self._secret_key
+ else:
+ if self.database_key:
+ self._secret_key = self._join_secret_key(self.database_key)
+ version_data = await self._admin_collection.find_one({"_id": "version"})
+ if version_data and version_data.get("serial"):
+ self._secret_key = self._join_secret_key(
+ b64decode(version_data["serial"])
+ )
+ return self._secret_key
+
+ @property
+ def database_key(self):
+ return self._config["database_commonkey"]