Update README to point to RTD for docs instead of PythonHosted.org
[osm/N2VC.git] / juju / client / connection.py
index 6c31ab6..7457391 100644 (file)
@@ -8,13 +8,17 @@ import shlex
 import ssl
 import string
 import subprocess
+import weakref
 import websockets
+from concurrent.futures import CancelledError
 from http.client import HTTPSConnection
+from pathlib import Path
 
 import asyncio
 import yaml
 
-from juju import tag
+from juju import tag, utils
+from juju.client import client
 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
 from juju.utils import IdQueue
 
@@ -37,12 +41,17 @@ class Monitor:
     """
     ERROR = 'error'
     CONNECTED = 'connected'
+    DISCONNECTING = 'disconnecting'
     DISCONNECTED = 'disconnected'
-    UNKNOWN = 'unknown'
 
     def __init__(self, connection):
-        self.connection = connection
-        self.receiver = None
+        self.connection = weakref.ref(connection)
+        self.reconnecting = asyncio.Lock(loop=connection.loop)
+        self.close_called = asyncio.Event(loop=connection.loop)
+        self.receiver_stopped = asyncio.Event(loop=connection.loop)
+        self.pinger_stopped = asyncio.Event(loop=connection.loop)
+        self.receiver_stopped.set()
+        self.pinger_stopped.set()
 
     @property
     def status(self):
@@ -56,46 +65,27 @@ class Monitor:
         isn't usable until that receiver has been started.
 
         """
+        connection = self.connection()
 
-        # DISCONNECTED: connection not yet open
-        if not self.connection.ws:
+        # the connection instance was destroyed but someone kept
+        # a separate reference to the monitor for some reason
+        if not connection:
             return self.DISCONNECTED
-        if not self.receiver:
-            return self.DISCONNECTED
-
-        # ERROR: Connection closed (or errored), but we didn't call
-        # connection.close
-        if not self.connection.close_called and self.receiver_exceptions():
-            return self.ERROR
-        if not self.connection.close_called and not self.connection.ws.open:
-            # The check for self.receiver existing above guards against the
-            # case where we're not open because we simply haven't
-            # setup the connection yet.
-            return self.ERROR
 
-        # DISCONNECTED: cleanly disconnected.
-        if self.connection.close_called and not self.connection.ws.open:
+        # connection cleanly disconnected or not yet opened
+        if not connection.ws:
             return self.DISCONNECTED
 
-        # CONNECTED: everything is fine!
-        if self.connection.ws.open:
-            return self.CONNECTED
+        # close called but not yet complete
+        if self.close_called.is_set():
+            return self.DISCONNECTING
 
-        # UNKNOWN: We should never hit this state -- if we do,
-        # something went wrong with the logic above, and we do not
-        # know what state the connection is in.
-        return self.UNKNOWN
+        # connection closed uncleanly (we didn't call connection.close)
+        if self.receiver_stopped.is_set() or not connection.ws.open:
+            return self.ERROR
 
-    def receiver_exceptions(self):
-        """
-        Return exceptions in the receiver, if any.
-
-        """
-        if not self.receiver:
-            return None
-        if not self.receiver.done():
-            return None
-        return self.receiver.exception()
+        # everything is fine!
+        return self.CONNECTED
 
 
 class Connection:
