Refactored login code to better handle redirects (#116)
[osm/N2VC.git] / juju / client / connection.py
index 4a9766d..2be360f 100644 (file)
@@ -45,6 +45,7 @@ class Monitor:
     def __init__(self, connection):
         self.connection = connection
         self.receiver = None
+        self.pinger = None
 
     @property
     def status(self):
@@ -122,9 +123,14 @@ class Connection:
             macaroons=None, loop=None):
         self.endpoint = endpoint
         self.uuid = uuid
-        self.username = username
-        self.password = password
-        self.macaroons = macaroons
+        if macaroons:
+            self.macaroons = macaroons
+            self.username = ''
+            self.password = ''
+        else:
+            self.macaroons = []
+            self.username = username
+            self.password = password
         self.cacert = cacert
         self.loop = loop or asyncio.get_event_loop()
 
@@ -162,7 +168,14 @@ class Connection:
         return self
 
     async def close(self):
+        if not self.is_open:
+            return
         self.close_called = True
+        if self.monitor.pinger:
+            # might be closing due to login failure,
+            # in which case we won't have a pinger yet
+            self.monitor.pinger.cancel()
+        self.monitor.receiver.cancel()
         await self.ws.close()
 
     async def recv(self, request_id):
@@ -196,7 +209,7 @@ class Connection:
         pinger_facade = client.PingerFacade.from_connection(self)
         while self.is_open:
             await pinger_facade.Ping()
-            await asyncio.sleep(10)
+            await asyncio.sleep(10, loop=self.loop)
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -309,6 +322,38 @@ class Connection:
             self.loop,
         )
 
+    async def _try_endpoint(self, endpoint, cacert):
+        success = False
+        result = None
+        new_endpoints = []
+
+        self.endpoint = endpoint
+        self.cacert = cacert
+        await self.open()
+        try:
+            result = await self.login()
+            if 'discharge-required-error' in result['response']:
+                log.info('Macaroon discharge required, disconnecting')
+            else:
+                # successful login!
+                log.info('Authenticated')
+                success = True
+        except JujuAPIError as e:
+            if e.error_code != 'redirection required':
+                raise
+            log.info('Controller requested redirect')
+            redirect_info = await self.redirect_info()
+            redir_cacert = redirect_info['ca-cert']
+            new_endpoints = [
+                ("{value}:{port}".format(**s), redir_cacert)
+                for servers in redirect_info['servers']
+                for s in servers if s["scope"] == 'public'
+            ]
+        finally:
+            if not success:
+                await self.close()
+        return success, result, new_endpoints
+
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
@@ -321,34 +366,24 @@ class Connection:
         """
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
                      loop)
-        await client.open()
-
-        redirect_info = await client.redirect_info()
-        if not redirect_info:
-            await client.login(username, password, macaroons)
-            return client
-
-        await client.close()
-        servers = [
-            s for servers in redirect_info['servers']
-            for s in servers if s["scope"] == 'public'
-        ]
-        for server in servers:
-            client = cls(
-                "{value}:{port}".format(**server), uuid, username,
-                password, redirect_info['ca-cert'], macaroons)
-            await client.open()
-            try:
-                result = await client.login(username, password, macaroons)
-                if 'discharge-required-error' in result:
-                    continue
-                return client
-            except Exception as e:
-                await client.close()
-                log.exception(e)
+        endpoints = [(endpoint, cacert)]
+        while endpoints:
+            _endpoint, _cacert = endpoints.pop(0)
+            success, result, new_endpoints = await client._try_endpoint(
+                _endpoint, _cacert)
+            if success:
+                break
+            endpoints.extend(new_endpoints)
+        else:
+            # ran out of endpoints without a successful login
+            raise Exception("Couldn't authenticate to {}".format(endpoint))
+
+        response = result['response']
+        client.info = response.copy()
+        client.build_facades(response.get('facades', {}))
+        client.monitor.pinger = client.loop.create_task(client.pinger())
 
-        raise Exception(
-            "Couldn't authenticate to %s", endpoint)
+        return client
 
     @classmethod
     async def connect_current(cls, loop=None):
@@ -444,11 +479,8 @@ class Connection:
                             self.info['server-version']))
             self.facades = VERSION_MAP['latest']
 
-    async def login(self, username, password, macaroons=None):
-        if macaroons:
-            username = ''
-            password = ''
-
+    async def login(self):
+        username = self.username
         if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
@@ -458,17 +490,11 @@ class Connection:
             "version": 3,
             "params": {
                 "auth-tag": username,
-                "credentials": password,
+                "credentials": self.password,
                 "nonce": "".join(random.sample(string.printable, 12)),
-                "macaroons": macaroons or []
+                "macaroons": self.macaroons
             }})
-        response = result['response']
-        self.info = response.copy()
-        self.build_facades(response.get('facades', {}))
-        # Create a pinger to keep the connection alive (needed for
-        # JaaS; harmless elsewhere).
-        self.loop.create_task(self.pinger())
-        return response
+        return result
 
     async def redirect_info(self):
         try: