Code Coverage

Cobertura Coverage Report > n2vc >

store.py

Trend

File Coverage summary

NameClassesLinesConditionals
store.py
100%
1/1
98%
158/162
100%
0/0

Coverage Breakdown by Class

NameLinesConditionals
store.py
98%
158/162
N/A

Source

n2vc/store.py
1 # Copyright 2021 Canonical Ltd.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 #     Unless required by applicable law or agreed to in writing, software
10 #     distributed under the License is distributed on an "AS IS" BASIS,
11 #     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 #     See the License for the specific language governing permissions and
13 #     limitations under the License.
14
15 1 import abc
16 1 import asyncio
17 1 from base64 import b64decode
18 1 import re
19 1 import typing
20
21 1 from Crypto.Cipher import AES
22 1 from motor.motor_asyncio import AsyncIOMotorClient
23 1 from n2vc.config import EnvironConfig
24 1 from n2vc.vca.connection_data import ConnectionData
25 1 from osm_common.dbmongo import DbMongo, DbException
26
27 1 DB_NAME = "osm"
28
29
30 1 class Store(abc.ABC):
31 1     @abc.abstractmethod
32 1     async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
33         """
34         Get VCA connection data
35
36         :param: vca_id: VCA ID
37
38         :returns: ConnectionData with the information of the database
39         """
40
41 1     @abc.abstractmethod
42 1     async def update_vca_endpoints(self, hosts: typing.List[str], vca_id: str):
43         """
44         Update VCA endpoints
45
46         :param: endpoints: List of endpoints to write in the database
47         :param: vca_id: VCA ID
48         """
49
50 1     @abc.abstractmethod
51 1     async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
52         """
53         Get list if VCA endpoints
54
55         :param: vca_id: VCA ID
56
57         :returns: List of endpoints
58         """
59
60 1     @abc.abstractmethod
61 1     async def get_vca_id(self, vim_id: str = None) -> str:
62         """
63         Get VCA id for a VIM account
64
65         :param: vim_id: Vim account ID
66         """
67
68
69 1 class DbMongoStore(Store):
70 1     def __init__(self, db: DbMongo):
71         """
72         Constructor
73
74         :param: db: osm_common.dbmongo.DbMongo object
75         """
76 1         self.db = db
77
78 1     async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
79         """
80         Get VCA connection data
81
82         :param: vca_id: VCA ID
83
84         :returns: ConnectionData with the information of the database
85         """
86 1         data = self.db.get_one("vca", q_filter={"_id": vca_id})
87 1         self.db.encrypt_decrypt_fields(
88             data,
89             "decrypt",
90             ["secret", "cacert"],
91             schema_version=data["schema_version"],
92             salt=data["_id"],
93         )
94 1         return ConnectionData(**data)
95
96 1     async def update_vca_endpoints(
97         self, endpoints: typing.List[str], vca_id: str = None
98     ):
99         """
100         Update VCA endpoints
101
102         :param: endpoints: List of endpoints to write in the database
103         :param: vca_id: VCA ID
104         """
105 1         if vca_id:
106 1             data = self.db.get_one("vca", q_filter={"_id": vca_id})
107 1             data["endpoints"] = endpoints
108 1             self._update("vca", vca_id, data)
109         else:
110             # The default VCA. Data for the endpoints is in a different place
111 1             juju_info = self._get_juju_info()
112             # If it doesn't, then create it
113 1             if not juju_info:
114 1                 try:
115 1                     self.db.create(
116                         "vca",
117                         {"_id": "juju"},
118                     )
119 1                 except DbException as e:
120                     # Racing condition: check if another N2VC worker has created it
121 1                     juju_info = self._get_juju_info()
122 1                     if not juju_info:
123 1                         raise e
124 1             self.db.set_one(
125                 "vca",
126                 {"_id": "juju"},
127                 {"api_endpoints": endpoints},
128             )
129
130 1     async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
131         """
132         Get list if VCA endpoints
133
134         :param: vca_id: VCA ID
135
136         :returns: List of endpoints
137         """
138 1         endpoints = []
139 1         if vca_id:
140 1             endpoints = self.get_vca_connection_data(vca_id).endpoints
141         else:
142 1             juju_info = self._get_juju_info()
143 1             if juju_info and "api_endpoints" in juju_info:
144 1                 endpoints = juju_info["api_endpoints"]
145 1         return endpoints
146
147 1     async def get_vca_id(self, vim_id: str = None) -> str:
148         """
149         Get VCA ID from the database for a given VIM account ID
150
151         :param: vim_id: VIM account ID
152         """
153 1         return (
154             self.db.get_one(
155                 "vim_accounts",
156                 q_filter={"_id": vim_id},
157                 fail_on_empty=False,
158             ).get("vca")
159             if vim_id
160             else None
161         )
162
163 1     def _update(self, collection: str, id: str, data: dict):
164         """
165         Update object in database
166
167         :param: collection: Collection name
168         :param: id: ID of the object
169         :param: data: Object data
170         """
171 1         self.db.replace(
172             collection,
173             id,
174             data,
175         )
176
177 1     def _get_juju_info(self):
178         """Get Juju information (the default VCA) from the admin collection"""
179 1         return self.db.get_one(
180             "vca",
181             q_filter={"_id": "juju"},
182             fail_on_empty=False,
183         )
184
185
186 1 class MotorStore(Store):
187 1     def __init__(self, uri: str, loop=None):
188         """
189         Constructor
190
191         :param: uri: Connection string to connect to the database.
192         :param: loop: Asyncio Loop
193         """
194 1         self._client = AsyncIOMotorClient(uri)
195 1         self.loop = loop or asyncio.get_event_loop()
196 1         self._secret_key = None
197 1         self._config = EnvironConfig(prefixes=["OSMLCM_", "OSMMON_"])
198
199 1     @property
200 1     def _database(self):
201 1         return self._client[DB_NAME]
202
203 1     @property
204 1     def _vca_collection(self):
205 1         return self._database["vca"]
206
207 1     @property
208 1     def _admin_collection(self):
209 1         return self._database["admin"]
210
211 1     @property
212 1     def _vim_accounts_collection(self):
213 1         return self._database["vim_accounts"]
214
215 1     async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
216         """
217         Get VCA connection data
218
219         :param: vca_id: VCA ID
220
221         :returns: ConnectionData with the information of the database
222         """
223 1         data = await self._vca_collection.find_one({"_id": vca_id})
224 1         if not data:
225 1             raise Exception("vca with id {} not found".format(vca_id))
226 1         await self.decrypt_fields(
227             data,
228             ["secret", "cacert"],
229             schema_version=data["schema_version"],
230             salt=data["_id"],
231         )
232 1         return ConnectionData(**data)
233
234 1     async def update_vca_endpoints(
235         self, endpoints: typing.List[str], vca_id: str = None
236     ):
237         """
238         Update VCA endpoints
239
240         :param: endpoints: List of endpoints to write in the database
241         :param: vca_id: VCA ID
242         """
243 1         if vca_id:
244 1             data = await self._vca_collection.find_one({"_id": vca_id})
245 1             data["endpoints"] = endpoints
246 1             await self._vca_collection.replace_one({"_id": vca_id}, data)
247         else:
248             # The default VCA. Data for the endpoints is in a different place
249 1             juju_info = await self._get_juju_info()
250             # If it doesn't, then create it
251 1             if not juju_info:
252 1                 try:
253 1                     await self._admin_collection.insert_one({"_id": "juju"})
254 1                 except Exception as e:
255                     # Racing condition: check if another N2VC worker has created it
256 1                     juju_info = await self._get_juju_info()
257 1                     if not juju_info:
258 1                         raise e
259
260 1             await self._admin_collection.replace_one(
261                 {"_id": "juju"}, {"api_endpoints": endpoints}
262             )
263
264 1     async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
265         """
266         Get list if VCA endpoints
267
268         :param: vca_id: VCA ID
269
270         :returns: List of endpoints
271         """
272 1         endpoints = []
273 1         if vca_id:
274 1             endpoints = (await self.get_vca_connection_data(vca_id)).endpoints
275         else:
276 1             juju_info = await self._get_juju_info()
277 1             if juju_info and "api_endpoints" in juju_info:
278 1                 endpoints = juju_info["api_endpoints"]
279 1         return endpoints
280
281 1     async def get_vca_id(self, vim_id: str = None) -> str:
282         """
283         Get VCA ID from the database for a given VIM account ID
284
285         :param: vim_id: VIM account ID
286         """
287 1         vca_id = None
288 1         if vim_id:
289 1             vim_account = await self._vim_accounts_collection.find_one({"_id": vim_id})
290 1             if vim_account and "vca" in vim_account:
291 1                 vca_id = vim_account["vca"]
292 1         return vca_id
293
294 1     async def _get_juju_info(self):
295         """Get Juju information (the default VCA) from the admin collection"""
296 1         return await self._admin_collection.find_one({"_id": "juju"})
297
298     # DECRYPT METHODS
299 1     async def decrypt_fields(
300         self,
301         item: dict,
302         fields: typing.List[str],
303         schema_version: str = None,
304         salt: str = None,
305     ):
306         """
307         Decrypt fields
308
309         Decrypt fields from a dictionary. Follows the same logic as in osm_common.
310
311         :param: item: Dictionary with the keys to be decrypted
312         :param: fields: List of keys to decrypt
313         :param: schema version: Schema version. (i.e. 1.11)
314         :param: salt: Salt for the decryption
315         """
316 1         flags = re.I
317
318 1         async def process(_item):
319 1             if isinstance(_item, list):
320 1                 for elem in _item:
321 0                     await process(elem)
322 1             elif isinstance(_item, dict):
323 1                 for key, val in _item.items():
324 1                     if isinstance(val, str):
325 1                         if any(re.search(f, key, flags) for f in fields):
326 1                             _item[key] = await self.decrypt(val, schema_version, salt)
327                     else:
328 1                         await process(val)
329
330 1         await process(item)
331
332 1     async def decrypt(self, value, schema_version=None, salt=None):
333         """
334         Decrypt an encrypted value
335         :param value: value to be decrypted. It is a base64 string
336         :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done.
337                If '1.1' symmetric AES encryption has been done
338         :param salt: optional salt to be used
339         :return: Plain content of value
340         """
341 1         await self.get_secret_key()
342 1         if not self.secret_key or not schema_version or schema_version == "1.0":
343 0             return value
344         else:
345 1             secret_key = self._join_secret_key(salt)
346 1             encrypted_msg = b64decode(value)
347 1             cipher = AES.new(secret_key)
348 1             decrypted_msg = cipher.decrypt(encrypted_msg)
349 1             try:
350 1                 unpadded_private_msg = decrypted_msg.decode().rstrip("\0")
351 0             except UnicodeDecodeError:
352 0                 raise DbException(
353                     "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
354                     http_code=500,
355                 )
356 1             return unpadded_private_msg
357
358 1     def _join_secret_key(self, update_key: typing.Any) -> bytes:
359         """
360         Join key with secret key
361
362         :param: update_key: str or bytes with the to update
363
364         :return: Joined key
365         """
366 1         return self._join_keys(update_key, self.secret_key)
367
368 1     def _join_keys(self, key: typing.Any, secret_key: bytes) -> bytes:
369         """
370         Join key with secret_key
371
372         :param: key: str or bytesof the key to update
373         :param: secret_key: bytes of the secret key
374
375         :return: Joined key
376         """
377 1         if isinstance(key, str):
378 1             update_key_bytes = key.encode()
379         else:
380 1             update_key_bytes = key
381 1         new_secret_key = bytearray(secret_key) if secret_key else bytearray(32)
382 1         for i, b in enumerate(update_key_bytes):
383 1             new_secret_key[i % 32] ^= b
384 1         return bytes(new_secret_key)
385
386 1     @property
387 1     def secret_key(self):
388 1         return self._secret_key
389
390 1     async def get_secret_key(self):
391         """
392         Get secret key using the database key and the serial key in the DB
393         The key is populated in the property self.secret_key
394         """
395 1         if self.secret_key:
396 1             return
397 1         secret_key = None
398 1         if self.database_key:
399 1             secret_key = self._join_keys(self.database_key, None)
400 1         version_data = await self._admin_collection.find_one({"_id": "version"})
401 1         if version_data and version_data.get("serial"):
402 1             secret_key = self._join_keys(b64decode(version_data["serial"]), secret_key)
403 1         self._secret_key = secret_key
404
405 1     @property
406 1     def database_key(self):
407 1         return self._config["database_commonkey"]