Refactored login code to better handle redirects (#116)
authorCory Johns <johnsca@gmail.com>
Wed, 26 Apr 2017 22:23:13 +0000 (18:23 -0400)
committerGitHub <noreply@github.com>
Wed, 26 Apr 2017 22:23:13 +0000 (18:23 -0400)
Fixes #114: build_facades not handling discharge-required results
Fixes #115: JAAS update broke controller connections

Also ensures that the receiver and pinger tasks get cleaned up properly
when a connection is closed.

Also makes the model AllWatcher share the model connection to reduce the
number of open connections required.  The independent connection is no
longer needed since the websocket responses are properly paired with the
requests.

examples/add_model.py
juju/client/connection.py
juju/errors.py
juju/model.py
tests/integration/test_controller.py

index 259771b..3e46490 100644 (file)
@@ -51,13 +51,14 @@ async def main():
         print("Destroying model")
         await controller.destroy_model(model.info.uuid)
 
         print("Destroying model")
         await controller.destroy_model(model.info.uuid)
 
-    except Exception as e:
+    except Exception:
         LOG.exception(
             "Test failed! Model {} may not be cleaned up".format(model_name))
 
     finally:
         print('Disconnecting from controller')
         LOG.exception(
             "Test failed! Model {} may not be cleaned up".format(model_name))
 
     finally:
         print('Disconnecting from controller')
-        await model.disconnect()
+        if model:
+            await model.disconnect()
         await controller.disconnect()
 
 
         await controller.disconnect()
 
 
index 4a9766d..2be360f 100644 (file)
@@ -45,6 +45,7 @@ class Monitor:
     def __init__(self, connection):
         self.connection = connection
         self.receiver = None
     def __init__(self, connection):
         self.connection = connection
         self.receiver = None
+        self.pinger = None
 
     @property
     def status(self):
 
     @property
     def status(self):
@@ -122,9 +123,14 @@ class Connection:
             macaroons=None, loop=None):
         self.endpoint = endpoint
         self.uuid = uuid
             macaroons=None, loop=None):
         self.endpoint = endpoint
         self.uuid = uuid
-        self.username = username
-        self.password = password
-        self.macaroons = macaroons
+        if macaroons:
+            self.macaroons = macaroons
+            self.username = ''
+            self.password = ''
+        else:
+            self.macaroons = []
+            self.username = username
+            self.password = password
         self.cacert = cacert
         self.loop = loop or asyncio.get_event_loop()
 
         self.cacert = cacert
         self.loop = loop or asyncio.get_event_loop()
 
@@ -162,7 +168,14 @@ class Connection:
         return self
 
     async def close(self):
         return self
 
     async def close(self):
+        if not self.is_open:
+            return
         self.close_called = True
         self.close_called = True
+        if self.monitor.pinger:
+            # might be closing due to login failure,
+            # in which case we won't have a pinger yet
+            self.monitor.pinger.cancel()
+        self.monitor.receiver.cancel()
         await self.ws.close()
 
     async def recv(self, request_id):
         await self.ws.close()
 
     async def recv(self, request_id):
@@ -196,7 +209,7 @@ class Connection:
         pinger_facade = client.PingerFacade.from_connection(self)
         while self.is_open:
             await pinger_facade.Ping()
         pinger_facade = client.PingerFacade.from_connection(self)
         while self.is_open:
             await pinger_facade.Ping()
-            await asyncio.sleep(10)
+            await asyncio.sleep(10, loop=self.loop)
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -309,6 +322,38 @@ class Connection:
             self.loop,
         )
 
             self.loop,
         )
 
+    async def _try_endpoint(self, endpoint, cacert):
+        success = False
+        result = None
+        new_endpoints = []
+
+        self.endpoint = endpoint
+        self.cacert = cacert
+        await self.open()
+        try:
+            result = await self.login()
+            if 'discharge-required-error' in result['response']:
+                log.info('Macaroon discharge required, disconnecting')
+            else:
+                # successful login!
+                log.info('Authenticated')
+                success = True
+        except JujuAPIError as e:
+            if e.error_code != 'redirection required':
+                raise
+            log.info('Controller requested redirect')
+            redirect_info = await self.redirect_info()
+            redir_cacert = redirect_info['ca-cert']
+            new_endpoints = [
+                ("{value}:{port}".format(**s), redir_cacert)
+                for servers in redirect_info['servers']
+                for s in servers if s["scope"] == 'public'
+            ]
+        finally:
+            if not success:
+                await self.close()
+        return success, result, new_endpoints
+
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
@@ -321,34 +366,24 @@ class Connection:
         """
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
                      loop)
         """
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
                      loop)
-        await client.open()
-
-        redirect_info = await client.redirect_info()
-        if not redirect_info:
-            await client.login(username, password, macaroons)
-            return client
-
-        await client.close()
-        servers = [
-            s for servers in redirect_info['servers']
-            for s in servers if s["scope"] == 'public'
-        ]
-        for server in servers:
-            client = cls(
-                "{value}:{port}".format(**server), uuid, username,
-                password, redirect_info['ca-cert'], macaroons)
-            await client.open()
-            try:
-                result = await client.login(username, password, macaroons)
-                if 'discharge-required-error' in result:
-                    continue
-                return client
-            except Exception as e:
-                await client.close()
-                log.exception(e)
+        endpoints = [(endpoint, cacert)]
+        while endpoints:
+            _endpoint, _cacert = endpoints.pop(0)
+            success, result, new_endpoints = await client._try_endpoint(
+                _endpoint, _cacert)
+            if success:
+                break
+            endpoints.extend(new_endpoints)
+        else:
+            # ran out of endpoints without a successful login
+            raise Exception("Couldn't authenticate to {}".format(endpoint))
+
+        response = result['response']
+        client.info = response.copy()
+        client.build_facades(response.get('facades', {}))
+        client.monitor.pinger = client.loop.create_task(client.pinger())
 
 
-        raise Exception(
-            "Couldn't authenticate to %s", endpoint)
+        return client
 
     @classmethod
     async def connect_current(cls, loop=None):
 
     @classmethod
     async def connect_current(cls, loop=None):
