Improve handling of closed connections (#148)
[osm/N2VC.git] / juju / client / connection.py
index 4a9766d..7457391 100644 (file)
@@ -8,15 +8,17 @@ import shlex
 import ssl
 import string
 import subprocess
 import ssl
 import string
 import subprocess
+import weakref
 import websockets
 import websockets
+from concurrent.futures import CancelledError
 from http.client import HTTPSConnection
 from http.client import HTTPSConnection
+from pathlib import Path
 
 import asyncio
 import yaml
 
 
 import asyncio
 import yaml
 
-from juju import tag
+from juju import tag, utils
 from juju.client import client
 from juju.client import client
-from juju.client.version_map import VERSION_MAP
 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
 from juju.utils import IdQueue
 
 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
 from juju.utils import IdQueue
 
@@ -39,12 +41,17 @@ class Monitor:
     """
     ERROR = 'error'
     CONNECTED = 'connected'
     """
     ERROR = 'error'
     CONNECTED = 'connected'
+    DISCONNECTING = 'disconnecting'
     DISCONNECTED = 'disconnected'
     DISCONNECTED = 'disconnected'
-    UNKNOWN = 'unknown'
 
     def __init__(self, connection):
 
     def __init__(self, connection):
-        self.connection = connection
-        self.receiver = None
+        self.connection = weakref.ref(connection)
+        self.reconnecting = asyncio.Lock(loop=connection.loop)
+        self.close_called = asyncio.Event(loop=connection.loop)
+        self.receiver_stopped = asyncio.Event(loop=connection.loop)
+        self.pinger_stopped = asyncio.Event(loop=connection.loop)
+        self.receiver_stopped.set()
+        self.pinger_stopped.set()
 
     @property
     def status(self):
 
     @property
     def status(self):
@@ -58,46 +65,27 @@ class Monitor:
         isn't usable until that receiver has been started.
 
         """
         isn't usable until that receiver has been started.
 
         """
+        connection = self.connection()
 
 
-        # DISCONNECTED: connection not yet open
-        if not self.connection.ws:
+        # the connection instance was destroyed but someone kept
+        # a separate reference to the monitor for some reason
+        if not connection:
             return self.DISCONNECTED
             return self.DISCONNECTED
-        if not self.receiver:
-            return self.DISCONNECTED
-
-        # ERROR: Connection closed (or errored), but we didn't call
-        # connection.close
-        if not self.connection.close_called and self.receiver_exceptions():
-            return self.ERROR
-        if not self.connection.close_called and not self.connection.ws.open:
-            # The check for self.receiver existing above guards against the
-            # case where we're not open because we simply haven't
-            # setup the connection yet.
-            return self.ERROR
 
 
-        # DISCONNECTED: cleanly disconnected.
-        if self.connection.close_called and not self.connection.ws.open:
+        # connection cleanly disconnected or not yet opened
+        if not connection.ws:
             return self.DISCONNECTED
 
             return self.DISCONNECTED
 
-        # CONNECTED: everything is fine!
-        if self.connection.ws.open:
-            return self.CONNECTED
+        # close called but not yet complete
+        if self.close_called.is_set():
+            return self.DISCONNECTING
 
 
-        # UNKNOWN: We should never hit this state -- if we do,
-        # something went wrong with the logic above, and we do not
-        # know what state the connection is in.
-        return self.UNKNOWN
+        # connection closed uncleanly (we didn't call connection.close)
+        if self.receiver_stopped.is_set() or not connection.ws.open:
+            return self.ERROR
 
 
-    def receiver_exceptions(self):
-        """
-        Return exceptions in the receiver, if any.
-
-        """
-        if not self.receiver:
-            return None
-        if not self.receiver.done():
-            return None
-        return self.receiver.exception()
+        # everything is fine!
+        return self.CONNECTED
 
 
 class Connection:
 
 
 class Connection:
@@ -117,15 +105,27 @@ class Connection:
     Note: Any connection method or constructor can accept an optional `loop`
     argument to override the default event loop from `asyncio.get_event_loop`.
     """
     Note: Any connection method or constructor can accept an optional `loop`
     argument to override the default event loop from `asyncio.get_event_loop`.
     """
+
+    DEFAULT_FRAME_SIZE = 'default_frame_size'
+    MAX_FRAME_SIZE = 2**22
+    "Maximum size for a single frame.  Defaults to 4MB."
+
     def __init__(
             self, endpoint, uuid, username, password, cacert=None,
     def __init__(
             self, endpoint, uuid, username, password, cacert=None,
-            macaroons=None, loop=None):
+            macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
         self.endpoint = endpoint
         self.endpoint = endpoint
+        self._endpoint = endpoint
         self.uuid = uuid
         self.uuid = uuid
-        self.username = username
-        self.password = password
-        self.macaroons = macaroons
+        if macaroons:
+            self.macaroons = macaroons
+            self.username = ''
+            self.password = ''
+        else:
+            self.macaroons = []
+            self.username = username
+            self.password = password
         self.cacert = cacert
         self.cacert = cacert
+        self._cacert = cacert
         self.loop = loop or asyncio.get_event_loop()
 
         self.__request_id__ = 0
         self.loop = loop or asyncio.get_event_loop()
 
         self.__request_id__ = 0
@@ -133,14 +133,14 @@ class Connection:
         self.ws = None
         self.facades = {}
         self.messages = IdQueue(loop=self.loop)
         self.ws = None
         self.facades = {}
         self.messages = IdQueue(loop=self.loop)
-        self.close_called = False
         self.monitor = Monitor(connection=self)
         self.monitor = Monitor(connection=self)
+        if max_frame_size is self.DEFAULT_FRAME_SIZE:
+            max_frame_size = self.MAX_FRAME_SIZE
+        self.max_frame_size = max_frame_size
 
     @property
     def is_open(self):
 
     @property
     def is_open(self):
-        if self.ws:
-            return self.ws.open
-        return False
+        return self.monitor.status == Monitor.CONNECTED
 
     def _get_ssl(self, cert=None):
         return ssl.create_default_context(
 
     def _get_ssl(self, cert=None):
         return ssl.create_default_context(
@@ -155,15 +155,23 @@ class Connection:
         kw = dict()
         kw['ssl'] = self._get_ssl(self.cacert)
         kw['loop'] = self.loop
         kw = dict()
         kw['ssl'] = self._get_ssl(self.cacert)
         kw['loop'] = self.loop
+        kw['max_size'] = self.max_frame_size
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
-        self.monitor.receiver = self.loop.create_task(self.receiver())
+        self.loop.create_task(self.receiver())
+        self.monitor.receiver_stopped.clear()
         log.info("Driver connected to juju %s", url)
         log.info("Driver connected to juju %s", url)
+        self.monitor.close_called.clear()
         return self
 
     async def close(self):
         return self
 
     async def close(self):
-        self.close_called = True
+        if not self.ws:
+            return
+        self.monitor.close_called.set()
+        await self.monitor.pinger_stopped.wait()
+        await self.monitor.receiver_stopped.wait()
         await self.ws.close()
         await self.ws.close()
+        self.ws = None
 
     async def recv(self, request_id):
         if not self.is_open:
 
     async def recv(self, request_id):
         if not self.is_open:
@@ -171,19 +179,34 @@ class Connection:
         return await self.messages.get(request_id)
 
     async def receiver(self):
         return await self.messages.get(request_id)
 
     async def receiver(self):
-        while self.is_open:
-            try:
-                result = await self.ws.recv()
+        try:
+            while self.is_open:
+                result = await utils.run_with_interrupt(
+                    self.ws.recv(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
                 if result is not None:
                     result = json.loads(result)
                     await self.messages.put(result['request-id'], result)
                 if result is not None:
                     result = json.loads(result)
                     await self.messages.put(result['request-id'], result)
-            except Exception as e:
-                await self.messages.put_all(e)
-                if isinstance(e, websockets.ConnectionClosed):
-                    # ConnectionClosed is not really exceptional for us,
-                    # but it may be for any pending message listeners
-                    return
-                raise
+        except CancelledError:
+            pass
+        except websockets.ConnectionClosed as e:
+            log.warning('Receiver: Connection closed, reconnecting')
+            await self.messages.put_all(e)
+            # the reconnect has to be done as a task because the receiver will
+            # be cancelled by the reconnect and we don't want the reconnect
+            # to be aborted half-way through
+            self.loop.create_task(self.reconnect())
+            return
+        except Exception as e:
+            log.exception("Error in receiver")
+            # make pending listeners aware of the error
+            await self.messages.put_all(e)
+            raise
+        finally:
+            self.monitor.receiver_stopped.set()
 
     async def pinger(self):
         '''
 
     async def pinger(self):
         '''
@@ -193,10 +216,25 @@ class Connection:
         To prevent timing out, we send a ping every ten seconds.
 
         '''
         To prevent timing out, we send a ping every ten seconds.
 
         '''
+        async def _do_ping():
+            try:
+                await pinger_facade.Ping()
+                await asyncio.sleep(10, loop=self.loop)
+            except CancelledError:
+                pass
+
         pinger_facade = client.PingerFacade.from_connection(self)
         pinger_facade = client.PingerFacade.from_connection(self)
-        while self.is_open:
-            await pinger_facade.Ping()
-            await asyncio.sleep(10)
+        try:
+            while True:
+                await utils.run_with_interrupt(
+                    _do_ping(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
+        finally:
+            self.monitor.pinger_stopped.set()
+            return
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -206,7 +244,19 @@ class Connection:
         if "version" not in msg:
             msg['version'] = self.facades[msg['type']]
         outgoing = json.dumps(msg, indent=2, cls=encoder)
         if "version" not in msg:
             msg['version'] = self.facades[msg['type']]
         outgoing = json.dumps(msg, indent=2, cls=encoder)
-        await self.ws.send(outgoing)
+        for attempt in range(3):
+            try:
+                await self.ws.send(outgoing)
+                break
+            except websockets.ConnectionClosed:
+                if attempt == 2:
+                    raise
+                log.warning('RPC: Connection closed, reconnecting')
+                # the reconnect has to be done in a separate task because,
+                # if it is triggered by the pinger, then this RPC call will
+                # be cancelled when the pinger is cancelled by the reconnect,
+                # and we don't want the reconnect to be aborted halfway through
+                await asyncio.wait([self.reconnect()], loop=self.loop)
         result = await self.recv(msg['request-id'])
 
         if not result:
         result = await self.recv(msg['request-id'])
 
         if not result:
@@ -293,6 +343,7 @@ class Connection:
             self.cacert,
             self.macaroons,
             self.loop,
             self.cacert,
             self.macaroons,
             self.loop,
+            self.max_frame_size,
         )
 
     async def controller(self):
         )
 
     async def controller(self):
@@ -309,10 +360,72 @@ class Connection:
             self.loop,
         )
 
             self.loop,
         )
 
+    async def _try_endpoint(self, endpoint, cacert):
+        success = False
+        result = None
+        new_endpoints = []
+
+        self.endpoint = endpoint
+        self.cacert = cacert
+        await self.open()
+        try:
+            result = await self.login()
+            if 'discharge-required-error' in result['response']:
+                log.info('Macaroon discharge required, disconnecting')
+            else:
+                # successful login!
+                log.info('Authenticated')
+                success = True
+        except JujuAPIError as e:
+            if e.error_code != 'redirection required':
+                raise
+            log.info('Controller requested redirect')
+            redirect_info = await self.redirect_info()
+            redir_cacert = redirect_info['ca-cert']
+            new_endpoints = [
+                ("{value}:{port}".format(**s), redir_cacert)
+                for servers in redirect_info['servers']
+                for s in servers if s["scope"] == 'public'
+            ]
+        finally:
+            if not success:
+                await self.close()
+        return success, result, new_endpoints
+
+    async def reconnect(self):
+        """ Force a reconnection.
+        """
+        monitor = self.monitor
+        if monitor.reconnecting.locked() or monitor.close_called.is_set():
+            return
+        async with monitor.reconnecting:
+            await self.close()
+            await self._connect()
+
+    async def _connect(self):
+        endpoints = [(self._endpoint, self._cacert)]
+        while endpoints:
+            _endpoint, _cacert = endpoints.pop(0)
+            success, result, new_endpoints = await self._try_endpoint(
+                _endpoint, _cacert)
+            if success:
+                break
+            endpoints.extend(new_endpoints)
+        else:
+            # ran out of endpoints without a successful login
+            raise Exception("Couldn't authenticate to {}".format(
+                self._endpoint))
+
+        response = result['response']
+        self.info = response.copy()
+        self.build_facades(response.get('facades', {}))
+        self.loop.create_task(self.pinger())
+        self.monitor.pinger_stopped.clear()
+
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
-            macaroons=None, loop=None):
+            macaroons=None, loop=None, max_frame_size=None):
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
@@ -320,50 +433,28 @@ class Connection:
 
         """
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
 
         """
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
-                     loop)
-        await client.open()
-
-        redirect_info = await client.redirect_info()
-        if not redirect_info:
-            await client.login(username, password, macaroons)
-            return client
-
-        await client.close()
-        servers = [
-            s for servers in redirect_info['servers']
-            for s in servers if s["scope"] == 'public'
-        ]
-        for server in servers:
-            client = cls(
-                "{value}:{port}".format(**server), uuid, username,
-                password, redirect_info['ca-cert'], macaroons)
-            await client.open()
-            try:
-                result = await client.login(username, password, macaroons)
-                if 'discharge-required-error' in result:
-                    continue
-                return client
-            except Exception as e:
-                await client.close()
-                log.exception(e)
-
-        raise Exception(
-            "Couldn't authenticate to %s", endpoint)
+                     loop, max_frame_size)
+        await client._connect()
+        return client
 
     @classmethod
 
     @classmethod