@@ -115,15 +105,27 @@ class Connection:
     Note: Any connection method or constructor can accept an optional `loop`
     argument to override the default event loop from `asyncio.get_event_loop`.
     """
+
+    DEFAULT_FRAME_SIZE = 'default_frame_size'
+    MAX_FRAME_SIZE = 2**22
+    "Maximum size for a single frame.  Defaults to 4MB."
+
     def __init__(
             self, endpoint, uuid, username, password, cacert=None,
-            macaroons=None, loop=None):
+            macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
         self.endpoint = endpoint
+        self._endpoint = endpoint
         self.uuid = uuid
-        self.username = username
-        self.password = password
-        self.macaroons = macaroons
+        if macaroons:
+            self.macaroons = macaroons
+            self.username = ''
+            self.password = ''
+        else:
+            self.macaroons = []
+            self.username = username
+            self.password = password
         self.cacert = cacert
+        self._cacert = cacert
         self.loop = loop or asyncio.get_event_loop()
 
         self.__request_id__ = 0
@@ -131,14 +133,14 @@ class Connection:
         self.ws = None
         self.facades = {}
         self.messages = IdQueue(loop=self.loop)
-        self.close_called = False
         self.monitor = Monitor(connection=self)
+        if max_frame_size is self.DEFAULT_FRAME_SIZE:
+            max_frame_size = self.MAX_FRAME_SIZE
+        self.max_frame_size = max_frame_size
 
     @property
     def is_open(self):
-        if self.ws:
-            return self.ws.open
-        return False
+        return self.monitor.status == Monitor.CONNECTED
 
     def _get_ssl(self, cert=None):
         return ssl.create_default_context(
@@ -153,15 +155,23 @@ class Connection:
         kw = dict()
         kw['ssl'] = self._get_ssl(self.cacert)
         kw['loop'] = self.loop
+        kw['max_size'] = self.max_frame_size
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
-        self.monitor.receiver = self.loop.create_task(self.receiver())
+        self.loop.create_task(self.receiver())
+        self.monitor.receiver_stopped.clear()
         log.info("Driver connected to juju %s", url)
+        self.monitor.close_called.clear()
         return self
 
     async def close(self):
-        self.close_called = True
+        if not self.ws:
+            return
+        self.monitor.close_called.set()
+        await self.monitor.pinger_stopped.wait()
+        await self.monitor.receiver_stopped.wait()
         await self.ws.close()
+        self.ws = None
 
     async def recv(self, request_id):
         if not self.is_open:
@@ -169,19 +179,62 @@ class Connection:
         return await self.messages.get(request_id)
 
     async def receiver(self):
-        while self.is_open:
-            try:
-                result = await self.ws.recv()
+        try:
+            while self.is_open:
+                result = await utils.run_with_interrupt(
+                    self.ws.recv(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
                 if result is not None:
                     result = json.loads(result)
                     await self.messages.put(result['request-id'], result)
-            except Exception as e:
-                await self.messages.put_all(e)
-                if isinstance(e, websockets.ConnectionClosed):
-                    # ConnectionClosed is not really exceptional for us,
-                    # but it may be for any pending message listeners
-                    return
-                raise
+        except CancelledError:
+            pass
+        except websockets.ConnectionClosed as e:
+            log.warning('Receiver: Connection closed, reconnecting')
+            await self.messages.put_all(e)
+            # the reconnect has to be done as a task because the receiver will
+            # be cancelled by the reconnect and we don't want the reconnect
+            # to be aborted half-way through
+            self.loop.create_task(self.reconnect())
+            return
+        except Exception as e:
+            log.exception("Error in receiver")
+            # make pending listeners aware of the error
+            await self.messages.put_all(e)
+            raise
+        finally:
+            self.monitor.receiver_stopped.set()
+
+    async def pinger(self):
+        '''
+        A Controller can time us out if we are silent for too long. This
+        is especially true in JaaS, which has a fairly strict timeout.
+
+        To prevent timing out, we send a ping every ten seconds.
+
+        '''
+        async def _do_ping():
+            try:
+                await pinger_facade.Ping()
+                await asyncio.sleep(10, loop=self.loop)
+            except CancelledError:
+                pass
+
+        pinger_facade = client.PingerFacade.from_connection(self)
+        try:
+            while True:
+                await utils.run_with_interrupt(
+                    _do_ping(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
+        finally:
+            self.monitor.pinger_stopped.set()
+            return
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -191,7 +244,19 @@ class Connection:
         if "version" not in msg:
             msg['version'] = self.facades[msg['type']]
         outgoing = json.dumps(msg, indent=2, cls=encoder)
-        await self.ws.send(outgoing)
+        for attempt in range(3):
+            try:
+                await self.ws.send(outgoing)
+                break
+            except websockets.ConnectionClosed:
+                if attempt == 2:
+                    raise
+                log.warning('RPC: Connection closed, reconnecting')
+                # the reconnect has to be done in a separate task because,
+                # if it is triggered by the pinger, then this RPC call will
+                # be cancelled when the pinger is cancelled by the reconnect,
+                # and we don't want the reconnect to be aborted halfway through
+                await asyncio.wait([self.reconnect()], loop=self.loop)
         result = await self.recv(msg['request-id'])
 
         if not result:
@@ -278,6 +343,7 @@ class Connection:
             self.cacert,
             self.macaroons,
             self.loop,
+            self.max_frame_size,
         )
 
     async def controller(self):
@@ -294,10 +360,72 @@ class Connection:
             self.loop,
         )
 
+    async def _try_endpoint(self, endpoint, cacert):
+        success = False
+        result = None
+        new_endpoints = []
+
+        self.endpoint = endpoint
+        self.cacert = cacert
+        await self.open()
+        try:
+            result = await self.login()
+            if 'discharge-required-error' in result['response']:
+                log.info('Macaroon discharge required, disconnecting')
+            else:
+                # successful login!
+                log.info('Authenticated')
+                success = True
+        except JujuAPIError as e:
+            if e.error_code != 'redirection required':
+                raise
+            log.info('Controller requested redirect')
+            redirect_info = await self.redirect_info()
+            redir_cacert = redirect_info['ca-cert']
+            new_endpoints = [
+                ("{value}:{port}".format(**s), redir_cacert)
+                for servers in redirect_info['servers']
+                for s in servers if s["scope"] == 'public'
+            ]
+        finally:
+            if not success:
+                await self.close()
+        return success, result, new_endpoints
+
+    async def reconnect(self):
+        """ Force a reconnection.
+        """
+        monitor = self.monitor
+        if monitor.reconnecting.locked() or monitor.close_called.is_set():
+            return
+        async with monitor.reconnecting:
+            await self.close()
+            await self._connect()
+
+    async def _connect(self):
+        endpoints = [(self._endpoint, self._cacert)]
+        while endpoints:
+            _endpoint, _cacert = endpoints.pop(0)
+            success, result, new_endpoints = await self._try_endpoint(
+                _endpoint, _cacert)
+            if success:
+                break
+            endpoints.extend(new_endpoints)
+        else:
+            # ran out of endpoints without a successful login
+            raise Exception("Couldn't authenticate to {}".format(
+                self._endpoint))
+
+        response = result['response']
+        self.info = response.copy()
+        self.build_facades(response.get('facades', {}))
+        self.loop.create_task(self.pinger())
+        self.monitor.pinger_stopped.clear()
+
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
-            macaroons=None, loop=None):
+            macaroons=None, loop=None, max_frame_size=None):
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
@@ -305,50 +433,28 @@ class Connection:
 
         """
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
-                     loop)
-        await client.open()
-
-        redirect_info = await client.redirect_info()
-        if not redirect_info:
-            await client.login(username, password, macaroons)
-            return client
-
-        await client.close()
-        servers = [
-            s for servers in redirect_info['servers']
-            for s in servers if s["scope"] == 'public'
-        ]
-        for server in servers:
-            client = cls(
-                "{value}:{port}".format(**server), uuid, username,
-                password, redirect_info['ca-cert'], macaroons)
-            await client.open()
-            try:
-                result = await client.login(username, password, macaroons)
-                if 'discharge-required-error' in result:
-                    continue
-                return client
-            except Exception as e:
-                await client.close()
-                log.exception(e)
-
-        raise Exception(
-            "Couldn't authenticate to %s", endpoint)
+                     loop, max_frame_size)
+        await client._connect()
+        return client
 
     @classmethod
-    async def connect_current(cls, loop=None):
+    async def connect_current(cls, loop=None, max_frame_size=None):
         """Connect to the currently active model.
 
         """
         jujudata = JujuData()
+
         controller_name = jujudata.current_controller()
+        if not controller_name:
+            raise JujuConnectionError('No current controller')
+
         model_name = jujudata.current_model()
 
         return await cls.connect_model(
-            '{}:{}'.format(controller_name, model_name), loop)
+            '{}:{}'.format(controller_name, model_name), loop, max_frame_size)
 
     @classmethod
-    async def connect_current_controller(cls, loop=None):
+    async def connect_current_controller(cls, loop=None, max_frame_size=None):
         """Connect to the currently active controller.
 
         """
@@ -357,10 +463,12 @@ class Connection:
         if not controller_name:
             raise JujuConnectionError('No current controller')
 
-        return await cls.connect_controller(controller_name, loop)
+        return await cls.connect_controller(controller_name, loop,
+                                            max_frame_size)
 
     @classmethod
-    async def connect_controller(cls, controller_name, loop=None):
+    async def connect_controller(cls, controller_name, loop=None,
+                                 max_frame_size=None):
         """Connect to a controller by name.
 
         """
@@ -371,13 +479,14 @@ class Connection:
         accounts = jujudata.accounts()[controller_name]
         username = accounts['user']
         password = accounts.get('password')