@@ -444,11 +479,8 @@ class Connection:
                             self.info['server-version']))
             self.facades = VERSION_MAP['latest']
 
                             self.info['server-version']))
             self.facades = VERSION_MAP['latest']
 
-    async def login(self, username, password, macaroons=None):
-        if macaroons:
-            username = ''
-            password = ''
-
+    async def login(self):
+        username = self.username
         if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
         if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
@@ -458,17 +490,11 @@ class Connection:
             "version": 3,
             "params": {
                 "auth-tag": username,
             "version": 3,
             "params": {
                 "auth-tag": username,
-                "credentials": password,
+                "credentials": self.password,
                 "nonce": "".join(random.sample(string.printable, 12)),
                 "nonce": "".join(random.sample(string.printable, 12)),
-                "macaroons": macaroons or []
+                "macaroons": self.macaroons
             }})
             }})
-        response = result['response']
-        self.info = response.copy()
-        self.build_facades(response.get('facades', {}))
-        # Create a pinger to keep the connection alive (needed for
-        # JaaS; harmless elsewhere).
-        self.loop.create_task(self.pinger())
-        return response
+        return result
 
     async def redirect_info(self):
         try:
 
     async def redirect_info(self):
         try:
index 71a3215..de52174 100644 (file)
@@ -4,7 +4,9 @@ class JujuError(Exception):
 
 class JujuAPIError(JujuError):
     def __init__(self, result):
 
 class JujuAPIError(JujuError):
     def __init__(self, result):
+        self.result = result
         self.message = result['error']
         self.message = result['error']
+        self.error_code = result.get('error-code')
         self.response = result['response']
         self.request_id = result['request-id']
         super().__init__(self.message)
         self.response = result['response']
         self.request_id = result['request-id']
         super().__init__(self.message)
index f162c7e..3ed8fa7 100644 (file)
@@ -621,9 +621,8 @@ class Model(object):
         async def _start_watch():
             self._watch_shutdown.clear()
             try:
         async def _start_watch():
             self._watch_shutdown.clear()
             try:
-                self._watch_conn = await self.connection.clone()
                 allwatcher = client.AllWatcherFacade.from_connection(
                 allwatcher = client.AllWatcherFacade.from_connection(
-                    self._watch_conn)
+                    self.connection)
                 while True:
                     results = await allwatcher.Next()
                     for delta in results.deltas:
                 while True:
                     results = await allwatcher.Next()
                     for delta in results.deltas:
@@ -640,11 +639,8 @@ class Model(object):
                             loop=self.loop)
                     self._watch_received.set()
             except CancelledError:
                             loop=self.loop)
                     self._watch_received.set()
             except CancelledError:
-                log.debug('Closing watcher connection')
-                await self._watch_conn.close()
                 self._watch_shutdown.set()
                 self._watch_shutdown.set()
-                self._watch_conn = None
-            except Exception as e:
+            except Exception:
                 log.exception('Error in watcher')
                 raise
 
                 log.exception('Error in watcher')
                 raise
 
index d3a687f..f3840cc 100644 (file)
@@ -1,5 +1,3 @@
-import asyncio
-from concurrent.futures import ThreadPoolExecutor
 import pytest
 import uuid
 
 import pytest
 import uuid
 
@@ -43,9 +41,11 @@ async def test_change_user_password(event_loop):
         await controller.add_user(username)
         await controller.change_user_password(username, 'password')
         try:
         await controller.add_user(username)
         await controller.change_user_password(username, 'password')
         try:
-            con = await controller.connect(
+            new_controller = Controller()
+            await new_controller.connect(
                 controller.connection.endpoint, username, 'password')
             result = True
                 controller.connection.endpoint, username, 'password')
             result = True
+            await new_controller.disconnect()
         except JujuAPIError:
             result = False
         assert result is True
         except JujuAPIError:
             result = False
         assert result is True
@@ -59,11 +59,13 @@ async def test_grant(event_loop):
         await controller.add_user(username)
         await controller.grant(username, 'superuser')
         result = await controller.get_user(username)
         await controller.add_user(username)
         await controller.grant(username, 'superuser')
         result = await controller.get_user(username)
-        result = result.serialize()['results'][0].serialize()['result'].serialize()
+        result = result.serialize()['results'][0].serialize()['result']\
+            .serialize()
         assert result['access'] == 'superuser'
         await controller.grant(username, 'login')
         result = await controller.get_user(username)
         assert result['access'] == 'superuser'
         await controller.grant(username, 'login')
         result = await controller.get_user(username)
-        result = result.serialize()['results'][0].serialize()['result'].serialize()
+        result = result.serialize()['results'][0].serialize()['result']\
+            .serialize()
         assert result['access'] == 'login'
 
 
         assert result['access'] == 'login'