-    async def connect_current(cls, loop=None):
+    async def connect_current(cls, loop=None, max_frame_size=None):
         """Connect to the currently active model.
 
         """
         jujudata = JujuData()
         """Connect to the currently active model.
 
         """
         jujudata = JujuData()
+
         controller_name = jujudata.current_controller()
         controller_name = jujudata.current_controller()
+        if not controller_name:
+            raise JujuConnectionError('No current controller')
+
         model_name = jujudata.current_model()
 
         return await cls.connect_model(
         model_name = jujudata.current_model()
 
         return await cls.connect_model(
-            '{}:{}'.format(controller_name, model_name), loop)
+            '{}:{}'.format(controller_name, model_name), loop, max_frame_size)
 
     @classmethod
 
     @classmethod
-    async def connect_current_controller(cls, loop=None):
+    async def connect_current_controller(cls, loop=None, max_frame_size=None):
         """Connect to the currently active controller.
 
         """
         """Connect to the currently active controller.
 
         """
@@ -372,10 +463,12 @@ class Connection:
         if not controller_name:
             raise JujuConnectionError('No current controller')
 
         if not controller_name:
             raise JujuConnectionError('No current controller')
 
-        return await cls.connect_controller(controller_name, loop)
+        return await cls.connect_controller(controller_name, loop,
+                                            max_frame_size)
 
     @classmethod
 
     @classmethod
