X-Git-Url: https://osm.etsi.org/gitweb/?a=blobdiff_plain;f=juju%2Fclient%2Fconnection.py;h=68517079f5cdbdbeea7390d312f460f2a4945b93;hb=13d73a3515008c5797f83bcaef0adfb6627395d3;hp=4a9766d62da5fa44dfdffef899c7608ec95adca2;hpb=1a3cee44420e79fda92943edf636eaddb393145e;p=osm%2FN2VC.git diff --git a/juju/client/connection.py b/juju/client/connection.py index 4a9766d..6851707 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -9,14 +9,15 @@ import ssl import string import subprocess import websockets +from concurrent.futures import CancelledError from http.client import HTTPSConnection +from pathlib import Path import asyncio import yaml -from juju import tag +from juju import tag, utils from juju.client import client -from juju.client.version_map import VERSION_MAP from juju.errors import JujuError, JujuAPIError, JujuConnectionError from juju.utils import IdQueue @@ -44,7 +45,11 @@ class Monitor: def __init__(self, connection): self.connection = connection - self.receiver = None + self.close_called = asyncio.Event(loop=self.connection.loop) + self.receiver_stopped = asyncio.Event(loop=self.connection.loop) + self.pinger_stopped = asyncio.Event(loop=self.connection.loop) + self.receiver_stopped.set() + self.pinger_stopped.set() @property def status(self): @@ -62,21 +67,21 @@ class Monitor: # DISCONNECTED: connection not yet open if not self.connection.ws: return self.DISCONNECTED - if not self.receiver: + if self.receiver_stopped.is_set(): return self.DISCONNECTED # ERROR: Connection closed (or errored), but we didn't call # connection.close - if not self.connection.close_called and self.receiver_exceptions(): + if not self.close_called.is_set() and self.receiver_stopped.is_set(): return self.ERROR - if not self.connection.close_called and not self.connection.ws.open: - # The check for self.receiver existing above guards against the - # case where we're not open because we simply haven't - # setup the connection yet. + if not self.close_called.is_set() and not self.connection.ws.open: + # The check for self.receiver_stopped existing above guards + # against the case where we're not open because we simply + # haven't setup the connection yet. return self.ERROR # DISCONNECTED: cleanly disconnected. - if self.connection.close_called and not self.connection.ws.open: + if self.close_called.is_set() and not self.connection.ws.open: return self.DISCONNECTED # CONNECTED: everything is fine! @@ -88,17 +93,6 @@ class Monitor: # know what state the connection is in. return self.UNKNOWN - def receiver_exceptions(self): - """ - Return exceptions in the receiver, if any. - - """ - if not self.receiver: - return None - if not self.receiver.done(): - return None - return self.receiver.exception() - class Connection: """ @@ -122,9 +116,14 @@ class Connection: macaroons=None, loop=None): self.endpoint = endpoint self.uuid = uuid - self.username = username - self.password = password - self.macaroons = macaroons + if macaroons: + self.macaroons = macaroons + self.username = '' + self.password = '' + else: + self.macaroons = [] + self.username = username + self.password = password self.cacert = cacert self.loop = loop or asyncio.get_event_loop() @@ -133,7 +132,6 @@ class Connection: self.ws = None self.facades = {} self.messages = IdQueue(loop=self.loop) - self.close_called = False self.monitor = Monitor(connection=self) @property @@ -157,12 +155,18 @@ class Connection: kw['loop'] = self.loop self.addr = url self.ws = await websockets.connect(url, **kw) - self.monitor.receiver = self.loop.create_task(self.receiver()) + self.loop.create_task(self.receiver()) + self.monitor.receiver_stopped.clear() log.info("Driver connected to juju %s", url) + self.monitor.close_called.clear() return self async def close(self): - self.close_called = True + if not self.is_open: + return + self.monitor.close_called.set() + await self.monitor.pinger_stopped.wait() + await self.monitor.receiver_stopped.wait() await self.ws.close() async def recv(self, request_id): @@ -171,19 +175,29 @@ class Connection: return await self.messages.get(request_id) async def receiver(self): - while self.is_open: - try: - result = await self.ws.recv() + try: + while self.is_open: + result = await utils.run_with_interrupt( + self.ws.recv(), + self.monitor.close_called, + loop=self.loop) + if self.monitor.close_called.is_set(): + break if result is not None: result = json.loads(result) await self.messages.put(result['request-id'], result) - except Exception as e: - await self.messages.put_all(e) - if isinstance(e, websockets.ConnectionClosed): - # ConnectionClosed is not really exceptional for us, - # but it may be for any pending message listeners - return - raise + except CancelledError: + pass + except Exception as e: + await self.messages.put_all(e) + if isinstance(e, websockets.ConnectionClosed): + # ConnectionClosed is not really exceptional for us, + # but it may be for any pending message listeners + return + log.exception("Error in receiver") + raise + finally: + self.monitor.receiver_stopped.set() async def pinger(self): ''' @@ -193,10 +207,24 @@ class Connection: To prevent timing out, we send a ping every ten seconds. ''' + async def _do_ping(): + try: + await pinger_facade.Ping() + await asyncio.sleep(10, loop=self.loop) + except CancelledError: + pass + pinger_facade = client.PingerFacade.from_connection(self) - while self.is_open: - await pinger_facade.Ping() - await asyncio.sleep(10) + try: + while self.is_open: + await utils.run_with_interrupt( + _do_ping(), + self.monitor.close_called, + loop=self.loop) + if self.monitor.close_called.is_set(): + break + finally: + self.monitor.pinger_stopped.set() async def rpc(self, msg, encoder=None): self.__request_id__ += 1 @@ -309,6 +337,38 @@ class Connection: self.loop, ) + async def _try_endpoint(self, endpoint, cacert): + success = False + result = None + new_endpoints = [] + + self.endpoint = endpoint + self.cacert = cacert + await self.open() + try: + result = await self.login() + if 'discharge-required-error' in result['response']: + log.info('Macaroon discharge required, disconnecting') + else: + # successful login! + log.info('Authenticated') + success = True + except JujuAPIError as e: + if e.error_code != 'redirection required': + raise + log.info('Controller requested redirect') + redirect_info = await self.redirect_info() + redir_cacert = redirect_info['ca-cert'] + new_endpoints = [ + ("{value}:{port}".format(**s), redir_cacert) + for servers in redirect_info['servers'] + for s in servers if s["scope"] == 'public' + ] + finally: + if not success: + await self.close() + return success, result, new_endpoints + @classmethod async def connect( cls, endpoint, uuid, username, password, cacert=None, @@ -321,34 +381,25 @@ class Connection: """ client = cls(endpoint, uuid, username, password, cacert, macaroons, loop) - await client.open() - - redirect_info = await client.redirect_info() - if not redirect_info: - await client.login(username, password, macaroons) - return client - - await client.close() - servers = [ - s for servers in redirect_info['servers'] - for s in servers if s["scope"] == 'public' - ] - for server in servers: - client = cls( - "{value}:{port}".format(**server), uuid, username, - password, redirect_info['ca-cert'], macaroons) - await client.open() - try: - result = await client.login(username, password, macaroons) - if 'discharge-required-error' in result: - continue - return client - except Exception as e: - await client.close() - log.exception(e) + endpoints = [(endpoint, cacert)] + while endpoints: + _endpoint, _cacert = endpoints.pop(0) + success, result, new_endpoints = await client._try_endpoint( + _endpoint, _cacert) + if success: + break + endpoints.extend(new_endpoints) + else: + # ran out of endpoints without a successful login + raise Exception("Couldn't authenticate to {}".format(endpoint)) + + response = result['response'] + client.info = response.copy() + client.build_facades(response.get('facades', {})) + client.loop.create_task(client.pinger()) + client.monitor.pinger_stopped.clear() - raise Exception( - "Couldn't authenticate to %s", endpoint) + return client @classmethod async def connect_current(cls, loop=None): @@ -356,7 +407,11 @@ class Connection: """ jujudata = JujuData() + controller_name = jujudata.current_controller() + if not controller_name: + raise JujuConnectionError('No current controller') + model_name = jujudata.current_model() return await cls.connect_model( @@ -386,7 +441,7 @@ class Connection: accounts = jujudata.accounts()[controller_name] username = accounts['user'] password = accounts.get('password') - macaroons = get_macaroons() if not password else None + macaroons = get_macaroons(controller_name) if not password else None return await cls.connect( endpoint, None, username, password, cacert, macaroons, loop) @@ -420,35 +475,18 @@ class Connection: password = accounts.get('password') models = jujudata.models()[controller_name] model_uuid = models['models'][model_name]['uuid'] - macaroons = get_macaroons() if not password else None + macaroons = get_macaroons(controller_name) if not password else None return await cls.connect( endpoint, model_uuid, username, password, cacert, macaroons, loop) def build_facades(self, facades): self.facades.clear() - # In order to work around an issue where the juju api is not - # returning a complete list of facades, we simply look up the - # juju version in a pregenerated map, and use that info to - # populate our list of facades. - - # TODO: if a future version of juju fixes this bug, restore - # the following code for that version and higher: - # for facade in facades: - # self.facades[facade['name']] = facade['versions'][-1] - try: - self.facades = VERSION_MAP[self.info['server-version']] - except KeyError: - log.warning("Could not find a set of facades for {}. Using " - "the latest facade set instead".format( - self.info['server-version'])) - self.facades = VERSION_MAP['latest'] - - async def login(self, username, password, macaroons=None): - if macaroons: - username = '' - password = '' + for facade in facades: + self.facades[facade['name']] = facade['versions'][-1] + async def login(self): + username = self.username if username and not username.startswith('user-'): username = 'user-{}'.format(username) @@ -458,17 +496,11 @@ class Connection: "version": 3, "params": { "auth-tag": username, - "credentials": password, + "credentials": self.password, "nonce": "".join(random.sample(string.printable, 12)), - "macaroons": macaroons or [] + "macaroons": self.macaroons }}) - response = result['response'] - self.info = response.copy() - self.build_facades(response.get('facades', {})) - # Create a pinger to keep the connection alive (needed for - # JaaS; harmless elsewhere). - self.loop.create_task(self.pinger()) - return response + return result async def redirect_info(self): try: @@ -518,16 +550,26 @@ class JujuData: return yaml.safe_load(f)[key] -def get_macaroons(): +def get_macaroons(controller_name=None): """Decode and return macaroons from default ~/.go-cookies """ - try: - cookie_file = os.path.expanduser('~/.go-cookies') - with open(cookie_file, 'r') as f: - cookies = json.load(f) - except (OSError, ValueError): - log.warn("Couldn't load macaroons from %s", cookie_file) + cookie_files = [] + if controller_name: + cookie_files.append('~/.local/share/juju/cookies/{}.json'.format( + controller_name)) + cookie_files.append('~/.go-cookies') + for cookie_file in cookie_files: + cookie_file = Path(cookie_file).expanduser() + if cookie_file.exists(): + try: + cookies = json.loads(cookie_file.read_text()) + break + except (OSError, ValueError): + log.warn("Couldn't load macaroons from %s", cookie_file) + return [] + else: + log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files)) return [] base64_macaroons = [