Remove VERSION_MAP and rely on facade list from controller (#118)
[osm/N2VC.git] / juju / client / connection.py
1 import base64
2 import io
3 import json
4 import logging
5 import os
6 import random
7 import shlex
8 import ssl
9 import string
10 import subprocess
11 import websockets
12 from concurrent.futures import CancelledError
13 from http.client import HTTPSConnection
14
15 import asyncio
16 import yaml
17
18 from juju import tag, utils
19 from juju.client import client
20 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
21 from juju.utils import IdQueue
22
23 log = logging.getLogger("websocket")
24
25
26 class Monitor:
27 """
28 Monitor helper class for our Connection class.
29
30 Contains a reference to an instantiated Connection, along with a
31 reference to the Connection.receiver Future. Upon inspecttion of
32 these objects, this class determines whether the connection is in
33 an 'error', 'connected' or 'disconnected' state.
34
35 Use this class to stay up to date on the health of a connection,
36 and take appropriate action if the connection errors out due to
37 network issues or other unexpected circumstances.
38
39 """
40 ERROR = 'error'
41 CONNECTED = 'connected'
42 DISCONNECTED = 'disconnected'
43 UNKNOWN = 'unknown'
44
45 def __init__(self, connection):
46 self.connection = connection
47 self.close_called = asyncio.Event(loop=self.connection.loop)
48 self.receiver_stopped = asyncio.Event(loop=self.connection.loop)
49 self.pinger_stopped = asyncio.Event(loop=self.connection.loop)
50 self.receiver_stopped.set()
51 self.pinger_stopped.set()
52
53 @property
54 def status(self):
55 """
56 Determine the status of the connection and receiver, and return
57 ERROR, CONNECTED, or DISCONNECTED as appropriate.
58
59 For simplicity, we only consider ourselves to be connected
60 after the Connection class has setup a receiver task. This
61 only happens after the websocket is open, and the connection
62 isn't usable until that receiver has been started.
63
64 """
65
66 # DISCONNECTED: connection not yet open
67 if not self.connection.ws:
68 return self.DISCONNECTED
69 if self.receiver_stopped.is_set():
70 return self.DISCONNECTED
71
72 # ERROR: Connection closed (or errored), but we didn't call
73 # connection.close
74 if not self.close_called.is_set() and self.receiver_stopped.is_set():
75 return self.ERROR
76 if not self.close_called.is_set() and not self.connection.ws.open:
77 # The check for self.receiver_stopped existing above guards
78 # against the case where we're not open because we simply
79 # haven't setup the connection yet.
80 return self.ERROR
81
82 # DISCONNECTED: cleanly disconnected.
83 if self.close_called.is_set() and not self.connection.ws.open:
84 return self.DISCONNECTED
85
86 # CONNECTED: everything is fine!
87 if self.connection.ws.open:
88 return self.CONNECTED
89
90 # UNKNOWN: We should never hit this state -- if we do,
91 # something went wrong with the logic above, and we do not
92 # know what state the connection is in.
93 return self.UNKNOWN
94
95
96 class Connection:
97 """
98 Usage::
99
100 # Connect to an arbitrary api server
101 client = await Connection.connect(
102 api_endpoint, model_uuid, username, password, cacert)
103
104 # Connect using a controller/model name
105 client = await Connection.connect_model('local.local:default')
106
107 # Connect to the currently active model
108 client = await Connection.connect_current()
109
110 Note: Any connection method or constructor can accept an optional `loop`
111 argument to override the default event loop from `asyncio.get_event_loop`.
112 """
113 def __init__(
114 self, endpoint, uuid, username, password, cacert=None,
115 macaroons=None, loop=None):
116 self.endpoint = endpoint
117 self.uuid = uuid
118 if macaroons:
119 self.macaroons = macaroons
120 self.username = ''
121 self.password = ''
122 else:
123 self.macaroons = []
124 self.username = username
125 self.password = password
126 self.cacert = cacert
127 self.loop = loop or asyncio.get_event_loop()
128
129 self.__request_id__ = 0
130 self.addr = None
131 self.ws = None
132 self.facades = {}
133 self.messages = IdQueue(loop=self.loop)
134 self.monitor = Monitor(connection=self)
135
136 @property
137 def is_open(self):
138 if self.ws:
139 return self.ws.open
140 return False
141
142 def _get_ssl(self, cert=None):
143 return ssl.create_default_context(
144 purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
145
146 async def open(self):
147 if self.uuid:
148 url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
149 else:
150 url = "wss://{}/api".format(self.endpoint)
151
152 kw = dict()
153 kw['ssl'] = self._get_ssl(self.cacert)
154 kw['loop'] = self.loop
155 self.addr = url
156 self.ws = await websockets.connect(url, **kw)
157 self.loop.create_task(self.receiver())
158 self.monitor.receiver_stopped.clear()
159 log.info("Driver connected to juju %s", url)
160 self.monitor.close_called.clear()
161 return self
162
163 async def close(self):
164 if not self.is_open:
165 return
166 self.monitor.close_called.set()
167 await self.monitor.pinger_stopped.wait()
168 await self.monitor.receiver_stopped.wait()
169 await self.ws.close()
170
171 async def recv(self, request_id):
172 if not self.is_open:
173 raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
174 return await self.messages.get(request_id)
175
176 async def receiver(self):
177 try:
178 while self.is_open:
179 result = await utils.run_with_interrupt(
180 self.ws.recv(),
181 self.monitor.close_called,
182 loop=self.loop)
183 if self.monitor.close_called.is_set():
184 break
185 if result is not None:
186 result = json.loads(result)
187 await self.messages.put(result['request-id'], result)
188 except CancelledError:
189 pass
190 except Exception as e:
191 await self.messages.put_all(e)
192 if isinstance(e, websockets.ConnectionClosed):
193 # ConnectionClosed is not really exceptional for us,
194 # but it may be for any pending message listeners
195 return
196 log.exception("Error in receiver")
197 raise
198 finally:
199 self.monitor.receiver_stopped.set()
200
201 async def pinger(self):
202 '''
203 A Controller can time us out if we are silent for too long. This
204 is especially true in JaaS, which has a fairly strict timeout.
205
206 To prevent timing out, we send a ping every ten seconds.
207
208 '''
209 async def _do_ping():
210 try:
211 await pinger_facade.Ping()
212 await asyncio.sleep(10, loop=self.loop)
213 except CancelledError:
214 pass
215
216 pinger_facade = client.PingerFacade.from_connection(self)
217 try:
218 while self.is_open:
219 await utils.run_with_interrupt(
220 _do_ping(),
221 self.monitor.close_called,
222 loop=self.loop)
223 if self.monitor.close_called.is_set():
224 break
225 finally:
226 self.monitor.pinger_stopped.set()
227
228 async def rpc(self, msg, encoder=None):
229 self.__request_id__ += 1
230 msg['request-id'] = self.__request_id__
231 if'params' not in msg:
232 msg['params'] = {}
233 if "version" not in msg:
234 msg['version'] = self.facades[msg['type']]
235 outgoing = json.dumps(msg, indent=2, cls=encoder)
236 await self.ws.send(outgoing)
237 result = await self.recv(msg['request-id'])
238
239 if not result:
240 return result
241
242 if 'error' in result:
243 # API Error Response
244 raise JujuAPIError(result)
245
246 if 'response' not in result:
247 # This may never happen
248 return result
249
250 if 'results' in result['response']:
251 # Check for errors in a result list.
252 errors = []
253 for res in result['response']['results']:
254 if res.get('error', {}).get('message'):
255 errors.append(res['error']['message'])
256 if errors:
257 raise JujuError(errors)
258
259 elif result['response'].get('error', {}).get('message'):
260 raise JujuError(result['response']['error']['message'])
261
262 return result
263
264 def http_headers(self):
265 """Return dictionary of http headers necessary for making an http
266 connection to the endpoint of this Connection.
267
268 :return: Dictionary of headers
269
270 """
271 if not self.username:
272 return {}
273
274 creds = u'{}:{}'.format(
275 tag.user(self.username),
276 self.password or ''
277 )
278 token = base64.b64encode(creds.encode())
279 return {
280 'Authorization': 'Basic {}'.format(token.decode())
281 }
282
283 def https_connection(self):
284 """Return an https connection to this Connection's endpoint.
285
286 Returns a 3-tuple containing::
287
288 1. The :class:`HTTPSConnection` instance
289 2. Dictionary of auth headers to be used with the connection
290 3. The root url path (str) to be used for requests.
291
292 """
293 endpoint = self.endpoint
294 host, remainder = endpoint.split(':', 1)
295 port = remainder
296 if '/' in remainder:
297 port, _ = remainder.split('/', 1)
298
299 conn = HTTPSConnection(
300 host, int(port),
301 context=self._get_ssl(self.cacert),
302 )
303
304 path = (
305 "/model/{}".format(self.uuid)
306 if self.uuid else ""
307 )
308 return conn, self.http_headers(), path
309
310 async def clone(self):
311 """Return a new Connection, connected to the same websocket endpoint
312 as this one.
313
314 """
315 return await Connection.connect(
316 self.endpoint,
317 self.uuid,
318 self.username,
319 self.password,
320 self.cacert,
321 self.macaroons,
322 self.loop,
323 )
324
325 async def controller(self):
326 """Return a Connection to the controller at self.endpoint
327
328 """
329 return await Connection.connect(
330 self.endpoint,
331 None,
332 self.username,
333 self.password,
334 self.cacert,
335 self.macaroons,
336 self.loop,
337 )
338
339 async def _try_endpoint(self, endpoint, cacert):
340 success = False
341 result = None
342 new_endpoints = []
343
344 self.endpoint = endpoint
345 self.cacert = cacert
346 await self.open()
347 try:
348 result = await self.login()
349 if 'discharge-required-error' in result['response']:
350 log.info('Macaroon discharge required, disconnecting')
351 else:
352 # successful login!
353 log.info('Authenticated')
354 success = True
355 except JujuAPIError as e:
356 if e.error_code != 'redirection required':
357 raise
358 log.info('Controller requested redirect')
359 redirect_info = await self.redirect_info()
360 redir_cacert = redirect_info['ca-cert']
361 new_endpoints = [
362 ("{value}:{port}".format(**s), redir_cacert)
363 for servers in redirect_info['servers']
364 for s in servers if s["scope"] == 'public'
365 ]
366 finally:
367 if not success:
368 await self.close()
369 return success, result, new_endpoints
370
371 @classmethod
372 async def connect(
373 cls, endpoint, uuid, username, password, cacert=None,
374 macaroons=None, loop=None):
375 """Connect to the websocket.
376
377 If uuid is None, the connection will be to the controller. Otherwise it
378 will be to the model.
379
380 """
381 client = cls(endpoint, uuid, username, password, cacert, macaroons,
382 loop)
383 endpoints = [(endpoint, cacert)]
384 while endpoints:
385 _endpoint, _cacert = endpoints.pop(0)
386 success, result, new_endpoints = await client._try_endpoint(
387 _endpoint, _cacert)
388 if success:
389 break
390 endpoints.extend(new_endpoints)
391 else:
392 # ran out of endpoints without a successful login
393 raise Exception("Couldn't authenticate to {}".format(endpoint))
394
395 response = result['response']
396 client.info = response.copy()
397 client.build_facades(response.get('facades', {}))
398 client.loop.create_task(client.pinger())
399 client.monitor.pinger_stopped.clear()
400
401 return client
402
403 @classmethod
404 async def connect_current(cls, loop=None):
405 """Connect to the currently active model.
406
407 """
408 jujudata = JujuData()
409 controller_name = jujudata.current_controller()
410 model_name = jujudata.current_model()
411
412 return await cls.connect_model(
413 '{}:{}'.format(controller_name, model_name), loop)
414
415 @classmethod
416 async def connect_current_controller(cls, loop=None):
417 """Connect to the currently active controller.
418
419 """
420 jujudata = JujuData()
421 controller_name = jujudata.current_controller()
422 if not controller_name:
423 raise JujuConnectionError('No current controller')
424
425 return await cls.connect_controller(controller_name, loop)
426
427 @classmethod
428 async def connect_controller(cls, controller_name, loop=None):
429 """Connect to a controller by name.
430
431 """
432 jujudata = JujuData()
433 controller = jujudata.controllers()[controller_name]
434 endpoint = controller['api-endpoints'][0]
435 cacert = controller.get('ca-cert')
436 accounts = jujudata.accounts()[controller_name]
437 username = accounts['user']
438 password = accounts.get('password')
439 macaroons = get_macaroons() if not password else None
440
441 return await cls.connect(
442 endpoint, None, username, password, cacert, macaroons, loop)
443
444 @classmethod
445 async def connect_model(cls, model, loop=None):
446 """Connect to a model by name.
447
448 :param str model: [<controller>:]<model>
449
450 """
451 jujudata = JujuData()
452
453 if ':' in model:
454 # explicit controller given
455 controller_name, model_name = model.split(':')
456 else:
457 # use the current controller if one isn't explicitly given
458 controller_name = jujudata.current_controller()
459 model_name = model
460
461 accounts = jujudata.accounts()[controller_name]
462 username = accounts['user']
463 # model name must include a user prefix, so add it if it doesn't
464 if '/' not in model_name:
465 model_name = '{}/{}'.format(username, model_name)
466
467 controller = jujudata.controllers()[controller_name]
468 endpoint = controller['api-endpoints'][0]
469 cacert = controller.get('ca-cert')
470 password = accounts.get('password')
471 models = jujudata.models()[controller_name]
472 model_uuid = models['models'][model_name]['uuid']
473 macaroons = get_macaroons() if not password else None
474
475 return await cls.connect(
476 endpoint, model_uuid, username, password, cacert, macaroons, loop)
477
478 def build_facades(self, facades):
479 self.facades.clear()
480 for facade in facades:
481 self.facades[facade['name']] = facade['versions'][-1]
482
483 async def login(self):
484 username = self.username
485 if username and not username.startswith('user-'):
486 username = 'user-{}'.format(username)
487
488 result = await self.rpc({
489 "type": "Admin",
490 "request": "Login",
491 "version": 3,
492 "params": {
493 "auth-tag": username,
494 "credentials": self.password,
495 "nonce": "".join(random.sample(string.printable, 12)),
496 "macaroons": self.macaroons
497 }})
498 return result
499
500 async def redirect_info(self):
501 try:
502 result = await self.rpc({
503 "type": "Admin",
504 "request": "RedirectInfo",
505 "version": 3,
506 })
507 except JujuAPIError as e:
508 if e.message == 'not redirected':
509 return None
510 raise
511 return result['response']
512
513
514 class JujuData:
515 def __init__(self):
516 self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
517 self.path = os.path.abspath(os.path.expanduser(self.path))
518
519 def current_controller(self):
520 cmd = shlex.split('juju list-controllers --format yaml')
521 output = subprocess.check_output(cmd)
522 output = yaml.safe_load(output)
523 return output.get('current-controller', '')
524
525 def current_model(self, controller_name=None):
526 if not controller_name:
527 controller_name = self.current_controller()
528 models = self.models()[controller_name]
529 if 'current-model' not in models:
530 raise JujuError('No current model')
531 return models['current-model']
532
533 def controllers(self):
534 return self._load_yaml('controllers.yaml', 'controllers')
535
536 def models(self):
537 return self._load_yaml('models.yaml', 'controllers')
538
539 def accounts(self):
540 return self._load_yaml('accounts.yaml', 'controllers')
541
542 def _load_yaml(self, filename, key):
543 filepath = os.path.join(self.path, filename)
544 with io.open(filepath, 'rt') as f:
545 return yaml.safe_load(f)[key]
546
547
548 def get_macaroons():
549 """Decode and return macaroons from default ~/.go-cookies
550
551 """
552 try:
553 cookie_file = os.path.expanduser('~/.go-cookies')
554 with open(cookie_file, 'r') as f:
555 cookies = json.load(f)
556 except (OSError, ValueError):
557 log.warn("Couldn't load macaroons from %s", cookie_file)
558 return []
559
560 base64_macaroons = [
561 c['Value'] for c in cookies
562 if c['Name'].startswith('macaroon-') and c['Value']
563 ]
564
565 return [
566 json.loads(base64.b64decode(value).decode('utf-8'))
567 for value in base64_macaroons
568 ]