-    async def connect_controller(cls, controller_name, loop=None):
+    async def connect_controller(cls, controller_name, loop=None,
+                                 max_frame_size=None):
         """Connect to a controller by name.
 
         """
         """Connect to a controller by name.
 
         """
@@ -386,13 +479,14 @@ class Connection:
         accounts = jujudata.accounts()[controller_name]
         username = accounts['user']
         password = accounts.get('password')
         accounts = jujudata.accounts()[controller_name]
         username = accounts['user']
         password = accounts.get('password')
-        macaroons = get_macaroons() if not password else None
+        macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
 
         return await cls.connect(
-            endpoint, None, username, password, cacert, macaroons, loop)
+            endpoint, None, username, password, cacert, macaroons, loop,
+            max_frame_size)
 
     @classmethod
 
     @classmethod
-    async def connect_model(cls, model, loop=None):
+    async def connect_model(cls, model, loop=None, max_frame_size=None):
         """Connect to a model by name.
 
         :param str model: [<controller>:]<model>
         """Connect to a model by name.
 
         :param str model: [<controller>:]<model>
@@ -420,35 +514,19 @@ class Connection:
         password = accounts.get('password')
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
         password = accounts.get('password')
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
-        macaroons = get_macaroons() if not password else None
+        macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert, macaroons, loop)
+            endpoint, model_uuid, username, password, cacert, macaroons, loop,
+            max_frame_size)
 
     def build_facades(self, facades):
         self.facades.clear()
 
     def build_facades(self, facades):
         self.facades.clear()