-        macaroons = get_macaroons() if not password else None
+        macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
-            endpoint, None, username, password, cacert, macaroons, loop)
+            endpoint, None, username, password, cacert, macaroons, loop,
+            max_frame_size)
 
     @classmethod
-    async def connect_model(cls, model, loop=None):
+    async def connect_model(cls, model, loop=None, max_frame_size=None):
         """Connect to a model by name.
 
         :param str model: [<controller>:]<model>
@@ -405,21 +514,19 @@ class Connection:
         password = accounts.get('password')
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
-        macaroons = get_macaroons() if not password else None
+        macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert, macaroons, loop)
+            endpoint, model_uuid, username, password, cacert, macaroons, loop,
+            max_frame_size)
 
-    def build_facades(self, info):
+    def build_facades(self, facades):
         self.facades.clear()
-        for facade in info:
+        for facade in facades:
             self.facades[facade['name']] = facade['versions'][-1]
 
-    async def login(self, username, password, macaroons=None):
-        if macaroons:
-            username = ''
-            password = ''
-
+    async def login(self):
+        username = self.username
         if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
@@ -429,14 +536,11 @@ class Connection:
             "version": 3,
             "params": {
                 "auth-tag": username,
-                "credentials": password,
+                "credentials": self.password,
                 "nonce": "".join(random.sample(string.printable, 12)),
-                "macaroons": macaroons or []
+                "macaroons": self.macaroons
             }})
-        response = result['response']
-        self.build_facades(response.get('facades', {}))
-        self.info = response.copy()
-        return response
+        return result
 
     async def redirect_info(self):
         try:
@@ -486,16 +590,26 @@ class JujuData:
             return yaml.safe_load(f)[key]
 
 
-def get_macaroons():
+def get_macaroons(controller_name=None):
     """Decode and return macaroons from default ~/.go-cookies
 
     """
-    try:
-        cookie_file = os.path.expanduser('~/.go-cookies')
-        with open(cookie_file, 'r') as f:
-            cookies = json.load(f)
-    except (OSError, ValueError):
-        log.warn("Couldn't load macaroons from %s", cookie_file)
+    cookie_files = []
+    if controller_name:
+        cookie_files.append('~/.local/share/juju/cookies/{}.json'.format(
+            controller_name))
+    cookie_files.append('~/.go-cookies')
+    for cookie_file in cookie_files:
+        cookie_file = Path(cookie_file).expanduser()
+        if cookie_file.exists():
+            try:
+                cookies = json.loads(cookie_file.read_text())
+                break
+            except (OSError, ValueError):
+                log.warn("Couldn't load macaroons from %s", cookie_file)
+                return []
+    else:
+        log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files))
         return []
 
     base64_macaroons = [