blob: b827d5123ad2b07a40f795e1d0aa531296562202 [file] [log] [blame]
# Copyright 2021 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import asyncio
from base64 import b64decode
import re
import typing
from Crypto.Cipher import AES
from motor.motor_asyncio import AsyncIOMotorClient
from n2vc.config import EnvironConfig
from n2vc.vca.connection_data import ConnectionData
from osm_common.dbmongo import DbMongo, DbException
DB_NAME = "osm"
class Store(abc.ABC):
@abc.abstractmethod
async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
"""
Get VCA connection data
:param: vca_id: VCA ID
:returns: ConnectionData with the information of the database
"""
@abc.abstractmethod
async def update_vca_endpoints(self, hosts: typing.List[str], vca_id: str):
"""
Update VCA endpoints
:param: endpoints: List of endpoints to write in the database
:param: vca_id: VCA ID
"""
@abc.abstractmethod
async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
"""
Get list if VCA endpoints
:param: vca_id: VCA ID
:returns: List of endpoints
"""
@abc.abstractmethod
async def get_vca_id(self, vim_id: str = None) -> str:
"""
Get VCA id for a VIM account
:param: vim_id: Vim account ID
"""
class DbMongoStore(Store):
def __init__(self, db: DbMongo):
"""
Constructor
:param: db: osm_common.dbmongo.DbMongo object
"""
self.db = db
async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
"""
Get VCA connection data
:param: vca_id: VCA ID
:returns: ConnectionData with the information of the database
"""
data = self.db.get_one("vca", q_filter={"_id": vca_id})
self.db.encrypt_decrypt_fields(
data,
"decrypt",
["secret", "cacert"],
schema_version=data["schema_version"],
salt=data["_id"],
)
return ConnectionData(**data)
async def update_vca_endpoints(
self, endpoints: typing.List[str], vca_id: str = None
):
"""
Update VCA endpoints
:param: endpoints: List of endpoints to write in the database
:param: vca_id: VCA ID
"""
if vca_id:
data = self.db.get_one("vca", q_filter={"_id": vca_id})
data["endpoints"] = endpoints
self._update("vca", vca_id, data)
else:
# The default VCA. Data for the endpoints is in a different place
juju_info = self._get_juju_info()
# If it doesn't, then create it
if not juju_info:
try:
self.db.create(
"vca",
{"_id": "juju"},
)
except DbException as e:
# Racing condition: check if another N2VC worker has created it
juju_info = self._get_juju_info()
if not juju_info:
raise e
self.db.set_one(
"vca",
{"_id": "juju"},
{"api_endpoints": endpoints},
)
async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
"""
Get list if VCA endpoints
:param: vca_id: VCA ID
:returns: List of endpoints
"""
endpoints = []
if vca_id:
endpoints = self.get_vca_connection_data(vca_id).endpoints
else:
juju_info = self._get_juju_info()
if juju_info and "api_endpoints" in juju_info:
endpoints = juju_info["api_endpoints"]
return endpoints
async def get_vca_id(self, vim_id: str = None) -> str:
"""
Get VCA ID from the database for a given VIM account ID
:param: vim_id: VIM account ID
"""
return (
self.db.get_one(
"vim_accounts",
q_filter={"_id": vim_id},
fail_on_empty=False,
).get("vca")
if vim_id
else None
)
def _update(self, collection: str, id: str, data: dict):
"""
Update object in database
:param: collection: Collection name
:param: id: ID of the object
:param: data: Object data
"""
self.db.replace(
collection,
id,
data,
)
def _get_juju_info(self):
"""Get Juju information (the default VCA) from the admin collection"""
return self.db.get_one(
"vca",
q_filter={"_id": "juju"},
fail_on_empty=False,
)
class MotorStore(Store):
def __init__(self, uri: str, loop=None):
"""
Constructor
:param: uri: Connection string to connect to the database.
:param: loop: Asyncio Loop
"""
self._client = AsyncIOMotorClient(uri)
self.loop = loop or asyncio.get_event_loop()
self._secret_key = None
self._config = EnvironConfig(prefixes=["OSMLCM_", "OSMMON_"])
@property
def _database(self):
return self._client[DB_NAME]
@property
def _vca_collection(self):
return self._database["vca"]
@property
def _admin_collection(self):
return self._database["admin"]
@property
def _vim_accounts_collection(self):
return self._database["vim_accounts"]
async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
"""
Get VCA connection data
:param: vca_id: VCA ID
:returns: ConnectionData with the information of the database
"""
data = await self._vca_collection.find_one({"_id": vca_id})
if not data:
raise Exception("vca with id {} not found".format(vca_id))
await self.decrypt_fields(
data,
["secret", "cacert"],
schema_version=data["schema_version"],
salt=data["_id"],
)
return ConnectionData(**data)
async def update_vca_endpoints(
self, endpoints: typing.List[str], vca_id: str = None
):
"""
Update VCA endpoints
:param: endpoints: List of endpoints to write in the database
:param: vca_id: VCA ID
"""
if vca_id:
data = await self._vca_collection.find_one({"_id": vca_id})
data["endpoints"] = endpoints
await self._vca_collection.replace_one({"_id": vca_id}, data)
else:
# The default VCA. Data for the endpoints is in a different place
juju_info = await self._get_juju_info()
# If it doesn't, then create it
if not juju_info:
try:
await self._admin_collection.insert_one({"_id": "juju"})
except Exception as e:
# Racing condition: check if another N2VC worker has created it
juju_info = await self._get_juju_info()
if not juju_info:
raise e
await self._admin_collection.replace_one(
{"_id": "juju"}, {"api_endpoints": endpoints}
)
async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
"""
Get list if VCA endpoints
:param: vca_id: VCA ID
:returns: List of endpoints
"""
endpoints = []
if vca_id:
endpoints = (await self.get_vca_connection_data(vca_id)).endpoints
else:
juju_info = await self._get_juju_info()
if juju_info and "api_endpoints" in juju_info:
endpoints = juju_info["api_endpoints"]
return endpoints
async def get_vca_id(self, vim_id: str = None) -> str:
"""
Get VCA ID from the database for a given VIM account ID
:param: vim_id: VIM account ID
"""
vca_id = None
if vim_id:
vim_account = await self._vim_accounts_collection.find_one({"_id": vim_id})
if vim_account and "vca" in vim_account:
vca_id = vim_account["vca"]
return vca_id
async def _get_juju_info(self):
"""Get Juju information (the default VCA) from the admin collection"""
return await self._admin_collection.find_one({"_id": "juju"})
# DECRYPT METHODS
async def decrypt_fields(
self,
item: dict,
fields: typing.List[str],
schema_version: str = None,
salt: str = None,
):
"""
Decrypt fields
Decrypt fields from a dictionary. Follows the same logic as in osm_common.
:param: item: Dictionary with the keys to be decrypted
:param: fields: List of keys to decrypt
:param: schema version: Schema version. (i.e. 1.11)
:param: salt: Salt for the decryption
"""
flags = re.I
async def process(_item):
if isinstance(_item, list):
for elem in _item:
await process(elem)
elif isinstance(_item, dict):
for key, val in _item.items():
if isinstance(val, str):
if any(re.search(f, key, flags) for f in fields):
_item[key] = await self.decrypt(val, schema_version, salt)
else:
await process(val)
await process(item)
async def decrypt(self, value, schema_version=None, salt=None):
"""
Decrypt an encrypted value
:param value: value to be decrypted. It is a base64 string
:param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done.
If '1.1' symmetric AES encryption has been done
:param salt: optional salt to be used
:return: Plain content of value
"""
if not await self.secret_key or not schema_version or schema_version == "1.0":
return value
else:
secret_key = self._join_secret_key(salt)
encrypted_msg = b64decode(value)
cipher = AES.new(secret_key)
decrypted_msg = cipher.decrypt(encrypted_msg)
try:
unpadded_private_msg = decrypted_msg.decode().rstrip("\0")
except UnicodeDecodeError:
raise DbException(
"Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
http_code=500,
)
return unpadded_private_msg
def _join_secret_key(self, update_key: typing.Any):
"""
Join secret key
:param: update_key: str or bytes with the to update
"""
if isinstance(update_key, str):
update_key_bytes = update_key.encode()
else:
update_key_bytes = update_key
new_secret_key = (
bytearray(self._secret_key) if self._secret_key else bytearray(32)
)
for i, b in enumerate(update_key_bytes):
new_secret_key[i % 32] ^= b
return bytes(new_secret_key)
@property
async def secret_key(self):
if self._secret_key:
return self._secret_key
else:
if self.database_key:
self._secret_key = self._join_secret_key(self.database_key)
version_data = await self._admin_collection.find_one({"_id": "version"})
if version_data and version_data.get("serial"):
self._secret_key = self._join_secret_key(
b64decode(version_data["serial"])
)
return self._secret_key
@property
def database_key(self):
return self._config["database_commonkey"]