Refactored login code to better handle redirects (#116)
[osm/N2VC.git] / juju / client / connection.py
1 import base64
2 import io
3 import json
4 import logging
5 import os
6 import random
7 import shlex
8 import ssl
9 import string
10 import subprocess
11 import websockets
12 from http.client import HTTPSConnection
13
14 import asyncio
15 import yaml
16
17 from juju import tag
18 from juju.client import client
19 from juju.client.version_map import VERSION_MAP
20 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
21 from juju.utils import IdQueue
22
23 log = logging.getLogger("websocket")
24
25
26 class Monitor:
27 """
28 Monitor helper class for our Connection class.
29
30 Contains a reference to an instantiated Connection, along with a
31 reference to the Connection.receiver Future. Upon inspecttion of
32 these objects, this class determines whether the connection is in
33 an 'error', 'connected' or 'disconnected' state.
34
35 Use this class to stay up to date on the health of a connection,
36 and take appropriate action if the connection errors out due to
37 network issues or other unexpected circumstances.
38
39 """
40 ERROR = 'error'
41 CONNECTED = 'connected'
42 DISCONNECTED = 'disconnected'
43 UNKNOWN = 'unknown'
44
45 def __init__(self, connection):
46 self.connection = connection
47 self.receiver = None
48 self.pinger = None
49
50 @property
51 def status(self):
52 """
53 Determine the status of the connection and receiver, and return
54 ERROR, CONNECTED, or DISCONNECTED as appropriate.
55
56 For simplicity, we only consider ourselves to be connected
57 after the Connection class has setup a receiver task. This
58 only happens after the websocket is open, and the connection
59 isn't usable until that receiver has been started.
60
61 """
62
63 # DISCONNECTED: connection not yet open
64 if not self.connection.ws:
65 return self.DISCONNECTED
66 if not self.receiver:
67 return self.DISCONNECTED
68
69 # ERROR: Connection closed (or errored), but we didn't call
70 # connection.close
71 if not self.connection.close_called and self.receiver_exceptions():
72 return self.ERROR
73 if not self.connection.close_called and not self.connection.ws.open:
74 # The check for self.receiver existing above guards against the
75 # case where we're not open because we simply haven't
76 # setup the connection yet.
77 return self.ERROR
78
79 # DISCONNECTED: cleanly disconnected.
80 if self.connection.close_called and not self.connection.ws.open:
81 return self.DISCONNECTED
82
83 # CONNECTED: everything is fine!
84 if self.connection.ws.open:
85 return self.CONNECTED
86
87 # UNKNOWN: We should never hit this state -- if we do,
88 # something went wrong with the logic above, and we do not
89 # know what state the connection is in.
90 return self.UNKNOWN
91
92 def receiver_exceptions(self):
93 """
94 Return exceptions in the receiver, if any.
95
96 """
97 if not self.receiver:
98 return None
99 if not self.receiver.done():
100 return None
101 return self.receiver.exception()
102
103
104 class Connection:
105 """
106 Usage::
107
108 # Connect to an arbitrary api server
109 client = await Connection.connect(
110 api_endpoint, model_uuid, username, password, cacert)
111
112 # Connect using a controller/model name
113 client = await Connection.connect_model('local.local:default')
114
115 # Connect to the currently active model
116 client = await Connection.connect_current()
117
118 Note: Any connection method or constructor can accept an optional `loop`
119 argument to override the default event loop from `asyncio.get_event_loop`.
120 """
121 def __init__(
122 self, endpoint, uuid, username, password, cacert=None,
123 macaroons=None, loop=None):
124 self.endpoint = endpoint
125 self.uuid = uuid
126 if macaroons:
127 self.macaroons = macaroons
128 self.username = ''
129 self.password = ''
130 else:
131 self.macaroons = []
132 self.username = username
133 self.password = password
134 self.cacert = cacert
135 self.loop = loop or asyncio.get_event_loop()
136
137 self.__request_id__ = 0
138 self.addr = None
139 self.ws = None
140 self.facades = {}
141 self.messages = IdQueue(loop=self.loop)
142 self.close_called = False
143 self.monitor = Monitor(connection=self)
144
145 @property
146 def is_open(self):
147 if self.ws:
148 return self.ws.open
149 return False
150
151 def _get_ssl(self, cert=None):
152 return ssl.create_default_context(
153 purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
154
155 async def open(self):
156 if self.uuid:
157 url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
158 else:
159 url = "wss://{}/api".format(self.endpoint)
160
161 kw = dict()
162 kw['ssl'] = self._get_ssl(self.cacert)
163 kw['loop'] = self.loop
164 self.addr = url
165 self.ws = await websockets.connect(url, **kw)
166 self.monitor.receiver = self.loop.create_task(self.receiver())
167 log.info("Driver connected to juju %s", url)
168 return self
169
170 async def close(self):
171 if not self.is_open:
172 return
173 self.close_called = True
174 if self.monitor.pinger:
175 # might be closing due to login failure,
176 # in which case we won't have a pinger yet
177 self.monitor.pinger.cancel()
178 self.monitor.receiver.cancel()
179 await self.ws.close()
180
181 async def recv(self, request_id):
182 if not self.is_open:
183 raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
184 return await self.messages.get(request_id)
185
186 async def receiver(self):
187 while self.is_open:
188 try:
189 result = await self.ws.recv()
190 if result is not None:
191 result = json.loads(result)
192 await self.messages.put(result['request-id'], result)
193 except Exception as e:
194 await self.messages.put_all(e)
195 if isinstance(e, websockets.ConnectionClosed):
196 # ConnectionClosed is not really exceptional for us,
197 # but it may be for any pending message listeners
198 return
199 raise
200
201 async def pinger(self):
202 '''
203 A Controller can time us out if we are silent for too long. This
204 is especially true in JaaS, which has a fairly strict timeout.
205
206 To prevent timing out, we send a ping every ten seconds.
207
208 '''
209 pinger_facade = client.PingerFacade.from_connection(self)
210 while self.is_open:
211 await pinger_facade.Ping()
212 await asyncio.sleep(10, loop=self.loop)
213
214 async def rpc(self, msg, encoder=None):
215 self.__request_id__ += 1
216 msg['request-id'] = self.__request_id__
217 if'params' not in msg:
218 msg['params'] = {}
219 if "version" not in msg:
220 msg['version'] = self.facades[msg['type']]
221 outgoing = json.dumps(msg, indent=2, cls=encoder)
222 await self.ws.send(outgoing)
223 result = await self.recv(msg['request-id'])
224
225 if not result:
226 return result
227
228 if 'error' in result:
229 # API Error Response
230 raise JujuAPIError(result)
231
232 if 'response' not in result:
233 # This may never happen
234 return result
235
236 if 'results' in result['response']:
237 # Check for errors in a result list.
238 errors = []
239 for res in result['response']['results']:
240 if res.get('error', {}).get('message'):
241 errors.append(res['error']['message'])
242 if errors:
243 raise JujuError(errors)
244
245 elif result['response'].get('error', {}).get('message'):
246 raise JujuError(result['response']['error']['message'])
247
248 return result
249
250 def http_headers(self):
251 """Return dictionary of http headers necessary for making an http
252 connection to the endpoint of this Connection.
253
254 :return: Dictionary of headers
255
256 """
257 if not self.username:
258 return {}
259
260 creds = u'{}:{}'.format(
261 tag.user(self.username),
262 self.password or ''
263 )
264 token = base64.b64encode(creds.encode())
265 return {
266 'Authorization': 'Basic {}'.format(token.decode())
267 }
268
269 def https_connection(self):
270 """Return an https connection to this Connection's endpoint.
271
272 Returns a 3-tuple containing::
273
274 1. The :class:`HTTPSConnection` instance
275 2. Dictionary of auth headers to be used with the connection
276 3. The root url path (str) to be used for requests.
277
278 """
279 endpoint = self.endpoint
280 host, remainder = endpoint.split(':', 1)
281 port = remainder
282 if '/' in remainder:
283 port, _ = remainder.split('/', 1)
284
285 conn = HTTPSConnection(
286 host, int(port),
287 context=self._get_ssl(self.cacert),
288 )
289
290 path = (
291 "/model/{}".format(self.uuid)
292 if self.uuid else ""
293 )
294 return conn, self.http_headers(), path
295
296 async def clone(self):
297 """Return a new Connection, connected to the same websocket endpoint
298 as this one.
299
300 """
301 return await Connection.connect(
302 self.endpoint,
303 self.uuid,
304 self.username,
305 self.password,
306 self.cacert,
307 self.macaroons,
308 self.loop,
309 )
310
311 async def controller(self):
312 """Return a Connection to the controller at self.endpoint
313
314 """
315 return await Connection.connect(
316 self.endpoint,
317 None,
318 self.username,
319 self.password,
320 self.cacert,
321 self.macaroons,
322 self.loop,
323 )
324
325 async def _try_endpoint(self, endpoint, cacert):
326 success = False
327 result = None
328 new_endpoints = []
329
330 self.endpoint = endpoint
331 self.cacert = cacert
332 await self.open()
333 try:
334 result = await self.login()
335 if 'discharge-required-error' in result['response']:
336 log.info('Macaroon discharge required, disconnecting')
337 else:
338 # successful login!
339 log.info('Authenticated')
340 success = True
341 except JujuAPIError as e:
342 if e.error_code != 'redirection required':
343 raise
344 log.info('Controller requested redirect')
345 redirect_info = await self.redirect_info()
346 redir_cacert = redirect_info['ca-cert']
347 new_endpoints = [
348 ("{value}:{port}".format(**s), redir_cacert)
349 for servers in redirect_info['servers']
350 for s in servers if s["scope"] == 'public'
351 ]
352 finally:
353 if not success:
354 await self.close()
355 return success, result, new_endpoints
356
357 @classmethod
358 async def connect(
359 cls, endpoint, uuid, username, password, cacert=None,
360 macaroons=None, loop=None):
361 """Connect to the websocket.
362
363 If uuid is None, the connection will be to the controller. Otherwise it
364 will be to the model.
365
366 """
367 client = cls(endpoint, uuid, username, password, cacert, macaroons,
368 loop)
369 endpoints = [(endpoint, cacert)]
370 while endpoints:
371 _endpoint, _cacert = endpoints.pop(0)
372 success, result, new_endpoints = await client._try_endpoint(
373 _endpoint, _cacert)
374 if success:
375 break
376 endpoints.extend(new_endpoints)
377 else:
378 # ran out of endpoints without a successful login
379 raise Exception("Couldn't authenticate to {}".format(endpoint))
380
381 response = result['response']
382 client.info = response.copy()
383 client.build_facades(response.get('facades', {}))
384 client.monitor.pinger = client.loop.create_task(client.pinger())
385
386 return client
387
388 @classmethod
389 async def connect_current(cls, loop=None):
390 """Connect to the currently active model.
391
392 """
393 jujudata = JujuData()
394 controller_name = jujudata.current_controller()
395 model_name = jujudata.current_model()
396
397 return await cls.connect_model(
398 '{}:{}'.format(controller_name, model_name), loop)
399
400 @classmethod
401 async def connect_current_controller(cls, loop=None):
402 """Connect to the currently active controller.
403
404 """
405 jujudata = JujuData()
406 controller_name = jujudata.current_controller()
407 if not controller_name:
408 raise JujuConnectionError('No current controller')
409
410 return await cls.connect_controller(controller_name, loop)
411
412 @classmethod
413 async def connect_controller(cls, controller_name, loop=None):
414 """Connect to a controller by name.
415
416 """
417 jujudata = JujuData()
418 controller = jujudata.controllers()[controller_name]
419 endpoint = controller['api-endpoints'][0]
420 cacert = controller.get('ca-cert')
421 accounts = jujudata.accounts()[controller_name]
422 username = accounts['user']
423 password = accounts.get('password')
424 macaroons = get_macaroons() if not password else None
425
426 return await cls.connect(
427 endpoint, None, username, password, cacert, macaroons, loop)
428
429 @classmethod
430 async def connect_model(cls, model, loop=None):
431 """Connect to a model by name.
432
433 :param str model: [<controller>:]<model>
434
435 """
436 jujudata = JujuData()
437
438 if ':' in model:
439 # explicit controller given
440 controller_name, model_name = model.split(':')
441 else:
442 # use the current controller if one isn't explicitly given
443 controller_name = jujudata.current_controller()
444 model_name = model
445
446 accounts = jujudata.accounts()[controller_name]
447 username = accounts['user']
448 # model name must include a user prefix, so add it if it doesn't
449 if '/' not in model_name:
450 model_name = '{}/{}'.format(username, model_name)
451
452 controller = jujudata.controllers()[controller_name]
453 endpoint = controller['api-endpoints'][0]
454 cacert = controller.get('ca-cert')
455 password = accounts.get('password')
456 models = jujudata.models()[controller_name]
457 model_uuid = models['models'][model_name]['uuid']
458 macaroons = get_macaroons() if not password else None
459
460 return await cls.connect(
461 endpoint, model_uuid, username, password, cacert, macaroons, loop)
462
463 def build_facades(self, facades):
464 self.facades.clear()
465 # In order to work around an issue where the juju api is not
466 # returning a complete list of facades, we simply look up the
467 # juju version in a pregenerated map, and use that info to
468 # populate our list of facades.
469
470 # TODO: if a future version of juju fixes this bug, restore
471 # the following code for that version and higher:
472 # for facade in facades:
473 # self.facades[facade['name']] = facade['versions'][-1]
474 try:
475 self.facades = VERSION_MAP[self.info['server-version']]
476 except KeyError:
477 log.warning("Could not find a set of facades for {}. Using "
478 "the latest facade set instead".format(
479 self.info['server-version']))
480 self.facades = VERSION_MAP['latest']
481
482 async def login(self):
483 username = self.username
484 if username and not username.startswith('user-'):
485 username = 'user-{}'.format(username)
486
487 result = await self.rpc({
488 "type": "Admin",
489 "request": "Login",
490 "version": 3,
491 "params": {
492 "auth-tag": username,
493 "credentials": self.password,
494 "nonce": "".join(random.sample(string.printable, 12)),
495 "macaroons": self.macaroons
496 }})
497 return result
498
499 async def redirect_info(self):
500 try:
501 result = await self.rpc({
502 "type": "Admin",
503 "request": "RedirectInfo",
504 "version": 3,
505 })
506 except JujuAPIError as e:
507 if e.message == 'not redirected':
508 return None
509 raise
510 return result['response']
511
512
513 class JujuData:
514 def __init__(self):
515 self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
516 self.path = os.path.abspath(os.path.expanduser(self.path))
517
518 def current_controller(self):
519 cmd = shlex.split('juju list-controllers --format yaml')
520 output = subprocess.check_output(cmd)
521 output = yaml.safe_load(output)
522 return output.get('current-controller', '')
523
524 def current_model(self, controller_name=None):
525 if not controller_name:
526 controller_name = self.current_controller()
527 models = self.models()[controller_name]
528 if 'current-model' not in models:
529 raise JujuError('No current model')
530 return models['current-model']
531
532 def controllers(self):
533 return self._load_yaml('controllers.yaml', 'controllers')
534
535 def models(self):
536 return self._load_yaml('models.yaml', 'controllers')
537
538 def accounts(self):
539 return self._load_yaml('accounts.yaml', 'controllers')
540
541 def _load_yaml(self, filename, key):
542 filepath = os.path.join(self.path, filename)
543 with io.open(filepath, 'rt') as f:
544 return yaml.safe_load(f)[key]
545
546
547 def get_macaroons():
548 """Decode and return macaroons from default ~/.go-cookies
549
550 """
551 try:
552 cookie_file = os.path.expanduser('~/.go-cookies')
553 with open(cookie_file, 'r') as f:
554 cookies = json.load(f)
555 except (OSError, ValueError):
556 log.warn("Couldn't load macaroons from %s", cookie_file)
557 return []
558
559 base64_macaroons = [
560 c['Value'] for c in cookies
561 if c['Name'].startswith('macaroon-') and c['Value']
562 ]
563
564 return [
565 json.loads(base64.b64decode(value).decode('utf-8'))
566 for value in base64_macaroons
567 ]