-        # In order to work around an issue where the juju api is not
-        # returning a complete list of facades, we simply look up the
-        # juju version in a pregenerated map, and use that info to
-        # populate our list of facades.
-
-        # TODO: if a future version of juju fixes this bug, restore
-        # the following code for that version and higher:
-        # for facade in facades:
-        #     self.facades[facade['name']] = facade['versions'][-1]
-        try:
-            self.facades = VERSION_MAP[self.info['server-version']]
-        except KeyError:
-            log.warning("Could not find a set of facades for {}. Using "
-                        "the latest facade set instead".format(
-                            self.info['server-version']))
-            self.facades = VERSION_MAP['latest']
-
-    async def login(self, username, password, macaroons=None):
-        if macaroons:
-            username = ''
-            password = ''
+        for facade in facades:
+            self.facades[facade['name']] = facade['versions'][-1]
 
 
+    async def login(self):
+        username = self.username
         if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
         if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
@@ -458,17 +536,11 @@ class Connection:
             "version": 3,
             "params": {
                 "auth-tag": username,
             "version": 3,
             "params": {
                 "auth-tag": username,
-                "credentials": password,
+                "credentials": self.password,
                 "nonce": "".join(random.sample(string.printable, 12)),
                 "nonce": "".join(random.sample(string.printable, 12)),
-                "macaroons": macaroons or []
+                "macaroons": self.macaroons
             }})
             }})
-        response = result['response']
-        self.info = response.copy()
-        self.build_facades(response.get('facades', {}))
-        # Create a pinger to keep the connection alive (needed for
-        # JaaS; harmless elsewhere).
-        self.loop.create_task(self.pinger())
-        return response
+        return result
 
     async def redirect_info(self):
         try:
 
     async def redirect_info(self):
         try:
@@ -518,16 +590,26 @@ class JujuData:
             return yaml.safe_load(f)[key]
 
 
             return yaml.safe_load(f)[key]
 
 
-def get_macaroons():
+def get_macaroons(controller_name=None):
     """Decode and return macaroons from default ~/.go-cookies
 
     """
     """Decode and return macaroons from default ~/.go-cookies
 
     """
-    try:
-        cookie_file = os.path.expanduser('~/.go-cookies')
-        with open(cookie_file, 'r') as f:
-            cookies = json.load(f)
-    except (OSError, ValueError):
-        log.warn("Couldn't load macaroons from %s", cookie_file)
+    cookie_files = []
+    if controller_name:
+        cookie_files.append('~/.local/share/juju/cookies/{}.json'.format(
+            controller_name))
+    cookie_files.append('~/.go-cookies')
+    for cookie_file in cookie_files:
+        cookie_file = Path(cookie_file).expanduser()
+        if cookie_file.exists():
+            try:
+                cookies = json.loads(cookie_file.read_text())
+                break
+            except (OSError, ValueError):
+                log.warn("Couldn't load macaroons from %s", cookie_file)
+                return []
+    else:
+        log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files))
         return []
 
     base64_macaroons = [
         return []
 
     base64_macaroons = [