Refactored login code to better handle redirects (#116)
[osm/N2VC.git] / juju / client / connection.py
index 486a57f..2be360f 100644 (file)
@@ -15,11 +15,92 @@ import asyncio
 import yaml
 
 from juju import tag
+from juju.client import client
+from juju.client.version_map import VERSION_MAP
 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
+from juju.utils import IdQueue
 
 log = logging.getLogger("websocket")
 
 
+class Monitor:
+    """
+    Monitor helper class for our Connection class.
+
+    Contains a reference to an instantiated Connection, along with a
+    reference to the Connection.receiver Future. Upon inspecttion of
+    these objects, this class determines whether the connection is in
+    an 'error', 'connected' or 'disconnected' state.
+
+    Use this class to stay up to date on the health of a connection,
+    and take appropriate action if the connection errors out due to
+    network issues or other unexpected circumstances.
+
+    """
+    ERROR = 'error'
+    CONNECTED = 'connected'
+    DISCONNECTED = 'disconnected'
+    UNKNOWN = 'unknown'
+
+    def __init__(self, connection):
+        self.connection = connection
+        self.receiver = None
+        self.pinger = None
+
+    @property
+    def status(self):
+        """
+        Determine the status of the connection and receiver, and return
+        ERROR, CONNECTED, or DISCONNECTED as appropriate.
+
+        For simplicity, we only consider ourselves to be connected
+        after the Connection class has setup a receiver task. This
+        only happens after the websocket is open, and the connection
+        isn't usable until that receiver has been started.
+
+        """
+
+        # DISCONNECTED: connection not yet open
+        if not self.connection.ws:
+            return self.DISCONNECTED
+        if not self.receiver:
+            return self.DISCONNECTED
+
+        # ERROR: Connection closed (or errored), but we didn't call
+        # connection.close
+        if not self.connection.close_called and self.receiver_exceptions():
+            return self.ERROR
+        if not self.connection.close_called and not self.connection.ws.open:
+            # The check for self.receiver existing above guards against the
+            # case where we're not open because we simply haven't
+            # setup the connection yet.
+            return self.ERROR
+
+        # DISCONNECTED: cleanly disconnected.
+        if self.connection.close_called and not self.connection.ws.open:
+            return self.DISCONNECTED
+
+        # CONNECTED: everything is fine!
+        if self.connection.ws.open:
+            return self.CONNECTED
+
+        # UNKNOWN: We should never hit this state -- if we do,
+        # something went wrong with the logic above, and we do not
+        # know what state the connection is in.
+        return self.UNKNOWN
+
+    def receiver_exceptions(self):
+        """
+        Return exceptions in the receiver, if any.
+
+        """
+        if not self.receiver:
+            return None
+        if not self.receiver.done():
+            return None
+        return self.receiver.exception()
+
+
 class Connection:
     """
     Usage::
@@ -42,9 +123,14 @@ class Connection:
             macaroons=None, loop=None):
         self.endpoint = endpoint
         self.uuid = uuid
-        self.username = username
-        self.password = password
-        self.macaroons = macaroons
+        if macaroons:
+            self.macaroons = macaroons
+            self.username = ''
+            self.password = ''
+        else:
+            self.macaroons = []
+            self.username = username
+            self.password = password
         self.cacert = cacert
         self.loop = loop or asyncio.get_event_loop()
 
@@ -52,7 +138,9 @@ class Connection:
         self.addr = None
         self.ws = None
         self.facades = {}
-        self.messages = {}
+        self.messages = IdQueue(loop=self.loop)
+        self.close_called = False
+        self.monitor = Monitor(connection=self)
 
     @property
     def is_open(self):
@@ -75,27 +163,53 @@ class Connection:
         kw['loop'] = self.loop
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
+        self.monitor.receiver = self.loop.create_task(self.receiver())
         log.info("Driver connected to juju %s", url)
         return self
 
     async def close(self):
+        if not self.is_open:
+            return
+        self.close_called = True
+        if self.monitor.pinger:
+            # might be closing due to login failure,
+            # in which case we won't have a pinger yet
+            self.monitor.pinger.cancel()
+        self.monitor.receiver.cancel()
         await self.ws.close()
 
     async def recv(self, request_id):
-        while not self.messages.get(request_id):
-            await asyncio.sleep(0)
-
-        result = self.messages[request_id]
-
-        del self.messages[request_id]
-        return result
+        if not self.is_open:
+            raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
+        return await self.messages.get(request_id)
 
     async def receiver(self):
         while self.is_open:
-            result = await self.ws.recv()
-            if result is not None:
-                result = json.loads(result)
-                self.messages[result['request-id']] = result
+            try:
+                result = await self.ws.recv()
+                if result is not None:
+                    result = json.loads(result)
+                    await self.messages.put(result['request-id'], result)
+            except Exception as e:
+                await self.messages.put_all(e)
+                if isinstance(e, websockets.ConnectionClosed):
+                    # ConnectionClosed is not really exceptional for us,
+                    # but it may be for any pending message listeners
+                    return
+                raise
+
+    async def pinger(self):
+        '''
+        A Controller can time us out if we are silent for too long. This
+        is especially true in JaaS, which has a fairly strict timeout.
+
+        To prevent timing out, we send a ping every ten seconds.
+
+        '''
+        pinger_facade = client.PingerFacade.from_connection(self)
+        while self.is_open:
+            await pinger_facade.Ping()
+            await asyncio.sleep(10, loop=self.loop)
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -115,7 +229,7 @@ class Connection:
             # API Error Response
             raise JujuAPIError(result)
 
-        if not 'response' in result:
+        if 'response' not in result:
             # This may never happen
             return result
 
@@ -208,6 +322,38 @@ class Connection:
             self.loop,
         )
 
+    async def _try_endpoint(self, endpoint, cacert):
+        success = False
+        result = None
+        new_endpoints = []
+
+        self.endpoint = endpoint
+        self.cacert = cacert
+        await self.open()
+        try:
+            result = await self.login()
+            if 'discharge-required-error' in result['response']:
+                log.info('Macaroon discharge required, disconnecting')
+            else:
+                # successful login!
+                log.info('Authenticated')
+                success = True
+        except JujuAPIError as e:
+            if e.error_code != 'redirection required':
+                raise
+            log.info('Controller requested redirect')
+            redirect_info = await self.redirect_info()
+            redir_cacert = redirect_info['ca-cert']
+            new_endpoints = [
+                ("{value}:{port}".format(**s), redir_cacert)
+                for servers in redirect_info['servers']
+                for s in servers if s["scope"] == 'public'
+            ]
+        finally:
+            if not success:
+                await self.close()
+        return success, result, new_endpoints
+
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
@@ -220,35 +366,24 @@ class Connection:
         """
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
                      loop)
