Including upstream requirements
[osm/N2VC.git] / n2vc / store.py
1 # Copyright 2021 Canonical Ltd.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import abc
16 import asyncio
17 from base64 import b64decode
18 import re
19 import typing
20
21 from Crypto.Cipher import AES
22 from motor.motor_asyncio import AsyncIOMotorClient
23 from n2vc.config import EnvironConfig
24 from n2vc.vca.connection_data import ConnectionData
25 from osm_common.dbmongo import DbMongo, DbException
26
27 DB_NAME = "osm"
28
29
30 class Store(abc.ABC):
31 @abc.abstractmethod
32 async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
33 """
34 Get VCA connection data
35
36 :param: vca_id: VCA ID
37
38 :returns: ConnectionData with the information of the database
39 """
40
41 @abc.abstractmethod
42 async def update_vca_endpoints(self, hosts: typing.List[str], vca_id: str):
43 """
44 Update VCA endpoints
45
46 :param: endpoints: List of endpoints to write in the database
47 :param: vca_id: VCA ID
48 """
49
50 @abc.abstractmethod
51 async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
52 """
53 Get list if VCA endpoints
54
55 :param: vca_id: VCA ID
56
57 :returns: List of endpoints
58 """
59
60 @abc.abstractmethod
61 async def get_vca_id(self, vim_id: str = None) -> str:
62 """
63 Get VCA id for a VIM account
64
65 :param: vim_id: Vim account ID
66 """
67
68
69 class DbMongoStore(Store):
70 def __init__(self, db: DbMongo):
71 """
72 Constructor
73
74 :param: db: osm_common.dbmongo.DbMongo object
75 """
76 self.db = db
77
78 async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
79 """
80 Get VCA connection data
81
82 :param: vca_id: VCA ID
83
84 :returns: ConnectionData with the information of the database
85 """
86 data = self.db.get_one("vca", q_filter={"_id": vca_id})
87 self.db.encrypt_decrypt_fields(
88 data,
89 "decrypt",
90 ["secret", "cacert"],
91 schema_version=data["schema_version"],
92 salt=data["_id"],
93 )
94 return ConnectionData(**data)
95
96 async def update_vca_endpoints(
97 self, endpoints: typing.List[str], vca_id: str = None
98 ):
99 """
100 Update VCA endpoints
101
102 :param: endpoints: List of endpoints to write in the database
103 :param: vca_id: VCA ID
104 """
105 if vca_id:
106 data = self.db.get_one("vca", q_filter={"_id": vca_id})
107 data["endpoints"] = endpoints
108 self._update("vca", vca_id, data)
109 else:
110 # The default VCA. Data for the endpoints is in a different place
111 juju_info = self._get_juju_info()
112 # If it doesn't, then create it
113 if not juju_info:
114 try:
115 self.db.create(
116 "vca",
117 {"_id": "juju"},
118 )
119 except DbException as e:
120 # Racing condition: check if another N2VC worker has created it
121 juju_info = self._get_juju_info()
122 if not juju_info:
123 raise e
124 self.db.set_one(
125 "vca",
126 {"_id": "juju"},
127 {"api_endpoints": endpoints},
128 )
129
130 async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
131 """
132 Get list if VCA endpoints
133
134 :param: vca_id: VCA ID
135
136 :returns: List of endpoints
137 """
138 endpoints = []
139 if vca_id:
140 endpoints = self.get_vca_connection_data(vca_id).endpoints
141 else:
142 juju_info = self._get_juju_info()
143 if juju_info and "api_endpoints" in juju_info:
144 endpoints = juju_info["api_endpoints"]
145 return endpoints
146
147 async def get_vca_id(self, vim_id: str = None) -> str:
148 """
149 Get VCA ID from the database for a given VIM account ID
150
151 :param: vim_id: VIM account ID
152 """
153 return (
154 self.db.get_one(
155 "vim_accounts",
156 q_filter={"_id": vim_id},
157 fail_on_empty=False,
158 ).get("vca")
159 if vim_id
160 else None
161 )
162
163 def _update(self, collection: str, id: str, data: dict):
164 """
165 Update object in database
166
167 :param: collection: Collection name
168 :param: id: ID of the object
169 :param: data: Object data
170 """
171 self.db.replace(
172 collection,
173 id,
174 data,
175 )
176
177 def _get_juju_info(self):
178 """Get Juju information (the default VCA) from the admin collection"""
179 return self.db.get_one(
180 "vca",
181 q_filter={"_id": "juju"},
182 fail_on_empty=False,
183 )
184
185
186 class MotorStore(Store):
187 def __init__(self, uri: str, loop=None):
188 """
189 Constructor
190
191 :param: uri: Connection string to connect to the database.
192 :param: loop: Asyncio Loop
193 """
194 self._client = AsyncIOMotorClient(uri)
195 self.loop = loop or asyncio.get_event_loop()
196 self._secret_key = None
197 self._config = EnvironConfig(prefixes=["OSMLCM_", "OSMMON_"])
198
199 @property
200 def _database(self):
201 return self._client[DB_NAME]
202
203 @property
204 def _vca_collection(self):
205 return self._database["vca"]
206
207 @property
208 def _admin_collection(self):
209 return self._database["admin"]
210
211 @property
212 def _vim_accounts_collection(self):
213 return self._database["vim_accounts"]
214
215 async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
216 """
217 Get VCA connection data
218
219 :param: vca_id: VCA ID
220
221 :returns: ConnectionData with the information of the database
222 """
223 data = await self._vca_collection.find_one({"_id": vca_id})
224 if not data:
225 raise Exception("vca with id {} not found".format(vca_id))
226 await self.decrypt_fields(
227 data,
228 ["secret", "cacert"],
229 schema_version=data["schema_version"],
230 salt=data["_id"],
231 )
232 return ConnectionData(**data)
233
234 async def update_vca_endpoints(
235 self, endpoints: typing.List[str], vca_id: str = None
236 ):
237 """
238 Update VCA endpoints
239
240 :param: endpoints: List of endpoints to write in the database
241 :param: vca_id: VCA ID
242 """
243 if vca_id:
244 data = await self._vca_collection.find_one({"_id": vca_id})
245 data["endpoints"] = endpoints
246 await self._vca_collection.replace_one({"_id": vca_id}, data)
247 else:
248 # The default VCA. Data for the endpoints is in a different place
249 juju_info = await self._get_juju_info()
250 # If it doesn't, then create it
251 if not juju_info:
252 try:
253 await self._admin_collection.insert_one({"_id": "juju"})
254 except Exception as e:
255 # Racing condition: check if another N2VC worker has created it
256 juju_info = await self._get_juju_info()
257 if not juju_info:
258 raise e
259
260 await self._admin_collection.replace_one(
261 {"_id": "juju"}, {"api_endpoints": endpoints}
262 )
263
264 async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
265 """
266 Get list if VCA endpoints
267
268 :param: vca_id: VCA ID
269
270 :returns: List of endpoints
271 """
272 endpoints = []
273 if vca_id:
274 endpoints = (await self.get_vca_connection_data(vca_id)).endpoints
275 else:
276 juju_info = await self._get_juju_info()
277 if juju_info and "api_endpoints" in juju_info:
278 endpoints = juju_info["api_endpoints"]
279 return endpoints
280
281 async def get_vca_id(self, vim_id: str = None) -> str:
282 """
283 Get VCA ID from the database for a given VIM account ID
284
285 :param: vim_id: VIM account ID
286 """
287 vca_id = None
288 if vim_id:
289 vim_account = await self._vim_accounts_collection.find_one({"_id": vim_id})
290 if vim_account and "vca" in vim_account:
291 vca_id = vim_account["vca"]
292 return vca_id
293
294 async def _get_juju_info(self):
295 """Get Juju information (the default VCA) from the admin collection"""
296 return await self._admin_collection.find_one({"_id": "juju"})
297
298 # DECRYPT METHODS
299 async def decrypt_fields(
300 self,
301 item: dict,
302 fields: typing.List[str],
303 schema_version: str = None,
304 salt: str = None,
305 ):
306 """
307 Decrypt fields
308
309 Decrypt fields from a dictionary. Follows the same logic as in osm_common.
310
311 :param: item: Dictionary with the keys to be decrypted
312 :param: fields: List of keys to decrypt
313 :param: schema version: Schema version. (i.e. 1.11)
314 :param: salt: Salt for the decryption
315 """
316 flags = re.I
317
318 async def process(_item):
319 if isinstance(_item, list):
320 for elem in _item:
321 await process(elem)
322 elif isinstance(_item, dict):
323 for key, val in _item.items():
324 if isinstance(val, str):
325 if any(re.search(f, key, flags) for f in fields):
326 _item[key] = await self.decrypt(val, schema_version, salt)
327 else:
328 await process(val)
329
330 await process(item)
331
332 async def decrypt(self, value, schema_version=None, salt=None):
333 """
334 Decrypt an encrypted value
335 :param value: value to be decrypted. It is a base64 string
336 :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done.
337 If '1.1' symmetric AES encryption has been done
338 :param salt: optional salt to be used
339 :return: Plain content of value
340 """
341 if not await self.secret_key or not schema_version or schema_version == "1.0":
342 return value
343 else:
344 secret_key = self._join_secret_key(salt)
345 encrypted_msg = b64decode(value)
346 cipher = AES.new(secret_key)
347 decrypted_msg = cipher.decrypt(encrypted_msg)
348 try:
349 unpadded_private_msg = decrypted_msg.decode().rstrip("\0")
350 except UnicodeDecodeError:
351 raise DbException(
352 "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
353 http_code=500,
354 )
355 return unpadded_private_msg
356
357 def _join_secret_key(self, update_key: typing.Any):
358 """
359 Join secret key
360
361 :param: update_key: str or bytes with the to update
362 """
363 if isinstance(update_key, str):
364 update_key_bytes = update_key.encode()
365 else:
366 update_key_bytes = update_key
367 new_secret_key = (
368 bytearray(self._secret_key) if self._secret_key else bytearray(32)
369 )
370 for i, b in enumerate(update_key_bytes):
371 new_secret_key[i % 32] ^= b
372 return bytes(new_secret_key)
373
374 @property
375 async def secret_key(self):
376 if self._secret_key:
377 return self._secret_key
378 else:
379 if self.database_key:
380 self._secret_key = self._join_secret_key(self.database_key)
381 version_data = await self._admin_collection.find_one({"_id": "version"})
382 if version_data and version_data.get("serial"):
383 self._secret_key = self._join_secret_key(
384 b64decode(version_data["serial"])
385 )
386 return self._secret_key
387
388 @property
389 def database_key(self):
390 return self._config["database_commonkey"]