Refactored login code to better handle redirects (#116)
authorCory Johns <johnsca@gmail.com>
Wed, 26 Apr 2017 22:23:13 +0000 (18:23 -0400)
committerGitHub <noreply@github.com>
Wed, 26 Apr 2017 22:23:13 +0000 (18:23 -0400)
Fixes #114: build_facades not handling discharge-required results
Fixes #115: JAAS update broke controller connections

Also ensures that the receiver and pinger tasks get cleaned up properly
when a connection is closed.

Also makes the model AllWatcher share the model connection to reduce the
number of open connections required.  The independent connection is no
longer needed since the websocket responses are properly paired with the
requests.

examples/add_model.py
juju/client/connection.py
juju/errors.py
juju/model.py
tests/integration/test_controller.py

index 259771b..3e46490 100644 (file)
@@ -51,13 +51,14 @@ async def main():
         print("Destroying model")
         await controller.destroy_model(model.info.uuid)
 
-    except Exception as e:
+    except Exception:
         LOG.exception(
             "Test failed! Model {} may not be cleaned up".format(model_name))
 
     finally:
         print('Disconnecting from controller')
-        await model.disconnect()
+        if model:
+            await model.disconnect()
         await controller.disconnect()
 
 
index 4a9766d..2be360f 100644 (file)
@@ -45,6 +45,7 @@ class Monitor:
     def __init__(self, connection):
         self.connection = connection
         self.receiver = None
+        self.pinger = None
 
     @property
     def status(self):
@@ -122,9 +123,14 @@ class Connection:
             macaroons=None, loop=None):
         self.endpoint = endpoint
         self.uuid = uuid
-        self.username = username
-        self.password = password
-        self.macaroons = macaroons
+        if macaroons:
+            self.macaroons = macaroons
+            self.username = ''
+            self.password = ''
+        else:
+            self.macaroons = []
+            self.username = username
+            self.password = password
         self.cacert = cacert
         self.loop = loop or asyncio.get_event_loop()
 
@@ -162,7 +168,14 @@ class Connection:
         return self
 
     async def close(self):
+        if not self.is_open:
+            return
         self.close_called = True
+        if self.monitor.pinger:
+            # might be closing due to login failure,
+            # in which case we won't have a pinger yet
+            self.monitor.pinger.cancel()
+        self.monitor.receiver.cancel()
         await self.ws.close()
 
     async def recv(self, request_id):
@@ -196,7 +209,7 @@ class Connection:
         pinger_facade = client.PingerFacade.from_connection(self)
         while self.is_open:
             await pinger_facade.Ping()
-            await asyncio.sleep(10)
+            await asyncio.sleep(10, loop=self.loop)
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -309,6 +322,38 @@ class Connection:
             self.loop,
         )
 