-        await client.open()
-        self.loop.create_task(self.receiver)
-
-        redirect_info = await client.redirect_info()
-        if not redirect_info:
-            await client.login(username, password, macaroons)
-            return client
-
-        await client.close()
-        servers = [
-            s for servers in redirect_info['servers']
-            for s in servers if s["scope"] == 'public'
-        ]
-        for server in servers:
-            client = cls(
-                "{value}:{port}".format(**server), uuid, username,
-                password, redirect_info['ca-cert'], macaroons)
-            await client.open()
-            try:
-                result = await client.login(username, password, macaroons)
-                if 'discharge-required-error' in result:
-                    continue
-                return client
-            except Exception as e:
-                await client.close()
-                log.exception(e)
+        endpoints = [(endpoint, cacert)]
+        while endpoints:
+            _endpoint, _cacert = endpoints.pop(0)
+            success, result, new_endpoints = await client._try_endpoint(
+                _endpoint, _cacert)
+            if success:
+                break
+            endpoints.extend(new_endpoints)
+        else:
+            # ran out of endpoints without a successful login
+            raise Exception("Couldn't authenticate to {}".format(endpoint))
 
-        raise Exception(
-            "Couldn't authenticate to %s", endpoint)
+        response = result['response']
+        client.info = response.copy()
+        client.build_facades(response.get('facades', {}))
+        client.monitor.pinger = client.loop.create_task(client.pinger())
+
+        return client
 
     @classmethod
     async def connect_current(cls, loop=None):
@@ -325,16 +460,27 @@ class Connection:
         return await cls.connect(
             endpoint, model_uuid, username, password, cacert, macaroons, loop)
 
-    def build_facades(self, info):
+    def build_facades(self, facades):
         self.facades.clear()
-        for facade in info:
-            self.facades[facade['name']] = facade['versions'][-1]
-
-    async def login(self, username, password, macaroons=None):
-        if macaroons:
-            username = ''
-            password = ''
-
+        # In order to work around an issue where the juju api is not
+        # returning a complete list of facades, we simply look up the
+        # juju version in a pregenerated map, and use that info to
+        # populate our list of facades.
+
+        # TODO: if a future version of juju fixes this bug, restore
+        # the following code for that version and higher:
+        # for facade in facades:
+        #     self.facades[facade['name']] = facade['versions'][-1]
+        try:
+            self.facades = VERSION_MAP[self.info['server-version']]
+        except KeyError:
+            log.warning("Could not find a set of facades for {}. Using "
+                        "the latest facade set instead".format(
+                            self.info['server-version']))
+            self.facades = VERSION_MAP['latest']
+
+    async def login(self):
+        username = self.username
         if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
@@ -344,14 +490,11 @@ class Connection:
             "version": 3,
             "params": {
                 "auth-tag": username,
-                "credentials": password,
+                "credentials": self.password,
                 "nonce": "".join(random.sample(string.printable, 12)),
-                "macaroons": macaroons or []
+                "macaroons": self.macaroons
             }})
-        response = result['response']
-        self.build_facades(response.get('facades', {}))
-        self.info = response.copy()
-        return response
+        return result
 
     async def redirect_info(self):
         try: