Refactor connection task management to avoid cancels (#117)
[osm/N2VC.git] / juju / client / connection.py
1 import base64
2 import io
3 import json
4 import logging
5 import os
6 import random
7 import shlex
8 import ssl
9 import string
10 import subprocess
11 import websockets
12 from concurrent.futures import CancelledError
13 from http.client import HTTPSConnection
14
15 import asyncio
16 import yaml
17
18 from juju import tag, utils
19 from juju.client import client
20 from juju.client.version_map import VERSION_MAP
21 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
22 from juju.utils import IdQueue
23
24 log = logging.getLogger("websocket")
25
26
27 class Monitor:
28 """
29 Monitor helper class for our Connection class.
30
31 Contains a reference to an instantiated Connection, along with a
32 reference to the Connection.receiver Future. Upon inspecttion of
33 these objects, this class determines whether the connection is in
34 an 'error', 'connected' or 'disconnected' state.
35
36 Use this class to stay up to date on the health of a connection,
37 and take appropriate action if the connection errors out due to
38 network issues or other unexpected circumstances.
39
40 """
41 ERROR = 'error'
42 CONNECTED = 'connected'
43 DISCONNECTED = 'disconnected'
44 UNKNOWN = 'unknown'
45
46 def __init__(self, connection):
47 self.connection = connection
48 self.close_called = asyncio.Event(loop=self.connection.loop)
49 self.receiver_stopped = asyncio.Event(loop=self.connection.loop)
50 self.pinger_stopped = asyncio.Event(loop=self.connection.loop)
51 self.receiver_stopped.set()
52 self.pinger_stopped.set()
53
54 @property
55 def status(self):
56 """
57 Determine the status of the connection and receiver, and return
58 ERROR, CONNECTED, or DISCONNECTED as appropriate.
59
60 For simplicity, we only consider ourselves to be connected
61 after the Connection class has setup a receiver task. This
62 only happens after the websocket is open, and the connection
63 isn't usable until that receiver has been started.
64
65 """
66
67 # DISCONNECTED: connection not yet open
68 if not self.connection.ws:
69 return self.DISCONNECTED
70 if self.receiver_stopped.is_set():
71 return self.DISCONNECTED
72
73 # ERROR: Connection closed (or errored), but we didn't call
74 # connection.close
75 if not self.close_called.is_set() and self.receiver_stopped.is_set():
76 return self.ERROR
77 if not self.close_called.is_set() and not self.connection.ws.open:
78 # The check for self.receiver_stopped existing above guards
79 # against the case where we're not open because we simply
80 # haven't setup the connection yet.
81 return self.ERROR
82
83 # DISCONNECTED: cleanly disconnected.
84 if self.close_called.is_set() and not self.connection.ws.open:
85 return self.DISCONNECTED
86
87 # CONNECTED: everything is fine!
88 if self.connection.ws.open:
89 return self.CONNECTED
90
91 # UNKNOWN: We should never hit this state -- if we do,
92 # something went wrong with the logic above, and we do not
93 # know what state the connection is in.
94 return self.UNKNOWN
95
96
97 class Connection:
98 """
99 Usage::
100
101 # Connect to an arbitrary api server
102 client = await Connection.connect(
103 api_endpoint, model_uuid, username, password, cacert)
104
105 # Connect using a controller/model name
106 client = await Connection.connect_model('local.local:default')
107
108 # Connect to the currently active model
109 client = await Connection.connect_current()
110
111 Note: Any connection method or constructor can accept an optional `loop`
112 argument to override the default event loop from `asyncio.get_event_loop`.
113 """
114 def __init__(
115 self, endpoint, uuid, username, password, cacert=None,
116 macaroons=None, loop=None):
117 self.endpoint = endpoint
118 self.uuid = uuid
119 if macaroons:
120 self.macaroons = macaroons
121 self.username = ''
122 self.password = ''
123 else:
124 self.macaroons = []
125 self.username = username
126 self.password = password
127 self.cacert = cacert
128 self.loop = loop or asyncio.get_event_loop()
129
130 self.__request_id__ = 0
131 self.addr = None
132 self.ws = None
133 self.facades = {}
134 self.messages = IdQueue(loop=self.loop)
135 self.monitor = Monitor(connection=self)
136
137 @property
138 def is_open(self):
139 if self.ws:
140 return self.ws.open
141 return False
142
143 def _get_ssl(self, cert=None):
144 return ssl.create_default_context(
145 purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
146
147 async def open(self):
148 if self.uuid:
149 url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
150 else:
151 url = "wss://{}/api".format(self.endpoint)
152
153 kw = dict()
154 kw['ssl'] = self._get_ssl(self.cacert)
155 kw['loop'] = self.loop
156 self.addr = url
157 self.ws = await websockets.connect(url, **kw)
158 self.loop.create_task(self.receiver())
159 self.monitor.receiver_stopped.clear()
160 log.info("Driver connected to juju %s", url)
161 self.monitor.close_called.clear()
162 return self
163
164 async def close(self):
165 if not self.is_open:
166 return
167 self.monitor.close_called.set()
168 await self.monitor.pinger_stopped.wait()
169 await self.monitor.receiver_stopped.wait()
170 await self.ws.close()
171
172 async def recv(self, request_id):
173 if not self.is_open:
174 raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
175 return await self.messages.get(request_id)
176
177 async def receiver(self):
178 try:
179 while self.is_open:
180 result = await utils.run_with_interrupt(
181 self.ws.recv(),
182 self.monitor.close_called,
183 loop=self.loop)
184 if self.monitor.close_called.is_set():
185 break
186 if result is not None:
187 result = json.loads(result)
188 await self.messages.put(result['request-id'], result)
189 except CancelledError:
190 pass
191 except Exception as e:
192 await self.messages.put_all(e)
193 if isinstance(e, websockets.ConnectionClosed):
194 # ConnectionClosed is not really exceptional for us,
195 # but it may be for any pending message listeners
196 return
197 log.exception("Error in receiver")
198 raise
199 finally:
200 self.monitor.receiver_stopped.set()
201
202 async def pinger(self):
203 '''
204 A Controller can time us out if we are silent for too long. This
205 is especially true in JaaS, which has a fairly strict timeout.
206
207 To prevent timing out, we send a ping every ten seconds.
208
209 '''
210 async def _do_ping():
211 try:
212 await pinger_facade.Ping()
213 await asyncio.sleep(10, loop=self.loop)
214 except CancelledError:
215 pass
216
217 pinger_facade = client.PingerFacade.from_connection(self)
218 try:
219 while self.is_open:
220 await utils.run_with_interrupt(
221 _do_ping(),
222 self.monitor.close_called,
223 loop=self.loop)
224 if self.monitor.close_called.is_set():
225 break
226 finally:
227 self.monitor.pinger_stopped.set()
228
229 async def rpc(self, msg, encoder=None):
230 self.__request_id__ += 1
231 msg['request-id'] = self.__request_id__
232 if'params' not in msg:
233 msg['params'] = {}
234 if "version" not in msg:
235 msg['version'] = self.facades[msg['type']]
236 outgoing = json.dumps(msg, indent=2, cls=encoder)
237 await self.ws.send(outgoing)
238 result = await self.recv(msg['request-id'])
239
240 if not result:
241 return result
242
243 if 'error' in result:
244 # API Error Response
245 raise JujuAPIError(result)
246
247 if 'response' not in result:
248 # This may never happen
249 return result
250
251 if 'results' in result['response']:
252 # Check for errors in a result list.
253 errors = []
254 for res in result['response']['results']:
255 if res.get('error', {}).get('message'):
256 errors.append(res['error']['message'])
257 if errors:
258 raise JujuError(errors)
259
260 elif result['response'].get('error', {}).get('message'):
261 raise JujuError(result['response']['error']['message'])
262
263 return result
264
265 def http_headers(self):
266 """Return dictionary of http headers necessary for making an http
267 connection to the endpoint of this Connection.
268
269 :return: Dictionary of headers
270
271 """
272 if not self.username:
273 return {}
274
275 creds = u'{}:{}'.format(
276 tag.user(self.username),
277 self.password or ''
278 )
279 token = base64.b64encode(creds.encode())
280 return {
281 'Authorization': 'Basic {}'.format(token.decode())
282 }
283
284 def https_connection(self):
285 """Return an https connection to this Connection's endpoint.
286
287 Returns a 3-tuple containing::
288
289 1. The :class:`HTTPSConnection` instance
290 2. Dictionary of auth headers to be used with the connection
291 3. The root url path (str) to be used for requests.
292
293 """
294 endpoint = self.endpoint
295 host, remainder = endpoint.split(':', 1)
296 port = remainder
297 if '/' in remainder:
298 port, _ = remainder.split('/', 1)
299
300 conn = HTTPSConnection(
301 host, int(port),
302 context=self._get_ssl(self.cacert),
303 )
304
305 path = (
306 "/model/{}".format(self.uuid)
307 if self.uuid else ""
308 )
309 return conn, self.http_headers(), path
310
311 async def clone(self):
312 """Return a new Connection, connected to the same websocket endpoint
313 as this one.
314
315 """
316 return await Connection.connect(
317 self.endpoint,
318 self.uuid,
319 self.username,
320 self.password,
321 self.cacert,
322 self.macaroons,
323 self.loop,
324 )
325
326 async def controller(self):
327 """Return a Connection to the controller at self.endpoint
328
329 """
330 return await Connection.connect(
331 self.endpoint,
332 None,
333 self.username,
334 self.password,
335 self.cacert,
336 self.macaroons,
337 self.loop,
338 )
339
340 async def _try_endpoint(self, endpoint, cacert):
341 success = False
342 result = None
343 new_endpoints = []
344
345 self.endpoint = endpoint
346 self.cacert = cacert
347 await self.open()
348 try:
349 result = await self.login()
350 if 'discharge-required-error' in result['response']:
351 log.info('Macaroon discharge required, disconnecting')
352 else:
353 # successful login!
354 log.info('Authenticated')
355 success = True
356 except JujuAPIError as e:
357 if e.error_code != 'redirection required':
358 raise
359 log.info('Controller requested redirect')
360 redirect_info = await self.redirect_info()
361 redir_cacert = redirect_info['ca-cert']
362 new_endpoints = [
363 ("{value}:{port}".format(**s), redir_cacert)
364 for servers in redirect_info['servers']
365 for s in servers if s["scope"] == 'public'
366 ]
367 finally:
368 if not success:
369 await self.close()
370 return success, result, new_endpoints
371
372 @classmethod
373 async def connect(
374 cls, endpoint, uuid, username, password, cacert=None,
375 macaroons=None, loop=None):
376 """Connect to the websocket.
377
378 If uuid is None, the connection will be to the controller. Otherwise it
379 will be to the model.
380
381 """
382 client = cls(endpoint, uuid, username, password, cacert, macaroons,
383 loop)
384 endpoints = [(endpoint, cacert)]
385 while endpoints:
386 _endpoint, _cacert = endpoints.pop(0)
387 success, result, new_endpoints = await client._try_endpoint(
388 _endpoint, _cacert)
389 if success:
390 break
391 endpoints.extend(new_endpoints)
392 else:
393 # ran out of endpoints without a successful login
394 raise Exception("Couldn't authenticate to {}".format(endpoint))
395
396 response = result['response']
397 client.info = response.copy()
398 client.build_facades(response.get('facades', {}))
399 client.loop.create_task(client.pinger())
400 client.monitor.pinger_stopped.clear()
401
402 return client
403
404 @classmethod
405 async def connect_current(cls, loop=None):
406 """Connect to the currently active model.
407
408 """
409 jujudata = JujuData()
410 controller_name = jujudata.current_controller()
411 model_name = jujudata.current_model()
412
413 return await cls.connect_model(
414 '{}:{}'.format(controller_name, model_name), loop)
415
416 @classmethod
417 async def connect_current_controller(cls, loop=None):
418 """Connect to the currently active controller.
419
420 """
421 jujudata = JujuData()
422 controller_name = jujudata.current_controller()
423 if not controller_name:
424 raise JujuConnectionError('No current controller')
425
426 return await cls.connect_controller(controller_name, loop)
427
428 @classmethod
429 async def connect_controller(cls, controller_name, loop=None):
430 """Connect to a controller by name.
431
432 """
433 jujudata = JujuData()
434 controller = jujudata.controllers()[controller_name]
435 endpoint = controller['api-endpoints'][0]
436 cacert = controller.get('ca-cert')
437 accounts = jujudata.accounts()[controller_name]
438 username = accounts['user']
439 password = accounts.get('password')
440 macaroons = get_macaroons() if not password else None
441
442 return await cls.connect(
443 endpoint, None, username, password, cacert, macaroons, loop)
444
445 @classmethod
446 async def connect_model(cls, model, loop=None):
447 """Connect to a model by name.
448
449 :param str model: [<controller>:]<model>
450
451 """
452 jujudata = JujuData()
453
454 if ':' in model:
455 # explicit controller given
456 controller_name, model_name = model.split(':')
457 else:
458 # use the current controller if one isn't explicitly given
459 controller_name = jujudata.current_controller()
460 model_name = model
461
462 accounts = jujudata.accounts()[controller_name]
463 username = accounts['user']
464 # model name must include a user prefix, so add it if it doesn't
465 if '/' not in model_name:
466 model_name = '{}/{}'.format(username, model_name)
467
468 controller = jujudata.controllers()[controller_name]
469 endpoint = controller['api-endpoints'][0]
470 cacert = controller.get('ca-cert')
471 password = accounts.get('password')
472 models = jujudata.models()[controller_name]
473 model_uuid = models['models'][model_name]['uuid']
474 macaroons = get_macaroons() if not password else None
475
476 return await cls.connect(
477 endpoint, model_uuid, username, password, cacert, macaroons, loop)
478
479 def build_facades(self, facades):
480 self.facades.clear()
481 # In order to work around an issue where the juju api is not
482 # returning a complete list of facades, we simply look up the
483 # juju version in a pregenerated map, and use that info to
484 # populate our list of facades.
485
486 # TODO: if a future version of juju fixes this bug, restore
487 # the following code for that version and higher:
488 # for facade in facades:
489 # self.facades[facade['name']] = facade['versions'][-1]
490 try:
491 self.facades = VERSION_MAP[self.info['server-version']]
492 except KeyError:
493 log.warning("Could not find a set of facades for {}. Using "
494 "the latest facade set instead".format(
495 self.info['server-version']))
496 self.facades = VERSION_MAP['latest']
497
498 async def login(self):
499 username = self.username
500 if username and not username.startswith('user-'):
501 username = 'user-{}'.format(username)
502
503 result = await self.rpc({
504 "type": "Admin",
505 "request": "Login",
506 "version": 3,
507 "params": {
508 "auth-tag": username,
509 "credentials": self.password,
510 "nonce": "".join(random.sample(string.printable, 12)),
511 "macaroons": self.macaroons
512 }})
513 return result
514
515 async def redirect_info(self):
516 try:
517 result = await self.rpc({
518 "type": "Admin",
519 "request": "RedirectInfo",
520 "version": 3,
521 })
522 except JujuAPIError as e:
523 if e.message == 'not redirected':
524 return None
525 raise
526 return result['response']
527
528
529 class JujuData:
530 def __init__(self):
531 self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
532 self.path = os.path.abspath(os.path.expanduser(self.path))
533
534 def current_controller(self):
535 cmd = shlex.split('juju list-controllers --format yaml')
536 output = subprocess.check_output(cmd)
537 output = yaml.safe_load(output)
538 return output.get('current-controller', '')
539
540 def current_model(self, controller_name=None):
541 if not controller_name:
542 controller_name = self.current_controller()
543 models = self.models()[controller_name]
544 if 'current-model' not in models:
545 raise JujuError('No current model')
546 return models['current-model']
547
548 def controllers(self):
549 return self._load_yaml('controllers.yaml', 'controllers')
550
551 def models(self):
552 return self._load_yaml('models.yaml', 'controllers')
553
554 def accounts(self):
555 return self._load_yaml('accounts.yaml', 'controllers')
556
557 def _load_yaml(self, filename, key):
558 filepath = os.path.join(self.path, filename)
559 with io.open(filepath, 'rt') as f:
560 return yaml.safe_load(f)[key]
561
562
563 def get_macaroons():
564 """Decode and return macaroons from default ~/.go-cookies
565
566 """
567 try:
568 cookie_file = os.path.expanduser('~/.go-cookies')
569 with open(cookie_file, 'r') as f:
570 cookies = json.load(f)
571 except (OSError, ValueError):
572 log.warn("Couldn't load macaroons from %s", cookie_file)
573 return []
574
575 base64_macaroons = [
576 c['Value'] for c in cookies
577 if c['Name'].startswith('macaroon-') and c['Value']
578 ]
579
580 return [
581 json.loads(base64.b64decode(value).decode('utf-8'))
582 for value in base64_macaroons
583 ]