blob: cd6c6fbf8c6ce36ad3b1cbb7922513ce6a077d77 [file] [log] [blame]
David Garciaeb8943a2021-04-12 12:07:37 +02001# Copyright 2021 Canonical Ltd.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import abc
16import asyncio
17from base64 import b64decode
18import re
19import typing
20
21from Crypto.Cipher import AES
22from motor.motor_asyncio import AsyncIOMotorClient
23from n2vc.config import EnvironConfig
24from n2vc.vca.connection_data import ConnectionData
25from osm_common.dbmongo import DbMongo, DbException
26
27DB_NAME = "osm"
28
29
30class Store(abc.ABC):
31 @abc.abstractmethod
32 async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
33 """
34 Get VCA connection data
35
36 :param: vca_id: VCA ID
37
38 :returns: ConnectionData with the information of the database
39 """
40
41 @abc.abstractmethod
42 async def update_vca_endpoints(self, hosts: typing.List[str], vca_id: str):
43 """
44 Update VCA endpoints
45
46 :param: endpoints: List of endpoints to write in the database
47 :param: vca_id: VCA ID
48 """
49
50 @abc.abstractmethod
51 async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
52 """
53 Get list if VCA endpoints
54
55 :param: vca_id: VCA ID
56
57 :returns: List of endpoints
58 """
59
60 @abc.abstractmethod
61 async def get_vca_id(self, vim_id: str = None) -> str:
62 """
63 Get VCA id for a VIM account
64
65 :param: vim_id: Vim account ID
66 """
67
68
69class DbMongoStore(Store):
70 def __init__(self, db: DbMongo):
71 """
72 Constructor
73
74 :param: db: osm_common.dbmongo.DbMongo object
75 """
76 self.db = db
77
78 async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
79 """
80 Get VCA connection data
81
82 :param: vca_id: VCA ID
83
84 :returns: ConnectionData with the information of the database
85 """
86 data = self.db.get_one("vca", q_filter={"_id": vca_id})
87 self.db.encrypt_decrypt_fields(
88 data,
89 "decrypt",
90 ["secret", "cacert"],
91 schema_version=data["schema_version"],
92 salt=data["_id"],
93 )
94 return ConnectionData(**data)
95
96 async def update_vca_endpoints(
97 self, endpoints: typing.List[str], vca_id: str = None
98 ):
99 """
100 Update VCA endpoints
101
102 :param: endpoints: List of endpoints to write in the database
103 :param: vca_id: VCA ID
104 """
105 if vca_id:
106 data = self.db.get_one("vca", q_filter={"_id": vca_id})
107 data["endpoints"] = endpoints
108 self._update("vca", vca_id, data)
109 else:
110 # The default VCA. Data for the endpoints is in a different place
111 juju_info = self._get_juju_info()
112 # If it doesn't, then create it
113 if not juju_info:
114 try:
115 self.db.create(
116 "vca",
117 {"_id": "juju"},
118 )
119 except DbException as e:
120 # Racing condition: check if another N2VC worker has created it
121 juju_info = self._get_juju_info()
122 if not juju_info:
123 raise e
124 self.db.set_one(
125 "vca",
126 {"_id": "juju"},
127 {"api_endpoints": endpoints},
128 )
129
130 async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
131 """
132 Get list if VCA endpoints
133
134 :param: vca_id: VCA ID
135
136 :returns: List of endpoints
137 """
138 endpoints = []
139 if vca_id:
140 endpoints = self.get_vca_connection_data(vca_id).endpoints
141 else:
142 juju_info = self._get_juju_info()
143 if juju_info and "api_endpoints" in juju_info:
144 endpoints = juju_info["api_endpoints"]
145 return endpoints
146
147 async def get_vca_id(self, vim_id: str = None) -> str:
148 """
149 Get VCA ID from the database for a given VIM account ID
150
151 :param: vim_id: VIM account ID
152 """
153 return (
154 self.db.get_one(
155 "vim_accounts",
156 q_filter={"_id": vim_id},
157 fail_on_empty=False,
158 ).get("vca")
159 if vim_id
160 else None
161 )
162
163 def _update(self, collection: str, id: str, data: dict):
164 """
165 Update object in database
166
167 :param: collection: Collection name
168 :param: id: ID of the object
169 :param: data: Object data
170 """
171 self.db.replace(
172 collection,
173 id,
174 data,
175 )
176
177 def _get_juju_info(self):
178 """Get Juju information (the default VCA) from the admin collection"""
179 return self.db.get_one(
180 "vca",
181 q_filter={"_id": "juju"},
182 fail_on_empty=False,
183 )
184
185
186class MotorStore(Store):
187 def __init__(self, uri: str, loop=None):
188 """
189 Constructor
190
191 :param: uri: Connection string to connect to the database.
192 :param: loop: Asyncio Loop
193 """
194 self._client = AsyncIOMotorClient(uri)
195 self.loop = loop or asyncio.get_event_loop()
196 self._secret_key = None
197 self._config = EnvironConfig(prefixes=["OSMLCM_", "OSMMON_"])
198
199 @property
200 def _database(self):
201 return self._client[DB_NAME]
202
203 @property
204 def _vca_collection(self):
205 return self._database["vca"]
206
207 @property
208 def _admin_collection(self):
209 return self._database["admin"]
210
211 @property
212 def _vim_accounts_collection(self):
213 return self._database["vim_accounts"]
214
215 async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
216 """
217 Get VCA connection data
218
219 :param: vca_id: VCA ID
220
221 :returns: ConnectionData with the information of the database
222 """
223 data = await self._vca_collection.find_one({"_id": vca_id})
224 if not data:
225 raise Exception("vca with id {} not found".format(vca_id))
226 await self.decrypt_fields(
227 data,
228 ["secret", "cacert"],
229 schema_version=data["schema_version"],
230 salt=data["_id"],
231 )
232 return ConnectionData(**data)
233
234 async def update_vca_endpoints(
235 self, endpoints: typing.List[str], vca_id: str = None
236 ):
237 """
238 Update VCA endpoints
239
240 :param: endpoints: List of endpoints to write in the database
241 :param: vca_id: VCA ID
242 """
243 if vca_id:
244 data = await self._vca_collection.find_one({"_id": vca_id})
245 data["endpoints"] = endpoints
246 await self._vca_collection.replace_one({"_id": vca_id}, data)
247 else:
248 # The default VCA. Data for the endpoints is in a different place
249 juju_info = await self._get_juju_info()
250 # If it doesn't, then create it
251 if not juju_info:
252 try:
253 await self._admin_collection.insert_one({"_id": "juju"})
254 except Exception as e:
255 # Racing condition: check if another N2VC worker has created it
256 juju_info = await self._get_juju_info()
257 if not juju_info:
258 raise e
259
260 await self._admin_collection.replace_one(
261 {"_id": "juju"}, {"api_endpoints": endpoints}
262 )
263
264 async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
265 """
266 Get list if VCA endpoints
267
268 :param: vca_id: VCA ID
269
270 :returns: List of endpoints
271 """
272 endpoints = []
273 if vca_id:
274 endpoints = (await self.get_vca_connection_data(vca_id)).endpoints
275 else:
276 juju_info = await self._get_juju_info()
277 if juju_info and "api_endpoints" in juju_info:
278 endpoints = juju_info["api_endpoints"]
279 return endpoints
280
281 async def get_vca_id(self, vim_id: str = None) -> str:
282 """
283 Get VCA ID from the database for a given VIM account ID
284
285 :param: vim_id: VIM account ID
286 """
287 vca_id = None
288 if vim_id:
289 vim_account = await self._vim_accounts_collection.find_one({"_id": vim_id})
290 if vim_account and "vca" in vim_account:
291 vca_id = vim_account["vca"]
292 return vca_id
293
294 async def _get_juju_info(self):
295 """Get Juju information (the default VCA) from the admin collection"""
296 return await self._admin_collection.find_one({"_id": "juju"})
297
298 # DECRYPT METHODS
299 async def decrypt_fields(
300 self,
301 item: dict,
302 fields: typing.List[str],
303 schema_version: str = None,
304 salt: str = None,
305 ):
306 """
307 Decrypt fields
308
309 Decrypt fields from a dictionary. Follows the same logic as in osm_common.
310
311 :param: item: Dictionary with the keys to be decrypted
312 :param: fields: List of keys to decrypt
313 :param: schema version: Schema version. (i.e. 1.11)
314 :param: salt: Salt for the decryption
315 """
316 flags = re.I
317
318 async def process(_item):
319 if isinstance(_item, list):
320 for elem in _item:
321 await process(elem)
322 elif isinstance(_item, dict):
323 for key, val in _item.items():
324 if isinstance(val, str):
325 if any(re.search(f, key, flags) for f in fields):
326 _item[key] = await self.decrypt(val, schema_version, salt)
327 else:
328 await process(val)
329
330 await process(item)
331
332 async def decrypt(self, value, schema_version=None, salt=None):
333 """
334 Decrypt an encrypted value
335 :param value: value to be decrypted. It is a base64 string
336 :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done.
337 If '1.1' symmetric AES encryption has been done
338 :param salt: optional salt to be used
339 :return: Plain content of value
340 """
David Garciad80f0382021-06-21 13:20:30 +0200341 await self.get_secret_key()
342 if not self.secret_key or not schema_version or schema_version == "1.0":
David Garciaeb8943a2021-04-12 12:07:37 +0200343 return value
344 else:
345 secret_key = self._join_secret_key(salt)
346 encrypted_msg = b64decode(value)
347 cipher = AES.new(secret_key)
348 decrypted_msg = cipher.decrypt(encrypted_msg)
349 try:
350 unpadded_private_msg = decrypted_msg.decode().rstrip("\0")
351 except UnicodeDecodeError:
352 raise DbException(
353 "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
354 http_code=500,
355 )
356 return unpadded_private_msg
357
David Garciad80f0382021-06-21 13:20:30 +0200358 def _join_secret_key(self, update_key: typing.Any) -> bytes:
David Garciaeb8943a2021-04-12 12:07:37 +0200359 """
David Garciad80f0382021-06-21 13:20:30 +0200360 Join key with secret key
David Garciaeb8943a2021-04-12 12:07:37 +0200361
362 :param: update_key: str or bytes with the to update
David Garciad80f0382021-06-21 13:20:30 +0200363
364 :return: Joined key
David Garciaeb8943a2021-04-12 12:07:37 +0200365 """
David Garciad80f0382021-06-21 13:20:30 +0200366 return self._join_keys(update_key, self.secret_key)
367
368 def _join_keys(self, key: typing.Any, secret_key: bytes) -> bytes:
369 """
370 Join key with secret_key
371
372 :param: key: str or bytesof the key to update
373 :param: secret_key: bytes of the secret key
374
375 :return: Joined key
376 """
377 if isinstance(key, str):
378 update_key_bytes = key.encode()
David Garciaeb8943a2021-04-12 12:07:37 +0200379 else:
David Garciad80f0382021-06-21 13:20:30 +0200380 update_key_bytes = key
381 new_secret_key = bytearray(secret_key) if secret_key else bytearray(32)
David Garciaeb8943a2021-04-12 12:07:37 +0200382 for i, b in enumerate(update_key_bytes):
383 new_secret_key[i % 32] ^= b
384 return bytes(new_secret_key)
385
386 @property
David Garciad80f0382021-06-21 13:20:30 +0200387 def secret_key(self):
388 return self._secret_key
389
390 async def get_secret_key(self):
391 """
392 Get secret key using the database key and the serial key in the DB
393 The key is populated in the property self.secret_key
394 """
395 if self.secret_key:
396 return
397 secret_key = None
398 if self.database_key:
399 secret_key = self._join_keys(self.database_key, None)
400 version_data = await self._admin_collection.find_one({"_id": "version"})
401 if version_data and version_data.get("serial"):
402 secret_key = self._join_keys(b64decode(version_data["serial"]), secret_key)
403 self._secret_key = secret_key
David Garciaeb8943a2021-04-12 12:07:37 +0200404
405 @property
406 def database_key(self):
407 return self._config["database_commonkey"]