c09468c65eef6f1d265fc9107f99ebd1a05bfa8b
[osm/N2VC.git] / modules / libjuju / juju / client / connection.py
1 import base64
2 import io
3 import json
4 import logging
5 import os
6 import random
7 import shlex
8 import ssl
9 import string
10 import subprocess
11 import weakref
12 import websockets
13 from concurrent.futures import CancelledError
14 from http.client import HTTPSConnection
15 from pathlib import Path
16
17 import asyncio
18 import yaml
19
20 from juju import tag, utils
21 from juju.client import client
22 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
23 from juju.utils import IdQueue
24
25 log = logging.getLogger("websocket")
26
27
28 class Monitor:
29 """
30 Monitor helper class for our Connection class.
31
32 Contains a reference to an instantiated Connection, along with a
33 reference to the Connection.receiver Future. Upon inspecttion of
34 these objects, this class determines whether the connection is in
35 an 'error', 'connected' or 'disconnected' state.
36
37 Use this class to stay up to date on the health of a connection,
38 and take appropriate action if the connection errors out due to
39 network issues or other unexpected circumstances.
40
41 """
42 ERROR = 'error'
43 CONNECTED = 'connected'
44 DISCONNECTING = 'disconnecting'
45 DISCONNECTED = 'disconnected'
46
47 def __init__(self, connection):
48 self.connection = weakref.ref(connection)
49 self.reconnecting = asyncio.Lock(loop=connection.loop)
50 self.close_called = asyncio.Event(loop=connection.loop)
51 self.receiver_stopped = asyncio.Event(loop=connection.loop)
52 self.pinger_stopped = asyncio.Event(loop=connection.loop)
53 self.receiver_stopped.set()
54 self.pinger_stopped.set()
55
56 @property
57 def status(self):
58 """
59 Determine the status of the connection and receiver, and return
60 ERROR, CONNECTED, or DISCONNECTED as appropriate.
61
62 For simplicity, we only consider ourselves to be connected
63 after the Connection class has setup a receiver task. This
64 only happens after the websocket is open, and the connection
65 isn't usable until that receiver has been started.
66
67 """
68 connection = self.connection()
69
70 # the connection instance was destroyed but someone kept
71 # a separate reference to the monitor for some reason
72 if not connection:
73 return self.DISCONNECTED
74
75 # connection cleanly disconnected or not yet opened
76 if not connection.ws:
77 return self.DISCONNECTED
78
79 # close called but not yet complete
80 if self.close_called.is_set():
81 return self.DISCONNECTING
82
83 # connection closed uncleanly (we didn't call connection.close)
84 if self.receiver_stopped.is_set() or not connection.ws.open:
85 return self.ERROR
86
87 # everything is fine!
88 return self.CONNECTED
89
90
91 class Connection:
92 """
93 Usage::
94
95 # Connect to an arbitrary api server
96 client = await Connection.connect(
97 api_endpoint, model_uuid, username, password, cacert)
98
99 # Connect using a controller/model name
100 client = await Connection.connect_model('local.local:default')
101
102 # Connect to the currently active model
103 client = await Connection.connect_current()
104
105 Note: Any connection method or constructor can accept an optional `loop`
106 argument to override the default event loop from `asyncio.get_event_loop`.
107 """
108
109 DEFAULT_FRAME_SIZE = 'default_frame_size'
110 MAX_FRAME_SIZE = 2**22
111 "Maximum size for a single frame. Defaults to 4MB."
112
113 def __init__(
114 self, endpoint, uuid, username, password, cacert=None,
115 macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
116 self.endpoint = endpoint
117 self._endpoint = endpoint
118 self.uuid = uuid
119 if macaroons:
120 self.macaroons = macaroons
121 self.username = ''
122 self.password = ''
123 else:
124 self.macaroons = []
125 self.username = username
126 self.password = password
127 self.cacert = cacert
128 self._cacert = cacert
129 self.loop = loop or asyncio.get_event_loop()
130
131 self.__request_id__ = 0
132 self.addr = None
133 self.ws = None
134 self.facades = {}
135 self.messages = IdQueue(loop=self.loop)
136 self.monitor = Monitor(connection=self)
137 if max_frame_size is self.DEFAULT_FRAME_SIZE:
138 max_frame_size = self.MAX_FRAME_SIZE
139 self.max_frame_size = max_frame_size
140
141 @property
142 def is_open(self):
143 return self.monitor.status == Monitor.CONNECTED
144
145 def _get_ssl(self, cert=None):
146 return ssl.create_default_context(
147 purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
148
149 async def open(self):
150 if self.uuid:
151 url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
152 else:
153 url = "wss://{}/api".format(self.endpoint)
154
155 kw = dict()
156 kw['ssl'] = self._get_ssl(self.cacert)
157 kw['loop'] = self.loop
158 kw['max_size'] = self.max_frame_size
159 self.addr = url
160 self.ws = await websockets.connect(url, **kw)
161 self.loop.create_task(self.receiver())
162 self.monitor.receiver_stopped.clear()
163 log.info("Driver connected to juju %s", url)
164 self.monitor.close_called.clear()
165 return self
166
167 async def close(self):
168 if not self.ws:
169 return
170 self.monitor.close_called.set()
171 await self.monitor.pinger_stopped.wait()
172 await self.monitor.receiver_stopped.wait()
173 await self.ws.close()
174 self.ws = None
175
176 async def recv(self, request_id):
177 if not self.is_open:
178 raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
179 return await self.messages.get(request_id)
180
181 async def receiver(self):
182 try:
183 while self.is_open:
184 result = await utils.run_with_interrupt(
185 self.ws.recv(),
186 self.monitor.close_called,
187 loop=self.loop)
188 if self.monitor.close_called.is_set():
189 break
190 if result is not None:
191 result = json.loads(result)
192 await self.messages.put(result['request-id'], result)
193 except CancelledError:
194 pass
195 except websockets.ConnectionClosed as e:
196 log.warning('Receiver: Connection closed, reconnecting')
197 await self.messages.put_all(e)
198 # the reconnect has to be done as a task because the receiver will
199 # be cancelled by the reconnect and we don't want the reconnect
200 # to be aborted half-way through
201 self.loop.create_task(self.reconnect())
202 return
203 except Exception as e:
204 log.exception("Error in receiver")
205 # make pending listeners aware of the error
206 await self.messages.put_all(e)
207 raise
208 finally:
209 self.monitor.receiver_stopped.set()
210
211 async def pinger(self):
212 '''
213 A Controller can time us out if we are silent for too long. This
214 is especially true in JaaS, which has a fairly strict timeout.
215
216 To prevent timing out, we send a ping every ten seconds.
217
218 '''
219 async def _do_ping():
220 try:
221 await pinger_facade.Ping()
222 await asyncio.sleep(10, loop=self.loop)
223 except CancelledError:
224 pass
225
226 pinger_facade = client.PingerFacade.from_connection(self)
227 try:
228 while True:
229 await utils.run_with_interrupt(
230 _do_ping(),
231 self.monitor.close_called,
232 loop=self.loop)
233 if self.monitor.close_called.is_set():
234 break
235 finally:
236 self.monitor.pinger_stopped.set()
237 return
238
239 async def rpc(self, msg, encoder=None):
240 self.__request_id__ += 1
241 msg['request-id'] = self.__request_id__
242 if'params' not in msg:
243 msg['params'] = {}
244 if "version" not in msg:
245 msg['version'] = self.facades[msg['type']]
246 outgoing = json.dumps(msg, indent=2, cls=encoder)
247 for attempt in range(3):
248 try:
249 await self.ws.send(outgoing)
250 break
251 except websockets.ConnectionClosed:
252 if attempt == 2:
253 raise
254 log.warning('RPC: Connection closed, reconnecting')
255 # the reconnect has to be done in a separate task because,
256 # if it is triggered by the pinger, then this RPC call will
257 # be cancelled when the pinger is cancelled by the reconnect,
258 # and we don't want the reconnect to be aborted halfway through
259 await asyncio.wait([self.reconnect()], loop=self.loop)
260 result = await self.recv(msg['request-id'])
261
262 if not result:
263 return result
264
265 if 'error' in result:
266 # API Error Response
267 raise JujuAPIError(result)
268
269 if 'response' not in result:
270 # This may never happen
271 return result
272
273 if 'results' in result['response']:
274 # Check for errors in a result list.
275 errors = []
276 for res in result['response']['results']:
277 if res.get('error', {}).get('message'):
278 errors.append(res['error']['message'])
279 if errors:
280 raise JujuError(errors)
281
282 elif result['response'].get('error', {}).get('message'):
283 raise JujuError(result['response']['error']['message'])
284
285 return result
286
287 def http_headers(self):
288 """Return dictionary of http headers necessary for making an http
289 connection to the endpoint of this Connection.
290
291 :return: Dictionary of headers
292
293 """
294 if not self.username:
295 return {}
296
297 creds = u'{}:{}'.format(
298 tag.user(self.username),
299 self.password or ''
300 )
301 token = base64.b64encode(creds.encode())
302 return {
303 'Authorization': 'Basic {}'.format(token.decode())
304 }
305
306 def https_connection(self):
307 """Return an https connection to this Connection's endpoint.
308
309 Returns a 3-tuple containing::
310
311 1. The :class:`HTTPSConnection` instance
312 2. Dictionary of auth headers to be used with the connection
313 3. The root url path (str) to be used for requests.
314
315 """
316 endpoint = self.endpoint
317 host, remainder = endpoint.split(':', 1)
318 port = remainder
319 if '/' in remainder:
320 port, _ = remainder.split('/', 1)
321
322 conn = HTTPSConnection(
323 host, int(port),
324 context=self._get_ssl(self.cacert),
325 )
326
327 path = (
328 "/model/{}".format(self.uuid)
329 if self.uuid else ""
330 )
331 return conn, self.http_headers(), path
332
333 async def clone(self):
334 """Return a new Connection, connected to the same websocket endpoint
335 as this one.
336
337 """
338 return await Connection.connect(
339 self.endpoint,
340 self.uuid,
341 self.username,
342 self.password,
343 self.cacert,
344 self.macaroons,
345 self.loop,
346 self.max_frame_size,
347 )
348
349 async def controller(self):
350 """Return a Connection to the controller at self.endpoint
351
352 """
353 return await Connection.connect(
354 self.endpoint,
355 None,
356 self.username,
357 self.password,
358 self.cacert,
359 self.macaroons,
360 self.loop,
361 )
362
363 async def _try_endpoint(self, endpoint, cacert):
364 success = False
365 result = None
366 new_endpoints = []
367
368 self.endpoint = endpoint
369 self.cacert = cacert
370 await self.open()
371 try:
372 result = await self.login()
373 if 'discharge-required-error' in result['response']:
374 log.info('Macaroon discharge required, disconnecting')
375 else:
376 # successful login!
377 log.info('Authenticated')
378 success = True
379 except JujuAPIError as e:
380 if e.error_code != 'redirection required':
381 raise
382 log.info('Controller requested redirect')
383 redirect_info = await self.redirect_info()
384 redir_cacert = redirect_info['ca-cert']
385 new_endpoints = [
386 ("{value}:{port}".format(**s), redir_cacert)
387 for servers in redirect_info['servers']
388 for s in servers if s["scope"] == 'public'
389 ]
390 finally:
391 if not success:
392 await self.close()
393 return success, result, new_endpoints
394
395 async def reconnect(self):
396 """ Force a reconnection.
397 """
398 monitor = self.monitor
399 if monitor.reconnecting.locked() or monitor.close_called.is_set():
400 return
401 async with monitor.reconnecting:
402 await self.close()
403 await self._connect()
404
405 async def _connect(self):
406 endpoints = [(self._endpoint, self._cacert)]
407 while endpoints:
408 _endpoint, _cacert = endpoints.pop(0)
409 success, result, new_endpoints = await self._try_endpoint(
410 _endpoint, _cacert)
411 if success:
412 break
413 endpoints.extend(new_endpoints)
414 else:
415 # ran out of endpoints without a successful login
416 raise JujuConnectionError("Couldn't authenticate to {}".format(
417 self._endpoint))
418
419 response = result['response']
420 self.info = response.copy()
421 self.build_facades(response.get('facades', {}))
422 self.loop.create_task(self.pinger())
423 self.monitor.pinger_stopped.clear()
424
425 @classmethod
426 async def connect(
427 cls, endpoint, uuid, username, password, cacert=None,
428 macaroons=None, loop=None, max_frame_size=None):
429 """Connect to the websocket.
430
431 If uuid is None, the connection will be to the controller. Otherwise it
432 will be to the model.
433
434 """
435 client = cls(endpoint, uuid, username, password, cacert, macaroons,
436 loop, max_frame_size)
437 await client._connect()
438 return client
439
440 @classmethod
441 async def connect_current(cls, loop=None, max_frame_size=None):
442 """Connect to the currently active model.
443
444 """
445 jujudata = JujuData()
446
447 controller_name = jujudata.current_controller()
448 if not controller_name:
449 raise JujuConnectionError('No current controller')
450
451 model_name = jujudata.current_model()
452
453 return await cls.connect_model(
454 '{}:{}'.format(controller_name, model_name), loop, max_frame_size)
455
456 @classmethod
457 async def connect_current_controller(cls, loop=None, max_frame_size=None):
458 """Connect to the currently active controller.
459
460 """
461 jujudata = JujuData()
462 controller_name = jujudata.current_controller()
463 if not controller_name:
464 raise JujuConnectionError('No current controller')
465
466 return await cls.connect_controller(controller_name, loop,
467 max_frame_size)
468
469 @classmethod
470 async def connect_controller(cls, controller_name, loop=None,
471 max_frame_size=None):
472 """Connect to a controller by name.
473
474 """
475 jujudata = JujuData()
476 controller = jujudata.controllers()[controller_name]
477 endpoint = controller['api-endpoints'][0]
478 cacert = controller.get('ca-cert')
479 accounts = jujudata.accounts()[controller_name]
480 username = accounts['user']
481 password = accounts.get('password')
482 macaroons = get_macaroons(controller_name) if not password else None
483
484 return await cls.connect(
485 endpoint, None, username, password, cacert, macaroons, loop,
486 max_frame_size)
487
488 @classmethod
489 async def connect_model(cls, model, loop=None, max_frame_size=None):
490 """Connect to a model by name.
491
492 :param str model: [<controller>:]<model>
493
494 """
495 jujudata = JujuData()
496
497 if ':' in model:
498 # explicit controller given
499 controller_name, model_name = model.split(':')
500 else:
501 # use the current controller if one isn't explicitly given
502 controller_name = jujudata.current_controller()
503 model_name = model
504
505 accounts = jujudata.accounts()[controller_name]
506 username = accounts['user']
507 # model name must include a user prefix, so add it if it doesn't
508 if '/' not in model_name:
509 model_name = '{}/{}'.format(username, model_name)
510
511 controller = jujudata.controllers()[controller_name]
512 endpoint = controller['api-endpoints'][0]
513 cacert = controller.get('ca-cert')
514 password = accounts.get('password')
515 models = jujudata.models()[controller_name]
516 model_uuid = models['models'][model_name]['uuid']
517 macaroons = get_macaroons(controller_name) if not password else None
518
519 return await cls.connect(
520 endpoint, model_uuid, username, password, cacert, macaroons, loop,
521 max_frame_size)
522
523 def build_facades(self, facades):
524 self.facades.clear()
525 for facade in facades:
526 self.facades[facade['name']] = facade['versions'][-1]
527
528 async def login(self):
529 username = self.username
530 if username and not username.startswith('user-'):
531 username = 'user-{}'.format(username)
532
533 result = await self.rpc({
534 "type": "Admin",
535 "request": "Login",
536 "version": 3,
537 "params": {
538 "auth-tag": username,
539 "credentials": self.password,
540 "nonce": "".join(random.sample(string.printable, 12)),
541 "macaroons": self.macaroons
542 }})
543 return result
544
545 async def redirect_info(self):
546 try:
547 result = await self.rpc({
548 "type": "Admin",
549 "request": "RedirectInfo",
550 "version": 3,
551 })
552 except JujuAPIError as e:
553 if e.message == 'not redirected':
554 return None
555 raise
556 return result['response']
557
558
559 class JujuData:
560 def __init__(self):
561 self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
562 self.path = os.path.abspath(os.path.expanduser(self.path))
563
564 def current_controller(self):
565 cmd = shlex.split('juju list-controllers --format yaml')
566 output = subprocess.check_output(cmd)
567 output = yaml.safe_load(output)
568 return output.get('current-controller', '')
569
570 def current_model(self, controller_name=None):
571 if not controller_name:
572 controller_name = self.current_controller()
573 models = self.models()[controller_name]
574 if 'current-model' not in models:
575 raise JujuError('No current model')
576 return models['current-model']
577
578 def controllers(self):
579 return self._load_yaml('controllers.yaml', 'controllers')
580
581 def models(self):
582 return self._load_yaml('models.yaml', 'controllers')
583
584 def accounts(self):
585 return self._load_yaml('accounts.yaml', 'controllers')
586
587 def credentials(self):
588 return self._load_yaml('credentials.yaml', 'credentials')
589
590 def load_credential(self, cloud, name=None):
591 """Load a local credential.
592
593 :param str cloud: Name of cloud to load credentials from.
594 :param str name: Name of credential. If None, the default credential
595 will be used, if available.
596 :returns: A CloudCredential instance, or None.
597 """
598 try:
599 cloud = tag.untag('cloud-', cloud)
600 creds_data = self.credentials()[cloud]
601 if not name:
602 default_credential = creds_data.pop('default-credential', None)
603 default_region = creds_data.pop('default-region', None) # noqa
604 if default_credential:
605 name = creds_data['default-credential']
606 elif len(creds_data) == 1:
607 name = list(creds_data)[0]
608 else:
609 return None, None
610 cred_data = creds_data[name]
611 auth_type = cred_data.pop('auth-type')
612 return name, client.CloudCredential(
613 auth_type=auth_type,
614 attrs=cred_data,
615 )
616 except (KeyError, FileNotFoundError):
617 return None, None
618
619 def _load_yaml(self, filename, key):
620 filepath = os.path.join(self.path, filename)
621 with io.open(filepath, 'rt') as f:
622 return yaml.safe_load(f)[key]
623
624
625 def get_macaroons(controller_name=None):
626 """Decode and return macaroons from default ~/.go-cookies
627
628 """
629 cookie_files = []
630 if controller_name:
631 cookie_files.append('~/.local/share/juju/cookies/{}.json'.format(
632 controller_name))
633 cookie_files.append('~/.go-cookies')
634 for cookie_file in cookie_files:
635 cookie_file = Path(cookie_file).expanduser()
636 if cookie_file.exists():
637 try:
638 cookies = json.loads(cookie_file.read_text())
639 break
640 except (OSError, ValueError):
641 log.warn("Couldn't load macaroons from %s", cookie_file)
642 return []
643 else:
644 log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files))
645 return []
646
647 base64_macaroons = [
648 c['Value'] for c in cookies
649 if c['Name'].startswith('macaroon-') and c['Value']
650 ]
651
652 return [
653 json.loads(base64.b64decode(value).decode('utf-8'))
654 for value in base64_macaroons
655 ]