From: Cory Johns Date: Wed, 26 Apr 2017 22:23:13 +0000 (-0400) Subject: Refactored login code to better handle redirects (#116) X-Git-Tag: 0.4.1~3 X-Git-Url: https://osm.etsi.org/gitweb/?a=commitdiff_plain;h=7c2a530853c95b8a3518f6db0870f94858f87c27;p=osm%2FN2VC.git Refactored login code to better handle redirects (#116) Fixes #114: build_facades not handling discharge-required results Fixes #115: JAAS update broke controller connections Also ensures that the receiver and pinger tasks get cleaned up properly when a connection is closed. Also makes the model AllWatcher share the model connection to reduce the number of open connections required. The independent connection is no longer needed since the websocket responses are properly paired with the requests. --- diff --git a/examples/add_model.py b/examples/add_model.py index 259771b..3e46490 100644 --- a/examples/add_model.py +++ b/examples/add_model.py @@ -51,13 +51,14 @@ async def main(): print("Destroying model") await controller.destroy_model(model.info.uuid) - except Exception as e: + except Exception: LOG.exception( "Test failed! Model {} may not be cleaned up".format(model_name)) finally: print('Disconnecting from controller') - await model.disconnect() + if model: + await model.disconnect() await controller.disconnect() diff --git a/juju/client/connection.py b/juju/client/connection.py index 4a9766d..2be360f 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -45,6 +45,7 @@ class Monitor: def __init__(self, connection): self.connection = connection self.receiver = None + self.pinger = None @property def status(self): @@ -122,9 +123,14 @@ class Connection: macaroons=None, loop=None): self.endpoint = endpoint self.uuid = uuid - self.username = username - self.password = password - self.macaroons = macaroons + if macaroons: + self.macaroons = macaroons + self.username = '' + self.password = '' + else: + self.macaroons = [] + self.username = username + self.password = password self.cacert = cacert self.loop = loop or asyncio.get_event_loop() @@ -162,7 +168,14 @@ class Connection: return self async def close(self): + if not self.is_open: + return self.close_called = True + if self.monitor.pinger: + # might be closing due to login failure, + # in which case we won't have a pinger yet + self.monitor.pinger.cancel() + self.monitor.receiver.cancel() await self.ws.close() async def recv(self, request_id): @@ -196,7 +209,7 @@ class Connection: pinger_facade = client.PingerFacade.from_connection(self) while self.is_open: await pinger_facade.Ping() - await asyncio.sleep(10) + await asyncio.sleep(10, loop=self.loop) async def rpc(self, msg, encoder=None): self.__request_id__ += 1 @@ -309,6 +322,38 @@ class Connection: self.loop, ) + async def _try_endpoint(self, endpoint, cacert): + success = False + result = None + new_endpoints = [] + + self.endpoint = endpoint + self.cacert = cacert + await self.open() + try: + result = await self.login() + if 'discharge-required-error' in result['response']: + log.info('Macaroon discharge required, disconnecting') + else: + # successful login! + log.info('Authenticated') + success = True + except JujuAPIError as e: + if e.error_code != 'redirection required': + raise + log.info('Controller requested redirect') + redirect_info = await self.redirect_info() + redir_cacert = redirect_info['ca-cert'] + new_endpoints = [ + ("{value}:{port}".format(**s), redir_cacert) + for servers in redirect_info['servers'] + for s in servers if s["scope"] == 'public' + ] + finally: + if not success: + await self.close() + return success, result, new_endpoints + @classmethod async def connect( cls, endpoint, uuid, username, password, cacert=None, @@ -321,34 +366,24 @@ class Connection: """ client = cls(endpoint, uuid, username, password, cacert, macaroons, loop) - await client.open() - - redirect_info = await client.redirect_info() - if not redirect_info: - await client.login(username, password, macaroons) - return client - - await client.close() - servers = [ - s for servers in redirect_info['servers'] - for s in servers if s["scope"] == 'public' - ] - for server in servers: - client = cls( - "{value}:{port}".format(**server), uuid, username, - password, redirect_info['ca-cert'], macaroons) - await client.open() - try: - result = await client.login(username, password, macaroons) - if 'discharge-required-error' in result: - continue - return client - except Exception as e: - await client.close() - log.exception(e) + endpoints = [(endpoint, cacert)] + while endpoints: + _endpoint, _cacert = endpoints.pop(0) + success, result, new_endpoints = await client._try_endpoint( + _endpoint, _cacert) + if success: + break + endpoints.extend(new_endpoints) + else: + # ran out of endpoints without a successful login + raise Exception("Couldn't authenticate to {}".format(endpoint)) + + response = result['response'] + client.info = response.copy() + client.build_facades(response.get('facades', {})) + client.monitor.pinger = client.loop.create_task(client.pinger()) - raise Exception( - "Couldn't authenticate to %s", endpoint) + return client @classmethod async def connect_current(cls, loop=None): @@ -444,11 +479,8 @@ class Connection: self.info['server-version'])) self.facades = VERSION_MAP['latest'] - async def login(self, username, password, macaroons=None): - if macaroons: - username = '' - password = '' - + async def login(self): + username = self.username if username and not username.startswith('user-'): username = 'user-{}'.format(username) @@ -458,17 +490,11 @@ class Connection: "version": 3, "params": { "auth-tag": username, - "credentials": password, + "credentials": self.password, "nonce": "".join(random.sample(string.printable, 12)), - "macaroons": macaroons or [] + "macaroons": self.macaroons }}) - response = result['response'] - self.info = response.copy() - self.build_facades(response.get('facades', {})) - # Create a pinger to keep the connection alive (needed for - # JaaS; harmless elsewhere). - self.loop.create_task(self.pinger()) - return response + return result async def redirect_info(self): try: diff --git a/juju/errors.py b/juju/errors.py index 71a3215..de52174 100644 --- a/juju/errors.py +++ b/juju/errors.py @@ -4,7 +4,9 @@ class JujuError(Exception): class JujuAPIError(JujuError): def __init__(self, result): + self.result = result self.message = result['error'] + self.error_code = result.get('error-code') self.response = result['response'] self.request_id = result['request-id'] super().__init__(self.message) diff --git a/juju/model.py b/juju/model.py index f162c7e..3ed8fa7 100644 --- a/juju/model.py +++ b/juju/model.py @@ -621,9 +621,8 @@ class Model(object): async def _start_watch(): self._watch_shutdown.clear() try: - self._watch_conn = await self.connection.clone() allwatcher = client.AllWatcherFacade.from_connection( - self._watch_conn) + self.connection) while True: results = await allwatcher.Next() for delta in results.deltas: @@ -640,11 +639,8 @@ class Model(object): loop=self.loop) self._watch_received.set() except CancelledError: - log.debug('Closing watcher connection') - await self._watch_conn.close() self._watch_shutdown.set() - self._watch_conn = None - except Exception as e: + except Exception: log.exception('Error in watcher') raise diff --git a/tests/integration/test_controller.py b/tests/integration/test_controller.py index d3a687f..f3840cc 100644 --- a/tests/integration/test_controller.py +++ b/tests/integration/test_controller.py @@ -1,5 +1,3 @@ -import asyncio -from concurrent.futures import ThreadPoolExecutor import pytest import uuid @@ -43,9 +41,11 @@ async def test_change_user_password(event_loop): await controller.add_user(username) await controller.change_user_password(username, 'password') try: - con = await controller.connect( + new_controller = Controller() + await new_controller.connect( controller.connection.endpoint, username, 'password') result = True + await new_controller.disconnect() except JujuAPIError: result = False assert result is True @@ -59,11 +59,13 @@ async def test_grant(event_loop): await controller.add_user(username) await controller.grant(username, 'superuser') result = await controller.get_user(username) - result = result.serialize()['results'][0].serialize()['result'].serialize() + result = result.serialize()['results'][0].serialize()['result']\ + .serialize() assert result['access'] == 'superuser' await controller.grant(username, 'login') result = await controller.get_user(username) - result = result.serialize()['results'][0].serialize()['result'].serialize() + result = result.serialize()['results'][0].serialize()['result']\ + .serialize() assert result['access'] == 'login'