| David Garcia | eb8943a | 2021-04-12 12:07:37 +0200 | [diff] [blame] | 1 | # Copyright 2021 Canonical Ltd. |
| 2 | # |
| 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | # you may not use this file except in compliance with the License. |
| 5 | # You may obtain a copy of the License at |
| 6 | # |
| 7 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | # |
| 9 | # Unless required by applicable law or agreed to in writing, software |
| 10 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | # See the License for the specific language governing permissions and |
| 13 | # limitations under the License. |
| 14 | |
| 15 | import abc |
| 16 | import asyncio |
| 17 | from base64 import b64decode |
| 18 | import re |
| 19 | import typing |
| 20 | |
| 21 | from Crypto.Cipher import AES |
| 22 | from motor.motor_asyncio import AsyncIOMotorClient |
| 23 | from n2vc.config import EnvironConfig |
| 24 | from n2vc.vca.connection_data import ConnectionData |
| 25 | from osm_common.dbmongo import DbMongo, DbException |
| 26 | |
| 27 | DB_NAME = "osm" |
| 28 | |
| 29 | |
| 30 | class Store(abc.ABC): |
| 31 | @abc.abstractmethod |
| 32 | async def get_vca_connection_data(self, vca_id: str) -> ConnectionData: |
| 33 | """ |
| 34 | Get VCA connection data |
| 35 | |
| 36 | :param: vca_id: VCA ID |
| 37 | |
| 38 | :returns: ConnectionData with the information of the database |
| 39 | """ |
| 40 | |
| 41 | @abc.abstractmethod |
| 42 | async def update_vca_endpoints(self, hosts: typing.List[str], vca_id: str): |
| 43 | """ |
| 44 | Update VCA endpoints |
| 45 | |
| 46 | :param: endpoints: List of endpoints to write in the database |
| 47 | :param: vca_id: VCA ID |
| 48 | """ |
| 49 | |
| 50 | @abc.abstractmethod |
| 51 | async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]: |
| 52 | """ |
| 53 | Get list if VCA endpoints |
| 54 | |
| 55 | :param: vca_id: VCA ID |
| 56 | |
| 57 | :returns: List of endpoints |
| 58 | """ |
| 59 | |
| 60 | @abc.abstractmethod |
| 61 | async def get_vca_id(self, vim_id: str = None) -> str: |
| 62 | """ |
| 63 | Get VCA id for a VIM account |
| 64 | |
| 65 | :param: vim_id: Vim account ID |
| 66 | """ |
| 67 | |
| 68 | |
| 69 | class DbMongoStore(Store): |
| 70 | def __init__(self, db: DbMongo): |
| 71 | """ |
| 72 | Constructor |
| 73 | |
| 74 | :param: db: osm_common.dbmongo.DbMongo object |
| 75 | """ |
| 76 | self.db = db |
| 77 | |
| 78 | async def get_vca_connection_data(self, vca_id: str) -> ConnectionData: |
| 79 | """ |
| 80 | Get VCA connection data |
| 81 | |
| 82 | :param: vca_id: VCA ID |
| 83 | |
| 84 | :returns: ConnectionData with the information of the database |
| 85 | """ |
| 86 | data = self.db.get_one("vca", q_filter={"_id": vca_id}) |
| 87 | self.db.encrypt_decrypt_fields( |
| 88 | data, |
| 89 | "decrypt", |
| 90 | ["secret", "cacert"], |
| 91 | schema_version=data["schema_version"], |
| 92 | salt=data["_id"], |
| 93 | ) |
| 94 | return ConnectionData(**data) |
| 95 | |
| 96 | async def update_vca_endpoints( |
| 97 | self, endpoints: typing.List[str], vca_id: str = None |
| 98 | ): |
| 99 | """ |
| 100 | Update VCA endpoints |
| 101 | |
| 102 | :param: endpoints: List of endpoints to write in the database |
| 103 | :param: vca_id: VCA ID |
| 104 | """ |
| 105 | if vca_id: |
| 106 | data = self.db.get_one("vca", q_filter={"_id": vca_id}) |
| 107 | data["endpoints"] = endpoints |
| 108 | self._update("vca", vca_id, data) |
| 109 | else: |
| 110 | # The default VCA. Data for the endpoints is in a different place |
| 111 | juju_info = self._get_juju_info() |
| 112 | # If it doesn't, then create it |
| 113 | if not juju_info: |
| 114 | try: |
| 115 | self.db.create( |
| 116 | "vca", |
| 117 | {"_id": "juju"}, |
| 118 | ) |
| 119 | except DbException as e: |
| 120 | # Racing condition: check if another N2VC worker has created it |
| 121 | juju_info = self._get_juju_info() |
| 122 | if not juju_info: |
| 123 | raise e |
| 124 | self.db.set_one( |
| 125 | "vca", |
| 126 | {"_id": "juju"}, |
| 127 | {"api_endpoints": endpoints}, |
| 128 | ) |
| 129 | |
| 130 | async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]: |
| 131 | """ |
| 132 | Get list if VCA endpoints |
| 133 | |
| 134 | :param: vca_id: VCA ID |
| 135 | |
| 136 | :returns: List of endpoints |
| 137 | """ |
| 138 | endpoints = [] |
| 139 | if vca_id: |
| 140 | endpoints = self.get_vca_connection_data(vca_id).endpoints |
| 141 | else: |
| 142 | juju_info = self._get_juju_info() |
| 143 | if juju_info and "api_endpoints" in juju_info: |
| 144 | endpoints = juju_info["api_endpoints"] |
| 145 | return endpoints |
| 146 | |
| 147 | async def get_vca_id(self, vim_id: str = None) -> str: |
| 148 | """ |
| 149 | Get VCA ID from the database for a given VIM account ID |
| 150 | |
| 151 | :param: vim_id: VIM account ID |
| 152 | """ |
| 153 | return ( |
| 154 | self.db.get_one( |
| 155 | "vim_accounts", |
| 156 | q_filter={"_id": vim_id}, |
| 157 | fail_on_empty=False, |
| 158 | ).get("vca") |
| 159 | if vim_id |
| 160 | else None |
| 161 | ) |
| 162 | |
| 163 | def _update(self, collection: str, id: str, data: dict): |
| 164 | """ |
| 165 | Update object in database |
| 166 | |
| 167 | :param: collection: Collection name |
| 168 | :param: id: ID of the object |
| 169 | :param: data: Object data |
| 170 | """ |
| 171 | self.db.replace( |
| 172 | collection, |
| 173 | id, |
| 174 | data, |
| 175 | ) |
| 176 | |
| 177 | def _get_juju_info(self): |
| 178 | """Get Juju information (the default VCA) from the admin collection""" |
| 179 | return self.db.get_one( |
| 180 | "vca", |
| 181 | q_filter={"_id": "juju"}, |
| 182 | fail_on_empty=False, |
| 183 | ) |
| 184 | |
| 185 | |
| 186 | class MotorStore(Store): |
| 187 | def __init__(self, uri: str, loop=None): |
| 188 | """ |
| 189 | Constructor |
| 190 | |
| 191 | :param: uri: Connection string to connect to the database. |
| 192 | :param: loop: Asyncio Loop |
| 193 | """ |
| 194 | self._client = AsyncIOMotorClient(uri) |
| 195 | self.loop = loop or asyncio.get_event_loop() |
| 196 | self._secret_key = None |
| 197 | self._config = EnvironConfig(prefixes=["OSMLCM_", "OSMMON_"]) |
| 198 | |
| 199 | @property |
| 200 | def _database(self): |
| 201 | return self._client[DB_NAME] |
| 202 | |
| 203 | @property |
| 204 | def _vca_collection(self): |
| 205 | return self._database["vca"] |
| 206 | |
| 207 | @property |
| 208 | def _admin_collection(self): |
| 209 | return self._database["admin"] |
| 210 | |
| 211 | @property |
| 212 | def _vim_accounts_collection(self): |
| 213 | return self._database["vim_accounts"] |
| 214 | |
| 215 | async def get_vca_connection_data(self, vca_id: str) -> ConnectionData: |
| 216 | """ |
| 217 | Get VCA connection data |
| 218 | |
| 219 | :param: vca_id: VCA ID |
| 220 | |
| 221 | :returns: ConnectionData with the information of the database |
| 222 | """ |
| 223 | data = await self._vca_collection.find_one({"_id": vca_id}) |
| 224 | if not data: |
| 225 | raise Exception("vca with id {} not found".format(vca_id)) |
| 226 | await self.decrypt_fields( |
| 227 | data, |
| 228 | ["secret", "cacert"], |
| 229 | schema_version=data["schema_version"], |
| 230 | salt=data["_id"], |
| 231 | ) |
| 232 | return ConnectionData(**data) |
| 233 | |
| 234 | async def update_vca_endpoints( |
| 235 | self, endpoints: typing.List[str], vca_id: str = None |
| 236 | ): |
| 237 | """ |
| 238 | Update VCA endpoints |
| 239 | |
| 240 | :param: endpoints: List of endpoints to write in the database |
| 241 | :param: vca_id: VCA ID |
| 242 | """ |
| 243 | if vca_id: |
| 244 | data = await self._vca_collection.find_one({"_id": vca_id}) |
| 245 | data["endpoints"] = endpoints |
| 246 | await self._vca_collection.replace_one({"_id": vca_id}, data) |
| 247 | else: |
| 248 | # The default VCA. Data for the endpoints is in a different place |
| 249 | juju_info = await self._get_juju_info() |
| 250 | # If it doesn't, then create it |
| 251 | if not juju_info: |
| 252 | try: |
| 253 | await self._admin_collection.insert_one({"_id": "juju"}) |
| 254 | except Exception as e: |
| 255 | # Racing condition: check if another N2VC worker has created it |
| 256 | juju_info = await self._get_juju_info() |
| 257 | if not juju_info: |
| 258 | raise e |
| 259 | |
| 260 | await self._admin_collection.replace_one( |
| 261 | {"_id": "juju"}, {"api_endpoints": endpoints} |
| 262 | ) |
| 263 | |
| 264 | async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]: |
| 265 | """ |
| 266 | Get list if VCA endpoints |
| 267 | |
| 268 | :param: vca_id: VCA ID |
| 269 | |
| 270 | :returns: List of endpoints |
| 271 | """ |
| 272 | endpoints = [] |
| 273 | if vca_id: |
| 274 | endpoints = (await self.get_vca_connection_data(vca_id)).endpoints |
| 275 | else: |
| 276 | juju_info = await self._get_juju_info() |
| 277 | if juju_info and "api_endpoints" in juju_info: |
| 278 | endpoints = juju_info["api_endpoints"] |
| 279 | return endpoints |
| 280 | |
| 281 | async def get_vca_id(self, vim_id: str = None) -> str: |
| 282 | """ |
| 283 | Get VCA ID from the database for a given VIM account ID |
| 284 | |
| 285 | :param: vim_id: VIM account ID |
| 286 | """ |
| 287 | vca_id = None |
| 288 | if vim_id: |
| 289 | vim_account = await self._vim_accounts_collection.find_one({"_id": vim_id}) |
| 290 | if vim_account and "vca" in vim_account: |
| 291 | vca_id = vim_account["vca"] |
| 292 | return vca_id |
| 293 | |
| 294 | async def _get_juju_info(self): |
| 295 | """Get Juju information (the default VCA) from the admin collection""" |
| 296 | return await self._admin_collection.find_one({"_id": "juju"}) |
| 297 | |
| 298 | # DECRYPT METHODS |
| 299 | async def decrypt_fields( |
| 300 | self, |
| 301 | item: dict, |
| 302 | fields: typing.List[str], |
| 303 | schema_version: str = None, |
| 304 | salt: str = None, |
| 305 | ): |
| 306 | """ |
| 307 | Decrypt fields |
| 308 | |
| 309 | Decrypt fields from a dictionary. Follows the same logic as in osm_common. |
| 310 | |
| 311 | :param: item: Dictionary with the keys to be decrypted |
| 312 | :param: fields: List of keys to decrypt |
| 313 | :param: schema version: Schema version. (i.e. 1.11) |
| 314 | :param: salt: Salt for the decryption |
| 315 | """ |
| 316 | flags = re.I |
| 317 | |
| 318 | async def process(_item): |
| 319 | if isinstance(_item, list): |
| 320 | for elem in _item: |
| 321 | await process(elem) |
| 322 | elif isinstance(_item, dict): |
| 323 | for key, val in _item.items(): |
| 324 | if isinstance(val, str): |
| 325 | if any(re.search(f, key, flags) for f in fields): |
| 326 | _item[key] = await self.decrypt(val, schema_version, salt) |
| 327 | else: |
| 328 | await process(val) |
| 329 | |
| 330 | await process(item) |
| 331 | |
| 332 | async def decrypt(self, value, schema_version=None, salt=None): |
| 333 | """ |
| 334 | Decrypt an encrypted value |
| 335 | :param value: value to be decrypted. It is a base64 string |
| 336 | :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done. |
| 337 | If '1.1' symmetric AES encryption has been done |
| 338 | :param salt: optional salt to be used |
| 339 | :return: Plain content of value |
| 340 | """ |
| David Garcia | d80f038 | 2021-06-21 13:20:30 +0200 | [diff] [blame] | 341 | await self.get_secret_key() |
| 342 | if not self.secret_key or not schema_version or schema_version == "1.0": |
| David Garcia | eb8943a | 2021-04-12 12:07:37 +0200 | [diff] [blame] | 343 | return value |
| 344 | else: |
| 345 | secret_key = self._join_secret_key(salt) |
| 346 | encrypted_msg = b64decode(value) |
| 347 | cipher = AES.new(secret_key) |
| 348 | decrypted_msg = cipher.decrypt(encrypted_msg) |
| 349 | try: |
| 350 | unpadded_private_msg = decrypted_msg.decode().rstrip("\0") |
| 351 | except UnicodeDecodeError: |
| 352 | raise DbException( |
| 353 | "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?", |
| 354 | http_code=500, |
| 355 | ) |
| 356 | return unpadded_private_msg |
| 357 | |
| David Garcia | d80f038 | 2021-06-21 13:20:30 +0200 | [diff] [blame] | 358 | def _join_secret_key(self, update_key: typing.Any) -> bytes: |
| David Garcia | eb8943a | 2021-04-12 12:07:37 +0200 | [diff] [blame] | 359 | """ |
| David Garcia | d80f038 | 2021-06-21 13:20:30 +0200 | [diff] [blame] | 360 | Join key with secret key |
| David Garcia | eb8943a | 2021-04-12 12:07:37 +0200 | [diff] [blame] | 361 | |
| 362 | :param: update_key: str or bytes with the to update |
| David Garcia | d80f038 | 2021-06-21 13:20:30 +0200 | [diff] [blame] | 363 | |
| 364 | :return: Joined key |
| David Garcia | eb8943a | 2021-04-12 12:07:37 +0200 | [diff] [blame] | 365 | """ |
| David Garcia | d80f038 | 2021-06-21 13:20:30 +0200 | [diff] [blame] | 366 | return self._join_keys(update_key, self.secret_key) |
| 367 | |
| 368 | def _join_keys(self, key: typing.Any, secret_key: bytes) -> bytes: |
| 369 | """ |
| 370 | Join key with secret_key |
| 371 | |
| 372 | :param: key: str or bytesof the key to update |
| 373 | :param: secret_key: bytes of the secret key |
| 374 | |
| 375 | :return: Joined key |
| 376 | """ |
| 377 | if isinstance(key, str): |
| 378 | update_key_bytes = key.encode() |
| David Garcia | eb8943a | 2021-04-12 12:07:37 +0200 | [diff] [blame] | 379 | else: |
| David Garcia | d80f038 | 2021-06-21 13:20:30 +0200 | [diff] [blame] | 380 | update_key_bytes = key |
| 381 | new_secret_key = bytearray(secret_key) if secret_key else bytearray(32) |
| David Garcia | eb8943a | 2021-04-12 12:07:37 +0200 | [diff] [blame] | 382 | for i, b in enumerate(update_key_bytes): |
| 383 | new_secret_key[i % 32] ^= b |
| 384 | return bytes(new_secret_key) |
| 385 | |
| 386 | @property |
| David Garcia | d80f038 | 2021-06-21 13:20:30 +0200 | [diff] [blame] | 387 | def secret_key(self): |
| 388 | return self._secret_key |
| 389 | |
| 390 | async def get_secret_key(self): |
| 391 | """ |
| 392 | Get secret key using the database key and the serial key in the DB |
| 393 | The key is populated in the property self.secret_key |
| 394 | """ |
| 395 | if self.secret_key: |
| 396 | return |
| 397 | secret_key = None |
| 398 | if self.database_key: |
| 399 | secret_key = self._join_keys(self.database_key, None) |
| 400 | version_data = await self._admin_collection.find_one({"_id": "version"}) |
| 401 | if version_data and version_data.get("serial"): |
| 402 | secret_key = self._join_keys(b64decode(version_data["serial"]), secret_key) |
| 403 | self._secret_key = secret_key |
| David Garcia | eb8943a | 2021-04-12 12:07:37 +0200 | [diff] [blame] | 404 | |
| 405 | @property |
| 406 | def database_key(self): |
| 407 | return self._config["database_commonkey"] |