Configurable and larger max message size (#146)
[osm/N2VC.git] / juju / client / connection.py
1 import base64
2 import io
3 import json
4 import logging
5 import os
6 import random
7 import shlex
8 import ssl
9 import string
10 import subprocess
11 import websockets
12 from concurrent.futures import CancelledError
13 from http.client import HTTPSConnection
14 from pathlib import Path
15
16 import asyncio
17 import yaml
18
19 from juju import tag, utils
20 from juju.client import client
21 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
22 from juju.utils import IdQueue
23
24 log = logging.getLogger("websocket")
25
26
27 class Monitor:
28 """
29 Monitor helper class for our Connection class.
30
31 Contains a reference to an instantiated Connection, along with a
32 reference to the Connection.receiver Future. Upon inspecttion of
33 these objects, this class determines whether the connection is in
34 an 'error', 'connected' or 'disconnected' state.
35
36 Use this class to stay up to date on the health of a connection,
37 and take appropriate action if the connection errors out due to
38 network issues or other unexpected circumstances.
39
40 """
41 ERROR = 'error'
42 CONNECTED = 'connected'
43 DISCONNECTED = 'disconnected'
44 UNKNOWN = 'unknown'
45
46 def __init__(self, connection):
47 self.connection = connection
48 self.close_called = asyncio.Event(loop=self.connection.loop)
49 self.receiver_stopped = asyncio.Event(loop=self.connection.loop)
50 self.pinger_stopped = asyncio.Event(loop=self.connection.loop)
51 self.receiver_stopped.set()
52 self.pinger_stopped.set()
53
54 @property
55 def status(self):
56 """
57 Determine the status of the connection and receiver, and return
58 ERROR, CONNECTED, or DISCONNECTED as appropriate.
59
60 For simplicity, we only consider ourselves to be connected
61 after the Connection class has setup a receiver task. This
62 only happens after the websocket is open, and the connection
63 isn't usable until that receiver has been started.
64
65 """
66
67 # DISCONNECTED: connection not yet open
68 if not self.connection.ws:
69 return self.DISCONNECTED
70 if self.receiver_stopped.is_set():
71 return self.DISCONNECTED
72
73 # ERROR: Connection closed (or errored), but we didn't call
74 # connection.close
75 if not self.close_called.is_set() and self.receiver_stopped.is_set():
76 return self.ERROR
77 if not self.close_called.is_set() and not self.connection.ws.open:
78 # The check for self.receiver_stopped existing above guards
79 # against the case where we're not open because we simply
80 # haven't setup the connection yet.
81 return self.ERROR
82
83 # DISCONNECTED: cleanly disconnected.
84 if self.close_called.is_set() and not self.connection.ws.open:
85 return self.DISCONNECTED
86
87 # CONNECTED: everything is fine!
88 if self.connection.ws.open:
89 return self.CONNECTED
90
91 # UNKNOWN: We should never hit this state -- if we do,
92 # something went wrong with the logic above, and we do not
93 # know what state the connection is in.
94 return self.UNKNOWN
95
96
97 class Connection:
98 """
99 Usage::
100
101 # Connect to an arbitrary api server
102 client = await Connection.connect(
103 api_endpoint, model_uuid, username, password, cacert)
104
105 # Connect using a controller/model name
106 client = await Connection.connect_model('local.local:default')
107
108 # Connect to the currently active model
109 client = await Connection.connect_current()
110
111 Note: Any connection method or constructor can accept an optional `loop`
112 argument to override the default event loop from `asyncio.get_event_loop`.
113 """
114
115 DEFAULT_FRAME_SIZE = 'default_frame_size'
116 MAX_FRAME_SIZE = 2**22
117 "Maximum size for a single frame. Defaults to 4MB."
118
119 def __init__(
120 self, endpoint, uuid, username, password, cacert=None,
121 macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
122 self.endpoint = endpoint
123 self.uuid = uuid
124 if macaroons:
125 self.macaroons = macaroons
126 self.username = ''
127 self.password = ''
128 else:
129 self.macaroons = []
130 self.username = username
131 self.password = password
132 self.cacert = cacert
133 self.loop = loop or asyncio.get_event_loop()
134
135 self.__request_id__ = 0
136 self.addr = None
137 self.ws = None
138 self.facades = {}
139 self.messages = IdQueue(loop=self.loop)
140 self.monitor = Monitor(connection=self)
141 if max_frame_size is self.DEFAULT_FRAME_SIZE:
142 max_frame_size = self.MAX_FRAME_SIZE
143 self.max_frame_size = max_frame_size
144
145 @property
146 def is_open(self):
147 if self.ws:
148 return self.ws.open
149 return False
150
151 def _get_ssl(self, cert=None):
152 return ssl.create_default_context(
153 purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
154
155 async def open(self):
156 if self.uuid:
157 url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
158 else:
159 url = "wss://{}/api".format(self.endpoint)
160
161 kw = dict()
162 kw['ssl'] = self._get_ssl(self.cacert)
163 kw['loop'] = self.loop
164 kw['max_size'] = self.max_frame_size
165 self.addr = url
166 self.ws = await websockets.connect(url, **kw)
167 self.loop.create_task(self.receiver())
168 self.monitor.receiver_stopped.clear()
169 log.info("Driver connected to juju %s", url)
170 self.monitor.close_called.clear()
171 return self
172
173 async def close(self):
174 if not self.is_open:
175 return
176 self.monitor.close_called.set()
177 await self.monitor.pinger_stopped.wait()
178 await self.monitor.receiver_stopped.wait()
179 await self.ws.close()
180
181 async def recv(self, request_id):
182 if not self.is_open:
183 raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
184 return await self.messages.get(request_id)
185
186 async def receiver(self):
187 try:
188 while self.is_open:
189 result = await utils.run_with_interrupt(
190 self.ws.recv(),
191 self.monitor.close_called,
192 loop=self.loop)
193 if self.monitor.close_called.is_set():
194 break
195 if result is not None:
196 result = json.loads(result)
197 await self.messages.put(result['request-id'], result)
198 except CancelledError:
199 pass
200 except Exception as e:
201 await self.messages.put_all(e)
202 if isinstance(e, websockets.ConnectionClosed):
203 # ConnectionClosed is not really exceptional for us,
204 # but it may be for any pending message listeners
205 return
206 log.exception("Error in receiver")
207 raise
208 finally:
209 self.monitor.receiver_stopped.set()
210
211 async def pinger(self):
212 '''
213 A Controller can time us out if we are silent for too long. This
214 is especially true in JaaS, which has a fairly strict timeout.
215
216 To prevent timing out, we send a ping every ten seconds.
217
218 '''
219 async def _do_ping():
220 try:
221 await pinger_facade.Ping()
222 await asyncio.sleep(10, loop=self.loop)
223 except CancelledError:
224 pass
225
226 pinger_facade = client.PingerFacade.from_connection(self)
227 try:
228 while self.is_open:
229 await utils.run_with_interrupt(
230 _do_ping(),
231 self.monitor.close_called,
232 loop=self.loop)
233 if self.monitor.close_called.is_set():
234 break
235 finally:
236 self.monitor.pinger_stopped.set()
237
238 async def rpc(self, msg, encoder=None):
239 self.__request_id__ += 1
240 msg['request-id'] = self.__request_id__
241 if'params' not in msg:
242 msg['params'] = {}
243 if "version" not in msg:
244 msg['version'] = self.facades[msg['type']]
245 outgoing = json.dumps(msg, indent=2, cls=encoder)
246 await self.ws.send(outgoing)
247 result = await self.recv(msg['request-id'])
248
249 if not result:
250 return result
251
252 if 'error' in result:
253 # API Error Response
254 raise JujuAPIError(result)
255
256 if 'response' not in result:
257 # This may never happen
258 return result
259
260 if 'results' in result['response']:
261 # Check for errors in a result list.
262 errors = []
263 for res in result['response']['results']:
264 if res.get('error', {}).get('message'):
265 errors.append(res['error']['message'])
266 if errors:
267 raise JujuError(errors)
268
269 elif result['response'].get('error', {}).get('message'):
270 raise JujuError(result['response']['error']['message'])
271
272 return result
273
274 def http_headers(self):
275 """Return dictionary of http headers necessary for making an http
276 connection to the endpoint of this Connection.
277
278 :return: Dictionary of headers
279
280 """
281 if not self.username:
282 return {}
283
284 creds = u'{}:{}'.format(
285 tag.user(self.username),
286 self.password or ''
287 )
288 token = base64.b64encode(creds.encode())
289 return {
290 'Authorization': 'Basic {}'.format(token.decode())
291 }
292
293 def https_connection(self):
294 """Return an https connection to this Connection's endpoint.
295
296 Returns a 3-tuple containing::
297
298 1. The :class:`HTTPSConnection` instance
299 2. Dictionary of auth headers to be used with the connection
300 3. The root url path (str) to be used for requests.
301
302 """
303 endpoint = self.endpoint
304 host, remainder = endpoint.split(':', 1)
305 port = remainder
306 if '/' in remainder:
307 port, _ = remainder.split('/', 1)
308
309 conn = HTTPSConnection(
310 host, int(port),
311 context=self._get_ssl(self.cacert),
312 )
313
314 path = (
315 "/model/{}".format(self.uuid)
316 if self.uuid else ""
317 )
318 return conn, self.http_headers(), path
319
320 async def clone(self):
321 """Return a new Connection, connected to the same websocket endpoint
322 as this one.
323
324 """
325 return await Connection.connect(
326 self.endpoint,
327 self.uuid,
328 self.username,
329 self.password,
330 self.cacert,
331 self.macaroons,
332 self.loop,
333 self.max_frame_size,
334 )
335
336 async def controller(self):
337 """Return a Connection to the controller at self.endpoint
338
339 """
340 return await Connection.connect(
341 self.endpoint,
342 None,
343 self.username,
344 self.password,
345 self.cacert,
346 self.macaroons,
347 self.loop,
348 )
349
350 async def _try_endpoint(self, endpoint, cacert):
351 success = False
352 result = None
353 new_endpoints = []
354
355 self.endpoint = endpoint
356 self.cacert = cacert
357 await self.open()
358 try:
359 result = await self.login()
360 if 'discharge-required-error' in result['response']:
361 log.info('Macaroon discharge required, disconnecting')
362 else:
363 # successful login!
364 log.info('Authenticated')
365 success = True
366 except JujuAPIError as e:
367 if e.error_code != 'redirection required':
368 raise
369 log.info('Controller requested redirect')
370 redirect_info = await self.redirect_info()
371 redir_cacert = redirect_info['ca-cert']
372 new_endpoints = [
373 ("{value}:{port}".format(**s), redir_cacert)
374 for servers in redirect_info['servers']
375 for s in servers if s["scope"] == 'public'
376 ]
377 finally:
378 if not success:
379 await self.close()
380 return success, result, new_endpoints
381
382 @classmethod
383 async def connect(
384 cls, endpoint, uuid, username, password, cacert=None,
385 macaroons=None, loop=None, max_frame_size=None):
386 """Connect to the websocket.
387
388 If uuid is None, the connection will be to the controller. Otherwise it
389 will be to the model.
390
391 """
392 client = cls(endpoint, uuid, username, password, cacert, macaroons,
393 loop, max_frame_size)
394 endpoints = [(endpoint, cacert)]
395 while endpoints:
396 _endpoint, _cacert = endpoints.pop(0)
397 success, result, new_endpoints = await client._try_endpoint(
398 _endpoint, _cacert)
399 if success:
400 break
401 endpoints.extend(new_endpoints)
402 else:
403 # ran out of endpoints without a successful login
404 raise Exception("Couldn't authenticate to {}".format(endpoint))
405
406 response = result['response']
407 client.info = response.copy()
408 client.build_facades(response.get('facades', {}))
409 client.loop.create_task(client.pinger())
410 client.monitor.pinger_stopped.clear()
411
412 return client
413
414 @classmethod
415 async def connect_current(cls, loop=None, max_frame_size=None):
416 """Connect to the currently active model.
417
418 """
419 jujudata = JujuData()
420
421 controller_name = jujudata.current_controller()
422 if not controller_name:
423 raise JujuConnectionError('No current controller')
424
425 model_name = jujudata.current_model()
426
427 return await cls.connect_model(
428 '{}:{}'.format(controller_name, model_name), loop, max_frame_size)
429
430 @classmethod
431 async def connect_current_controller(cls, loop=None, max_frame_size=None):
432 """Connect to the currently active controller.
433
434 """
435 jujudata = JujuData()
436 controller_name = jujudata.current_controller()
437 if not controller_name:
438 raise JujuConnectionError('No current controller')
439
440 return await cls.connect_controller(controller_name, loop,
441 max_frame_size)
442
443 @classmethod
444 async def connect_controller(cls, controller_name, loop=None,
445 max_frame_size=None):
446 """Connect to a controller by name.
447
448 """
449 jujudata = JujuData()
450 controller = jujudata.controllers()[controller_name]
451 endpoint = controller['api-endpoints'][0]
452 cacert = controller.get('ca-cert')
453 accounts = jujudata.accounts()[controller_name]
454 username = accounts['user']
455 password = accounts.get('password')
456 macaroons = get_macaroons(controller_name) if not password else None
457
458 return await cls.connect(
459 endpoint, None, username, password, cacert, macaroons, loop,
460 max_frame_size)
461
462 @classmethod
463 async def connect_model(cls, model, loop=None, max_frame_size=None):
464 """Connect to a model by name.
465
466 :param str model: [<controller>:]<model>
467
468 """
469 jujudata = JujuData()
470
471 if ':' in model:
472 # explicit controller given
473 controller_name, model_name = model.split(':')
474 else:
475 # use the current controller if one isn't explicitly given
476 controller_name = jujudata.current_controller()
477 model_name = model
478
479 accounts = jujudata.accounts()[controller_name]
480 username = accounts['user']
481 # model name must include a user prefix, so add it if it doesn't
482 if '/' not in model_name:
483 model_name = '{}/{}'.format(username, model_name)
484
485 controller = jujudata.controllers()[controller_name]
486 endpoint = controller['api-endpoints'][0]
487 cacert = controller.get('ca-cert')
488 password = accounts.get('password')
489 models = jujudata.models()[controller_name]
490 model_uuid = models['models'][model_name]['uuid']
491 macaroons = get_macaroons(controller_name) if not password else None
492
493 return await cls.connect(
494 endpoint, model_uuid, username, password, cacert, macaroons, loop,
495 max_frame_size)
496
497 def build_facades(self, facades):
498 self.facades.clear()
499 for facade in facades:
500 self.facades[facade['name']] = facade['versions'][-1]
501
502 async def login(self):
503 username = self.username
504 if username and not username.startswith('user-'):
505 username = 'user-{}'.format(username)
506
507 result = await self.rpc({
508 "type": "Admin",
509 "request": "Login",
510 "version": 3,
511 "params": {
512 "auth-tag": username,
513 "credentials": self.password,
514 "nonce": "".join(random.sample(string.printable, 12)),
515 "macaroons": self.macaroons
516 }})
517 return result
518
519 async def redirect_info(self):
520 try:
521 result = await self.rpc({
522 "type": "Admin",
523 "request": "RedirectInfo",
524 "version": 3,
525 })
526 except JujuAPIError as e:
527 if e.message == 'not redirected':
528 return None
529 raise
530 return result['response']
531
532
533 class JujuData:
534 def __init__(self):
535 self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
536 self.path = os.path.abspath(os.path.expanduser(self.path))
537
538 def current_controller(self):
539 cmd = shlex.split('juju list-controllers --format yaml')
540 output = subprocess.check_output(cmd)
541 output = yaml.safe_load(output)
542 return output.get('current-controller', '')
543
544 def current_model(self, controller_name=None):
545 if not controller_name:
546 controller_name = self.current_controller()
547 models = self.models()[controller_name]
548 if 'current-model' not in models:
549 raise JujuError('No current model')
550 return models['current-model']
551
552 def controllers(self):
553 return self._load_yaml('controllers.yaml', 'controllers')
554
555 def models(self):
556 return self._load_yaml('models.yaml', 'controllers')
557
558 def accounts(self):
559 return self._load_yaml('accounts.yaml', 'controllers')
560
561 def _load_yaml(self, filename, key):
562 filepath = os.path.join(self.path, filename)
563 with io.open(filepath, 'rt') as f:
564 return yaml.safe_load(f)[key]
565
566
567 def get_macaroons(controller_name=None):
568 """Decode and return macaroons from default ~/.go-cookies
569
570 """
571 cookie_files = []
572 if controller_name:
573 cookie_files.append('~/.local/share/juju/cookies/{}.json'.format(
574 controller_name))
575 cookie_files.append('~/.go-cookies')
576 for cookie_file in cookie_files:
577 cookie_file = Path(cookie_file).expanduser()
578 if cookie_file.exists():
579 try:
580 cookies = json.loads(cookie_file.read_text())
581 break
582 except (OSError, ValueError):
583 log.warn("Couldn't load macaroons from %s", cookie_file)
584 return []
585 else:
586 log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files))
587 return []
588
589 base64_macaroons = [
590 c['Value'] for c in cookies
591 if c['Name'].startswith('macaroon-') and c['Value']
592 ]
593
594 return [
595 json.loads(base64.b64decode(value).decode('utf-8'))
596 for value in base64_macaroons
597 ]