+    async def _try_endpoint(self, endpoint, cacert):
+        success = False
+        result = None
+        new_endpoints = []
+
+        self.endpoint = endpoint
+        self.cacert = cacert
+        await self.open()
+        try:
+            result = await self.login()
+            if 'discharge-required-error' in result['response']:
+                log.info('Macaroon discharge required, disconnecting')
+            else:
+                # successful login!
+                log.info('Authenticated')
+                success = True
+        except JujuAPIError as e:
+            if e.error_code != 'redirection required':
+                raise
+            log.info('Controller requested redirect')
+            redirect_info = await self.redirect_info()
+            redir_cacert = redirect_info['ca-cert']
+            new_endpoints = [
+                ("{value}:{port}".format(**s), redir_cacert)
+                for servers in redirect_info['servers']
+                for s in servers if s["scope"] == 'public'
+            ]
+        finally:
+            if not success:
+                await self.close()
+        return success, result, new_endpoints
+
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
@@ -321,34 +366,24 @@ class Connection:
         """
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
                      loop)
-        await client.open()
-
-        redirect_info = await client.redirect_info()
-        if not redirect_info:
-            await client.login(username, password, macaroons)
-            return client
-
-        await client.close()
-        servers = [
-            s for servers in redirect_info['servers']
-            for s in servers if s["scope"] == 'public'
-        ]
-        for server in servers:
-            client = cls(
-                "{value}:{port}".format(**server), uuid, username,
-                password, redirect_info['ca-cert'], macaroons)
-            await client.open()
-            try:
-                result = await client.login(username, password, macaroons)
-                if 'discharge-required-error' in result:
-                    continue
-                return client
-            except Exception as e:
-                await client.close()
-                log.exception(e)
+        endpoints = [(endpoint, cacert)]
+        while endpoints:
+            _endpoint, _cacert = endpoints.pop(0)
+            success, result, new_endpoints = await client._try_endpoint(
+                _endpoint, _cacert)
+            if success:
+                break
+            endpoints.extend(new_endpoints)
+        else:
+            # ran out of endpoints without a successful login
+            raise Exception("Couldn't authenticate to {}".format(endpoint))
+
+        response = result['response']
+        client.info = response.copy()
+        client.build_facades(response.get('facades', {}))
+        client.monitor.pinger = client.loop.create_task(client.pinger())
 
-        raise Exception(
-            "Couldn't authenticate to %s", endpoint)
+        return client
 
     @classmethod
     async def connect_current(cls, loop=None):
@@ -444,11 +479,8 @@ class Connection:
                             self.info['server-version']))
             self.facades = VERSION_MAP['latest']
 
-    async def login(self, username, password, macaroons=None):
-        if macaroons:
-            username = ''
-            password = ''
-
+    async def login(self):
+        username = self.username
         if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
@@ -458,17 +490,11 @@ class Connection:
             "version": 3,
             "params": {
                 "auth-tag": username,
-                "credentials": password,
+                "credentials": self.password,
                 "nonce": "".join(random.sample(string.printable, 12)),
-                "macaroons": macaroons or []
+                "macaroons": self.macaroons
             }})
-        response = result['response']
-        self.info = response.copy()
-        self.build_facades(response.get('facades', {}))
-        # Create a pinger to keep the connection alive (needed for
-        # JaaS; harmless elsewhere).
-        self.loop.create_task(self.pinger())
-        return response
+        return result
 
     async def redirect_info(self):
         try:
index 71a3215..de52174 100644 (file)
@@ -4,7 +4,9 @@ class JujuError(Exception):
 
 class JujuAPIError(JujuError):
     def __init__(self, result):
+        self.result = result
         self.message = result['error']
+        self.error_code = result.get('error-code')
         self.response = result['response']
         self.request_id = result['request-id']
         super().__init__(self.message)
index f162c7e..3ed8fa7 100644 (file)
@@ -621,9 +621,8 @@ class Model(object):
         async def _start_watch():
             self._watch_shutdown.clear()
             try:
-                self._watch_conn = await self.connection.clone()
                 allwatcher = client.AllWatcherFacade.from_connection(
-                    self._watch_conn)
+                    self.connection)
                 while True:
                     results = await allwatcher.Next()
                     for delta in results.deltas:
@@ -640,11 +639,8 @@ class Model(object):
                             loop=self.loop)
                     self._watch_received.set()
             except CancelledError:
-                log.debug('Closing watcher connection')
-                await self._watch_conn.close()
                 self._watch_shutdown.set()
-                self._watch_conn = None
-            except Exception as e:
+            except Exception:
                 log.exception('Error in watcher')
                 raise
 
index d3a687f..f3840cc 100644 (file)
@@ -1,5 +1,3 @@
-import asyncio
-from concurrent.futures import ThreadPoolExecutor
 import pytest
 import uuid
 
@@ -43,9 +41,11 @@ async def test_change_user_password(event_loop):
         await controller.add_user(username)
         await controller.change_user_password(username, 'password')
         try:
-            con = await controller.connect(
+            new_controller = Controller()
+            await new_controller.connect(
                 controller.connection.endpoint, username, 'password')
             result = True
+            await new_controller.disconnect()
         except JujuAPIError:
             result = False
         assert result is True
@@ -59,11 +59,13 @@ async def test_grant(event_loop):
         await controller.add_user(username)
         await controller.grant(username, 'superuser')
         result = await controller.get_user(username)
-        result = result.serialize()['results'][0].serialize()['result'].serialize()
+        result = result.serialize()['results'][0].serialize()['result']\
+            .serialize()
         assert result['access'] == 'superuser'
         await controller.grant(username, 'login')
         result = await controller.get_user(username)
-        result = result.serialize()['results'][0].serialize()['result'].serialize()
+        result = result.serialize()['results'][0].serialize()['result']\
+            .serialize()
         assert result['access'] == 'login'