1 |
|
# Copyright 2021 Canonical Ltd. |
2 |
|
# |
3 |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
4 |
|
# you may not use this file except in compliance with the License. |
5 |
|
# You may obtain a copy of the License at |
6 |
|
# |
7 |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
8 |
|
# |
9 |
|
# Unless required by applicable law or agreed to in writing, software |
10 |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
11 |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 |
|
# See the License for the specific language governing permissions and |
13 |
|
# limitations under the License. |
14 |
|
|
15 |
1 |
import abc |
16 |
1 |
import asyncio |
17 |
1 |
from base64 import b64decode |
18 |
1 |
import re |
19 |
1 |
import typing |
20 |
|
|
21 |
1 |
from Crypto.Cipher import AES |
22 |
1 |
from motor.motor_asyncio import AsyncIOMotorClient |
23 |
1 |
from n2vc.config import EnvironConfig |
24 |
1 |
from n2vc.vca.connection_data import ConnectionData |
25 |
1 |
from osm_common.dbmongo import DbMongo, DbException |
26 |
|
|
27 |
1 |
DB_NAME = "osm" |
28 |
|
|
29 |
|
|
30 |
1 |
class Store(abc.ABC): |
31 |
1 |
@abc.abstractmethod |
32 |
1 |
async def get_vca_connection_data(self, vca_id: str) -> ConnectionData: |
33 |
|
""" |
34 |
|
Get VCA connection data |
35 |
|
|
36 |
|
:param: vca_id: VCA ID |
37 |
|
|
38 |
|
:returns: ConnectionData with the information of the database |
39 |
|
""" |
40 |
|
|
41 |
1 |
@abc.abstractmethod |
42 |
1 |
async def update_vca_endpoints(self, hosts: typing.List[str], vca_id: str): |
43 |
|
""" |
44 |
|
Update VCA endpoints |
45 |
|
|
46 |
|
:param: endpoints: List of endpoints to write in the database |
47 |
|
:param: vca_id: VCA ID |
48 |
|
""" |
49 |
|
|
50 |
1 |
@abc.abstractmethod |
51 |
1 |
async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]: |
52 |
|
""" |
53 |
|
Get list if VCA endpoints |
54 |
|
|
55 |
|
:param: vca_id: VCA ID |
56 |
|
|
57 |
|
:returns: List of endpoints |
58 |
|
""" |
59 |
|
|
60 |
1 |
@abc.abstractmethod |
61 |
1 |
async def get_vca_id(self, vim_id: str = None) -> str: |
62 |
|
""" |
63 |
|
Get VCA id for a VIM account |
64 |
|
|
65 |
|
:param: vim_id: Vim account ID |
66 |
|
""" |
67 |
|
|
68 |
|
|
69 |
1 |
class DbMongoStore(Store): |
70 |
1 |
def __init__(self, db: DbMongo): |
71 |
|
""" |
72 |
|
Constructor |
73 |
|
|
74 |
|
:param: db: osm_common.dbmongo.DbMongo object |
75 |
|
""" |
76 |
1 |
self.db = db |
77 |
|
|
78 |
1 |
async def get_vca_connection_data(self, vca_id: str) -> ConnectionData: |
79 |
|
""" |
80 |
|
Get VCA connection data |
81 |
|
|
82 |
|
:param: vca_id: VCA ID |
83 |
|
|
84 |
|
:returns: ConnectionData with the information of the database |
85 |
|
""" |
86 |
1 |
data = self.db.get_one("vca", q_filter={"_id": vca_id}) |
87 |
1 |
self.db.encrypt_decrypt_fields( |
88 |
|
data, |
89 |
|
"decrypt", |
90 |
|
["secret", "cacert"], |
91 |
|
schema_version=data["schema_version"], |
92 |
|
salt=data["_id"], |
93 |
|
) |
94 |
1 |
return ConnectionData(**data) |
95 |
|
|
96 |
1 |
async def update_vca_endpoints( |
97 |
|
self, endpoints: typing.List[str], vca_id: str = None |
98 |
|
): |
99 |
|
""" |
100 |
|
Update VCA endpoints |
101 |
|
|
102 |
|
:param: endpoints: List of endpoints to write in the database |
103 |
|
:param: vca_id: VCA ID |
104 |
|
""" |
105 |
1 |
if vca_id: |
106 |
1 |
data = self.db.get_one("vca", q_filter={"_id": vca_id}) |
107 |
1 |
data["endpoints"] = endpoints |
108 |
1 |
self._update("vca", vca_id, data) |
109 |
|
else: |
110 |
|
# The default VCA. Data for the endpoints is in a different place |
111 |
1 |
juju_info = self._get_juju_info() |
112 |
|
# If it doesn't, then create it |
113 |
1 |
if not juju_info: |
114 |
1 |
try: |
115 |
1 |
self.db.create( |
116 |
|
"vca", |
117 |
|
{"_id": "juju"}, |
118 |
|
) |
119 |
1 |
except DbException as e: |
120 |
|
# Racing condition: check if another N2VC worker has created it |
121 |
1 |
juju_info = self._get_juju_info() |
122 |
1 |
if not juju_info: |
123 |
1 |
raise e |
124 |
1 |
self.db.set_one( |
125 |
|
"vca", |
126 |
|
{"_id": "juju"}, |
127 |
|
{"api_endpoints": endpoints}, |
128 |
|
) |
129 |
|
|
130 |
1 |
async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]: |
131 |
|
""" |
132 |
|
Get list if VCA endpoints |
133 |
|
|
134 |
|
:param: vca_id: VCA ID |
135 |
|
|
136 |
|
:returns: List of endpoints |
137 |
|
""" |
138 |
1 |
endpoints = [] |
139 |
1 |
if vca_id: |
140 |
1 |
endpoints = self.get_vca_connection_data(vca_id).endpoints |
141 |
|
else: |
142 |
1 |
juju_info = self._get_juju_info() |
143 |
1 |
if juju_info and "api_endpoints" in juju_info: |
144 |
1 |
endpoints = juju_info["api_endpoints"] |
145 |
1 |
return endpoints |
146 |
|
|
147 |
1 |
async def get_vca_id(self, vim_id: str = None) -> str: |
148 |
|
""" |
149 |
|
Get VCA ID from the database for a given VIM account ID |
150 |
|
|
151 |
|
:param: vim_id: VIM account ID |
152 |
|
""" |
153 |
1 |
return ( |
154 |
|
self.db.get_one( |
155 |
|
"vim_accounts", |
156 |
|
q_filter={"_id": vim_id}, |
157 |
|
fail_on_empty=False, |
158 |
|
).get("vca") |
159 |
|
if vim_id |
160 |
|
else None |
161 |
|
) |
162 |
|
|
163 |
1 |
def _update(self, collection: str, id: str, data: dict): |
164 |
|
""" |
165 |
|
Update object in database |
166 |
|
|
167 |
|
:param: collection: Collection name |
168 |
|
:param: id: ID of the object |
169 |
|
:param: data: Object data |
170 |
|
""" |
171 |
1 |
self.db.replace( |
172 |
|
collection, |
173 |
|
id, |
174 |
|
data, |
175 |
|
) |
176 |
|
|
177 |
1 |
def _get_juju_info(self): |
178 |
|
"""Get Juju information (the default VCA) from the admin collection""" |
179 |
1 |
return self.db.get_one( |
180 |
|
"vca", |
181 |
|
q_filter={"_id": "juju"}, |
182 |
|
fail_on_empty=False, |
183 |
|
) |
184 |
|
|
185 |
|
|
186 |
1 |
class MotorStore(Store): |
187 |
1 |
def __init__(self, uri: str, loop=None): |
188 |
|
""" |
189 |
|
Constructor |
190 |
|
|
191 |
|
:param: uri: Connection string to connect to the database. |
192 |
|
:param: loop: Asyncio Loop |
193 |
|
""" |
194 |
1 |
self._client = AsyncIOMotorClient(uri) |
195 |
1 |
self.loop = loop or asyncio.get_event_loop() |
196 |
1 |
self._secret_key = None |
197 |
1 |
self._config = EnvironConfig(prefixes=["OSMLCM_", "OSMMON_"]) |
198 |
|
|
199 |
1 |
@property |
200 |
1 |
def _database(self): |
201 |
1 |
return self._client[DB_NAME] |
202 |
|
|
203 |
1 |
@property |
204 |
1 |
def _vca_collection(self): |
205 |
1 |
return self._database["vca"] |
206 |
|
|
207 |
1 |
@property |
208 |
1 |
def _admin_collection(self): |
209 |
1 |
return self._database["admin"] |
210 |
|
|
211 |
1 |
@property |
212 |
1 |
def _vim_accounts_collection(self): |
213 |
1 |
return self._database["vim_accounts"] |
214 |
|
|
215 |
1 |
async def get_vca_connection_data(self, vca_id: str) -> ConnectionData: |
216 |
|
""" |
217 |
|
Get VCA connection data |
218 |
|
|
219 |
|
:param: vca_id: VCA ID |
220 |
|
|
221 |
|
:returns: ConnectionData with the information of the database |
222 |
|
""" |
223 |
1 |
data = await self._vca_collection.find_one({"_id": vca_id}) |
224 |
1 |
if not data: |
225 |
1 |
raise Exception("vca with id {} not found".format(vca_id)) |
226 |
1 |
await self.decrypt_fields( |
227 |
|
data, |
228 |
|
["secret", "cacert"], |
229 |
|
schema_version=data["schema_version"], |
230 |
|
salt=data["_id"], |
231 |
|
) |
232 |
1 |
return ConnectionData(**data) |
233 |
|
|
234 |
1 |
async def update_vca_endpoints( |
235 |
|
self, endpoints: typing.List[str], vca_id: str = None |
236 |
|
): |
237 |
|
""" |
238 |
|
Update VCA endpoints |
239 |
|
|
240 |
|
:param: endpoints: List of endpoints to write in the database |
241 |
|
:param: vca_id: VCA ID |
242 |
|
""" |
243 |
1 |
if vca_id: |
244 |
1 |
data = await self._vca_collection.find_one({"_id": vca_id}) |
245 |
1 |
data["endpoints"] = endpoints |
246 |
1 |
await self._vca_collection.replace_one({"_id": vca_id}, data) |
247 |
|
else: |
248 |
|
# The default VCA. Data for the endpoints is in a different place |
249 |
1 |
juju_info = await self._get_juju_info() |
250 |
|
# If it doesn't, then create it |
251 |
1 |
if not juju_info: |
252 |
1 |
try: |
253 |
1 |
await self._admin_collection.insert_one({"_id": "juju"}) |
254 |
1 |
except Exception as e: |
255 |
|
# Racing condition: check if another N2VC worker has created it |
256 |
1 |
juju_info = await self._get_juju_info() |
257 |
1 |
if not juju_info: |
258 |
1 |
raise e |
259 |
|
|
260 |
1 |
await self._admin_collection.replace_one( |
261 |
|
{"_id": "juju"}, {"api_endpoints": endpoints} |
262 |
|
) |
263 |
|
|
264 |
1 |
async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]: |
265 |
|
""" |
266 |
|
Get list if VCA endpoints |
267 |
|
|
268 |
|
:param: vca_id: VCA ID |
269 |
|
|
270 |
|
:returns: List of endpoints |
271 |
|
""" |
272 |
1 |
endpoints = [] |
273 |
1 |
if vca_id: |
274 |
1 |
endpoints = (await self.get_vca_connection_data(vca_id)).endpoints |
275 |
|
else: |
276 |
1 |
juju_info = await self._get_juju_info() |
277 |
1 |
if juju_info and "api_endpoints" in juju_info: |
278 |
1 |
endpoints = juju_info["api_endpoints"] |
279 |
1 |
return endpoints |
280 |
|
|
281 |
1 |
async def get_vca_id(self, vim_id: str = None) -> str: |
282 |
|
""" |
283 |
|
Get VCA ID from the database for a given VIM account ID |
284 |
|
|
285 |
|
:param: vim_id: VIM account ID |
286 |
|
""" |
287 |
1 |
vca_id = None |
288 |
1 |
if vim_id: |
289 |
1 |
vim_account = await self._vim_accounts_collection.find_one({"_id": vim_id}) |
290 |
1 |
if vim_account and "vca" in vim_account: |
291 |
1 |
vca_id = vim_account["vca"] |
292 |
1 |
return vca_id |
293 |
|
|
294 |
1 |
async def _get_juju_info(self): |
295 |
|
"""Get Juju information (the default VCA) from the admin collection""" |
296 |
1 |
return await self._admin_collection.find_one({"_id": "juju"}) |
297 |
|
|
298 |
|
# DECRYPT METHODS |
299 |
1 |
async def decrypt_fields( |
300 |
|
self, |
301 |
|
item: dict, |
302 |
|
fields: typing.List[str], |
303 |
|
schema_version: str = None, |
304 |
|
salt: str = None, |
305 |
|
): |
306 |
|
""" |
307 |
|
Decrypt fields |
308 |
|
|
309 |
|
Decrypt fields from a dictionary. Follows the same logic as in osm_common. |
310 |
|
|
311 |
|
:param: item: Dictionary with the keys to be decrypted |
312 |
|
:param: fields: List of keys to decrypt |
313 |
|
:param: schema version: Schema version. (i.e. 1.11) |
314 |
|
:param: salt: Salt for the decryption |
315 |
|
""" |
316 |
1 |
flags = re.I |
317 |
|
|
318 |
1 |
async def process(_item): |
319 |
1 |
if isinstance(_item, list): |
320 |
1 |
for elem in _item: |
321 |
0 |
await process(elem) |
322 |
1 |
elif isinstance(_item, dict): |
323 |
1 |
for key, val in _item.items(): |
324 |
1 |
if isinstance(val, str): |
325 |
1 |
if any(re.search(f, key, flags) for f in fields): |
326 |
1 |
_item[key] = await self.decrypt(val, schema_version, salt) |
327 |
|
else: |
328 |
1 |
await process(val) |
329 |
|
|
330 |
1 |
await process(item) |
331 |
|
|
332 |
1 |
async def decrypt(self, value, schema_version=None, salt=None): |
333 |
|
""" |
334 |
|
Decrypt an encrypted value |
335 |
|
:param value: value to be decrypted. It is a base64 string |
336 |
|
:param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done. |
337 |
|
If '1.1' symmetric AES encryption has been done |
338 |
|
:param salt: optional salt to be used |
339 |
|
:return: Plain content of value |
340 |
|
""" |
341 |
1 |
await self.get_secret_key() |
342 |
1 |
if not self.secret_key or not schema_version or schema_version == "1.0": |
343 |
0 |
return value |
344 |
|
else: |
345 |
1 |
secret_key = self._join_secret_key(salt) |
346 |
1 |
encrypted_msg = b64decode(value) |
347 |
1 |
cipher = AES.new(secret_key) |
348 |
1 |
decrypted_msg = cipher.decrypt(encrypted_msg) |
349 |
1 |
try: |
350 |
1 |
unpadded_private_msg = decrypted_msg.decode().rstrip("\0") |
351 |
0 |
except UnicodeDecodeError: |
352 |
0 |
raise DbException( |
353 |
|
"Cannot decrypt information. Are you using same COMMONKEY in all OSM components?", |
354 |
|
http_code=500, |
355 |
|
) |
356 |
1 |
return unpadded_private_msg |
357 |
|
|
358 |
1 |
def _join_secret_key(self, update_key: typing.Any) -> bytes: |
359 |
|
""" |
360 |
|
Join key with secret key |
361 |
|
|
362 |
|
:param: update_key: str or bytes with the to update |
363 |
|
|
364 |
|
:return: Joined key |
365 |
|
""" |
366 |
1 |
return self._join_keys(update_key, self.secret_key) |
367 |
|
|
368 |
1 |
def _join_keys(self, key: typing.Any, secret_key: bytes) -> bytes: |
369 |
|
""" |
370 |
|
Join key with secret_key |
371 |
|
|
372 |
|
:param: key: str or bytesof the key to update |
373 |
|
:param: secret_key: bytes of the secret key |
374 |
|
|
375 |
|
:return: Joined key |
376 |
|
""" |
377 |
1 |
if isinstance(key, str): |
378 |
1 |
update_key_bytes = key.encode() |
379 |
|
else: |
380 |
1 |
update_key_bytes = key |
381 |
1 |
new_secret_key = bytearray(secret_key) if secret_key else bytearray(32) |
382 |
1 |
for i, b in enumerate(update_key_bytes): |
383 |
1 |
new_secret_key[i % 32] ^= b |
384 |
1 |
return bytes(new_secret_key) |
385 |
|
|
386 |
1 |
@property |
387 |
1 |
def secret_key(self): |
388 |
1 |
return self._secret_key |
389 |
|
|
390 |
1 |
async def get_secret_key(self): |
391 |
|
""" |
392 |
|
Get secret key using the database key and the serial key in the DB |
393 |
|
The key is populated in the property self.secret_key |
394 |
|
""" |
395 |
1 |
if self.secret_key: |
396 |
1 |
return |
397 |
1 |
secret_key = None |
398 |
1 |
if self.database_key: |
399 |
1 |
secret_key = self._join_keys(self.database_key, None) |
400 |
1 |
version_data = await self._admin_collection.find_one({"_id": "version"}) |
401 |
1 |
if version_data and version_data.get("serial"): |
402 |
1 |
secret_key = self._join_keys(b64decode(version_data["serial"]), secret_key) |
403 |
1 |
self._secret_key = secret_key |
404 |
|
|
405 |
1 |
@property |
406 |
1 |
def database_key(self): |
407 |
1 |
return self._config["database_commonkey"] |