bdd1c3f3e2a6a92287a546f3cd8075652f2e3b15
[osm/N2VC.git] / modules / libjuju / juju / client / connection.py
1 import asyncio
2 import base64
3 import json
4 import logging
5 import ssl
6 import urllib.request
7 import weakref
8 from concurrent.futures import CancelledError
9 from http.client import HTTPSConnection
10
11 import macaroonbakery.httpbakery as httpbakery
12 import macaroonbakery.bakery as bakery
13 import websockets
14 from juju import errors, tag, utils
15 from juju.client import client
16 from juju.utils import IdQueue
17
18 log = logging.getLogger('juju.client.connection')
19
20
21 class Monitor:
22 """
23 Monitor helper class for our Connection class.
24
25 Contains a reference to an instantiated Connection, along with a
26 reference to the Connection.receiver Future. Upon inspection of
27 these objects, this class determines whether the connection is in
28 an 'error', 'connected' or 'disconnected' state.
29
30 Use this class to stay up to date on the health of a connection,
31 and take appropriate action if the connection errors out due to
32 network issues or other unexpected circumstances.
33
34 """
35 ERROR = 'error'
36 CONNECTED = 'connected'
37 DISCONNECTING = 'disconnecting'
38 DISCONNECTED = 'disconnected'
39
40 def __init__(self, connection):
41 self.connection = weakref.ref(connection)
42 self.reconnecting = asyncio.Lock(loop=connection.loop)
43 self.close_called = asyncio.Event(loop=connection.loop)
44
45 @property
46 def status(self):
47 """
48 Determine the status of the connection and receiver, and return
49 ERROR, CONNECTED, or DISCONNECTED as appropriate.
50
51 For simplicity, we only consider ourselves to be connected
52 after the Connection class has setup a receiver task. This
53 only happens after the websocket is open, and the connection
54 isn't usable until that receiver has been started.
55
56 """
57 connection = self.connection()
58
59 # the connection instance was destroyed but someone kept
60 # a separate reference to the monitor for some reason
61 if not connection:
62 return self.DISCONNECTED
63
64 # connection cleanly disconnected or not yet opened
65 if not connection.ws:
66 return self.DISCONNECTED
67
68 # close called but not yet complete
69 if self.close_called.is_set():
70 return self.DISCONNECTING
71
72 # connection closed uncleanly (we didn't call connection.close)
73 stopped = connection._receiver_task.stopped.is_set()
74 if stopped or not connection.ws.open:
75 return self.ERROR
76
77 # everything is fine!
78 return self.CONNECTED
79
80
81 class Connection:
82 """
83 Usage::
84
85 # Connect to an arbitrary api server
86 client = await Connection.connect(
87 api_endpoint, model_uuid, username, password, cacert)
88
89 Note: Any connection method or constructor can accept an optional `loop`
90 argument to override the default event loop from `asyncio.get_event_loop`.
91 """
92
93 MAX_FRAME_SIZE = 2**22
94 "Maximum size for a single frame. Defaults to 4MB."
95
96 @classmethod
97 async def connect(
98 cls,
99 endpoint=None,
100 uuid=None,
101 username=None,
102 password=None,
103 cacert=None,
104 bakery_client=None,
105 loop=None,
106 max_frame_size=None,
107 ):
108 """Connect to the websocket.
109
110 If uuid is None, the connection will be to the controller. Otherwise it
111 will be to the model.
112 :param str endpoint The hostname:port of the controller to connect to.
113 :param str uuid The model UUID to connect to (None for a
114 controller-only connection).
115 :param str username The username for controller-local users (or None
116 to use macaroon-based login.)
117 :param str password The password for controller-local users.
118 :param str cacert The CA certificate of the controller (PEM formatted).
119 :param httpbakery.Client bakery_client The macaroon bakery client to
120 to use when performing macaroon-based login. Macaroon tokens
121 acquired when logging will be saved to bakery_client.cookies.
122 If this is None, a default bakery_client will be used.
123 :param loop asyncio.BaseEventLoop The event loop to use for async
124 operations.
125 :param max_frame_size The maximum websocket frame size to allow.
126 """
127 self = cls()
128 if endpoint is None:
129 raise ValueError('no endpoint provided')
130 self.uuid = uuid
131 if bakery_client is None:
132 bakery_client = httpbakery.Client()
133 self.bakery_client = bakery_client
134 if username and '@' in username and not username.endswith('@local'):
135 # We're trying to log in as an external user - we need to use
136 # macaroon authentication with no username or password.
137 if password is not None:
138 raise errors.JujuAuthError('cannot log in as external '
139 'user with a password')
140 username = None
141 self.usertag = tag.user(username)
142 self.password = password
143 self.loop = loop or asyncio.get_event_loop()
144
145 self.__request_id__ = 0
146
147 # The following instance variables are initialized by the
148 # _connect_with_redirect method, but create them here
149 # as a reminder that they will exist.
150 self.addr = None
151 self.ws = None
152 self.endpoint = None
153 self.cacert = None
154 self.info = None
155
156 # Create that _Task objects but don't start the tasks yet.
157 self._pinger_task = _Task(self._pinger, self.loop)
158 self._receiver_task = _Task(self._receiver, self.loop)
159
160 self.facades = {}
161 self.messages = IdQueue(loop=self.loop)
162 self.monitor = Monitor(connection=self)
163 if max_frame_size is None:
164 max_frame_size = self.MAX_FRAME_SIZE
165 self.max_frame_size = max_frame_size
166 await self._connect_with_redirect([(endpoint, cacert)])
167 return self
168
169 @property
170 def username(self):
171 if not self.usertag:
172 return None
173 return self.usertag[len('user-'):]
174
175 @property
176 def is_open(self):
177 return self.monitor.status == Monitor.CONNECTED
178
179 def _get_ssl(self, cert=None):
180 return ssl.create_default_context(
181 purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
182
183 async def _open(self, endpoint, cacert):
184 if self.uuid:
185 url = "wss://{}/model/{}/api".format(endpoint, self.uuid)
186 else:
187 url = "wss://{}/api".format(endpoint)
188
189 return (await websockets.connect(
190 url,
191 ssl=self._get_ssl(cacert),
192 loop=self.loop,
193 max_size=self.max_frame_size,
194 ), url, endpoint, cacert)
195
196 async def close(self):
197 if not self.ws:
198 return
199 self.monitor.close_called.set()
200 await self._pinger_task.stopped.wait()
201 await self._receiver_task.stopped.wait()
202 await self.ws.close()
203 self.ws = None
204
205 async def _recv(self, request_id):
206 if not self.is_open:
207 raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
208 return await self.messages.get(request_id)
209
210 async def _receiver(self):
211 try:
212 while self.is_open:
213 result = await utils.run_with_interrupt(
214 self.ws.recv(),
215 self.monitor.close_called,
216 loop=self.loop)
217 if self.monitor.close_called.is_set():
218 break
219 if result is not None:
220 result = json.loads(result)
221 await self.messages.put(result['request-id'], result)
222 except CancelledError:
223 pass
224 except websockets.ConnectionClosed as e:
225 log.warning('Receiver: Connection closed, reconnecting')
226 await self.messages.put_all(e)
227 # the reconnect has to be done as a task because the receiver will
228 # be cancelled by the reconnect and we don't want the reconnect
229 # to be aborted half-way through
230 self.loop.create_task(self.reconnect())
231 return
232 except Exception as e:
233 log.exception("Error in receiver")
234 # make pending listeners aware of the error
235 await self.messages.put_all(e)
236 raise
237
238 async def _pinger(self):
239 '''
240 A Controller can time us out if we are silent for too long. This
241 is especially true in JaaS, which has a fairly strict timeout.
242
243 To prevent timing out, we send a ping every ten seconds.
244
245 '''
246 async def _do_ping():
247 try:
248 await pinger_facade.Ping()
249 await asyncio.sleep(10, loop=self.loop)
250 except CancelledError:
251 pass
252
253 pinger_facade = client.PingerFacade.from_connection(self)
254 try:
255 while True:
256 await utils.run_with_interrupt(
257 _do_ping(),
258 self.monitor.close_called,
259 loop=self.loop)
260 if self.monitor.close_called.is_set():
261 break
262 except websockets.exceptions.ConnectionClosed:
263 # The connection has closed - we can't do anything
264 # more until the connection is restarted.
265 log.debug('ping failed because of closed connection')
266 pass
267
268 async def rpc(self, msg, encoder=None):
269 '''Make an RPC to the API. The message is encoded as JSON
270 using the given encoder if any.
271 :param msg: Parameters for the call (will be encoded as JSON).
272 :param encoder: Encoder to be used when encoding the message.
273 :return: The result of the call.
274 :raises JujuAPIError: When there's an error returned.
275 :raises JujuError:
276 '''
277 self.__request_id__ += 1
278 msg['request-id'] = self.__request_id__
279 if'params' not in msg:
280 msg['params'] = {}
281 if "version" not in msg:
282 msg['version'] = self.facades[msg['type']]
283 outgoing = json.dumps(msg, indent=2, cls=encoder)
284 log.debug('connection {} -> {}'.format(id(self), outgoing))
285 for attempt in range(3):
286 if self.monitor.status == Monitor.DISCONNECTED:
287 # closed cleanly; shouldn't try to reconnect
288 raise websockets.exceptions.ConnectionClosed(
289 0, 'websocket closed')
290 try:
291 await self.ws.send(outgoing)
292 break
293 except websockets.ConnectionClosed:
294 if attempt == 2:
295 raise
296 log.warning('RPC: Connection closed, reconnecting')
297 # the reconnect has to be done in a separate task because,
298 # if it is triggered by the pinger, then this RPC call will
299 # be cancelled when the pinger is cancelled by the reconnect,
300 # and we don't want the reconnect to be aborted halfway through
301 await asyncio.wait([self.reconnect()], loop=self.loop)
302 if self.monitor.status != Monitor.CONNECTED:
303 # reconnect failed; abort and shutdown
304 log.error('RPC: Automatic reconnect failed')
305 raise
306 result = await self._recv(msg['request-id'])
307 log.debug('connection {} <- {}'.format(id(self), result))
308
309 if not result:
310 return result
311
312 if 'error' in result:
313 # API Error Response
314 raise errors.JujuAPIError(result)
315
316 if 'response' not in result:
317 # This may never happen
318 return result
319
320 if 'results' in result['response']:
321 # Check for errors in a result list.
322 # TODO This loses the results that might have succeeded.
323 # Perhaps JujuError should return all the results including
324 # errors, or perhaps a keyword parameter to the rpc method
325 # could be added to trigger this behaviour.
326 err_results = []
327 for res in result['response']['results']:
328 if res.get('error', {}).get('message'):
329 err_results.append(res['error']['message'])
330 if err_results:
331 raise errors.JujuError(err_results)
332
333 elif result['response'].get('error', {}).get('message'):
334 raise errors.JujuError(result['response']['error']['message'])
335
336 return result
337
338 def _http_headers(self):
339 """Return dictionary of http headers necessary for making an http
340 connection to the endpoint of this Connection.
341
342 :return: Dictionary of headers
343
344 """
345 if not self.usertag:
346 return {}
347
348 creds = u'{}:{}'.format(
349 self.usertag,
350 self.password or ''
351 )
352 token = base64.b64encode(creds.encode())
353 return {
354 'Authorization': 'Basic {}'.format(token.decode())
355 }
356
357 def https_connection(self):
358 """Return an https connection to this Connection's endpoint.
359
360 Returns a 3-tuple containing::
361
362 1. The :class:`HTTPSConnection` instance
363 2. Dictionary of auth headers to be used with the connection
364 3. The root url path (str) to be used for requests.
365
366 """
367 endpoint = self.endpoint
368 host, remainder = endpoint.split(':', 1)
369 port = remainder
370 if '/' in remainder:
371 port, _ = remainder.split('/', 1)
372
373 conn = HTTPSConnection(
374 host, int(port),
375 context=self._get_ssl(self.cacert),
376 )
377
378 path = (
379 "/model/{}".format(self.uuid)
380 if self.uuid else ""
381 )
382 return conn, self._http_headers(), path
383
384 async def clone(self):
385 """Return a new Connection, connected to the same websocket endpoint
386 as this one.
387
388 """
389 return await Connection.connect(**self.connect_params())
390
391 def connect_params(self):
392 """Return a tuple of parameters suitable for passing to
393 Connection.connect that can be used to make a new connection
394 to the same controller (and model if specified. The first
395 element in the returned tuple holds the endpoint argument;
396 the other holds a dict of the keyword args.
397 """
398 return {
399 'endpoint': self.endpoint,
400 'uuid': self.uuid,
401 'username': self.username,
402 'password': self.password,
403 'cacert': self.cacert,
404 'bakery_client': self.bakery_client,
405 'loop': self.loop,
406 'max_frame_size': self.max_frame_size,
407 }
408
409 async def controller(self):
410 """Return a Connection to the controller at self.endpoint
411 """
412 return await Connection.connect(
413 self.endpoint,
414 username=self.username,
415 password=self.password,
416 cacert=self.cacert,
417 bakery_client=self.bakery_client,
418 loop=self.loop,
419 max_frame_size=self.max_frame_size,
420 )
421
422 async def reconnect(self):
423 """ Force a reconnection.
424 """
425 monitor = self.monitor
426 if monitor.reconnecting.locked() or monitor.close_called.is_set():
427 return
428 async with monitor.reconnecting:
429 await self.close()
430 await self._connect_with_login([(self.endpoint, self.cacert)])
431
432 async def _connect(self, endpoints):
433 if len(endpoints) == 0:
434 raise errors.JujuConnectionError('no endpoints to connect to')
435
436 async def _try_endpoint(endpoint, cacert, delay):
437 if delay:
438 await asyncio.sleep(delay)
439 return await self._open(endpoint, cacert)
440
441 # Try all endpoints in parallel, with slight increasing delay (+100ms
442 # for each subsequent endpoint); the delay allows us to prefer the
443 # earlier endpoints over the latter. Use first successful connection.
444 tasks = [self.loop.create_task(_try_endpoint(endpoint, cacert,
445 0.1 * i))
446 for i, (endpoint, cacert) in enumerate(endpoints)]
447 for task in asyncio.as_completed(tasks, loop=self.loop):
448 try:
449 result = await task
450 break
451 except ConnectionError:
452 continue # ignore; try another endpoint
453 else:
454 raise errors.JujuConnectionError(
455 'Unable to connect to any endpoint: {}'.format(', '.join([
456 endpoint for endpoint, cacert in endpoints])))
457 for task in tasks:
458 task.cancel()
459 self.ws = result[0]
460 self.addr = result[1]
461 self.endpoint = result[2]
462 self.cacert = result[3]
463 self._receiver_task.start()
464 log.info("Driver connected to juju %s", self.addr)
465 self.monitor.close_called.clear()
466
467 async def _connect_with_login(self, endpoints):
468 """Connect to the websocket.
469
470 If uuid is None, the connection will be to the controller. Otherwise it
471 will be to the model.
472 :return: The response field of login response JSON object.
473 """
474 success = False
475 try:
476 await self._connect(endpoints)
477 # It's possible that we may get several discharge-required errors,
478 # corresponding to different levels of authentication, so retry
479 # a few times.
480 for i in range(0, 2):
481 result = (await self.login())['response']
482 macaroonJSON = result.get('discharge-required')
483 if macaroonJSON is None:
484 self.info = result
485 success = True
486 return result
487 macaroon = bakery.Macaroon.from_dict(macaroonJSON)
488 self.bakery_client.handle_error(
489 httpbakery.Error(
490 code=httpbakery.ERR_DISCHARGE_REQUIRED,
491 message=result.get('discharge-required-error'),
492 version=macaroon.version,
493 info=httpbakery.ErrorInfo(
494 macaroon=macaroon,
495 macaroon_path=result.get('macaroon-path'),
496 ),
497 ),
498 # note: remove the port number.
499 'https://' + self.endpoint + '/',
500 )
501 raise errors.JujuAuthError('failed to authenticate '
502 'after several attempts')
503 finally:
504 if not success:
505 await self.close()
506
507 async def _connect_with_redirect(self, endpoints):
508 try:
509 login_result = await self._connect_with_login(endpoints)
510 except errors.JujuRedirectException as e:
511 login_result = await self._connect_with_login(e.endpoints)
512 self._build_facades(login_result.get('facades', {}))
513 self._pinger_task.start()
514
515 def _build_facades(self, facades):
516 self.facades.clear()
517 for facade in facades:
518 self.facades[facade['name']] = facade['versions'][-1]
519
520 async def login(self):
521 params = {}
522 if self.password:
523 params['auth-tag'] = self.usertag
524 params['credentials'] = self.password
525 else:
526 macaroons = _macaroons_for_domain(self.bakery_client.cookies,
527 self.endpoint)
528 params['macaroons'] = [[bakery.macaroon_to_dict(m) for m in ms]
529 for ms in macaroons]
530
531 try:
532 return await self.rpc({
533 "type": "Admin",
534 "request": "Login",
535 "version": 3,
536 "params": params,
537 })
538 except errors.JujuAPIError as e:
539 if e.error_code != 'redirection required':
540 raise
541 log.info('Controller requested redirect')
542 # Fetch additional redirection information now so that
543 # we can safely close the connection after login
544 # fails.
545 redirect_info = (await self.rpc({
546 "type": "Admin",
547 "request": "RedirectInfo",
548 "version": 3,
549 }))['response']
550 raise errors.JujuRedirectException(redirect_info) from e
551
552
553 class _Task:
554 def __init__(self, task, loop):
555 self.stopped = asyncio.Event(loop=loop)
556 self.stopped.set()
557 self.task = task
558 self.loop = loop
559
560 def start(self):
561 async def run():
562 try:
563 return await(self.task())
564 finally:
565 self.stopped.set()
566 self.stopped.clear()
567 self.loop.create_task(run())
568
569
570 def _macaroons_for_domain(cookies, domain):
571 '''Return any macaroons from the given cookie jar that
572 apply to the given domain name.'''
573 req = urllib.request.Request('https://' + domain + '/')
574 cookies.add_cookie_header(req)
575 return httpbakery.extract_macaroons(req)