Update README to point to RTD for docs instead of PythonHosted.org
[osm/N2VC.git] / juju / client / connection.py
index 69ac425..7457391 100644 (file)
@@ -1,4 +1,4 @@
-import asyncio
+import base64
 import io
 import json
 import logging
 import io
 import json
 import logging
@@ -8,13 +8,86 @@ import shlex
 import ssl
 import string
 import subprocess
 import ssl
 import string
 import subprocess
+import weakref
 import websockets
 import websockets
+from concurrent.futures import CancelledError
+from http.client import HTTPSConnection
+from pathlib import Path
 
 
+import asyncio
 import yaml
 
 import yaml
 
+from juju import tag, utils
+from juju.client import client
+from juju.errors import JujuError, JujuAPIError, JujuConnectionError
+from juju.utils import IdQueue
+
 log = logging.getLogger("websocket")
 
 
 log = logging.getLogger("websocket")
 
 
+class Monitor:
+    """
+    Monitor helper class for our Connection class.
+
+    Contains a reference to an instantiated Connection, along with a
+    reference to the Connection.receiver Future. Upon inspecttion of
+    these objects, this class determines whether the connection is in
+    an 'error', 'connected' or 'disconnected' state.
+
+    Use this class to stay up to date on the health of a connection,
+    and take appropriate action if the connection errors out due to
+    network issues or other unexpected circumstances.
+
+    """
+    ERROR = 'error'
+    CONNECTED = 'connected'
+    DISCONNECTING = 'disconnecting'
+    DISCONNECTED = 'disconnected'
+
+    def __init__(self, connection):
+        self.connection = weakref.ref(connection)
+        self.reconnecting = asyncio.Lock(loop=connection.loop)
+        self.close_called = asyncio.Event(loop=connection.loop)
+        self.receiver_stopped = asyncio.Event(loop=connection.loop)
+        self.pinger_stopped = asyncio.Event(loop=connection.loop)
+        self.receiver_stopped.set()
+        self.pinger_stopped.set()
+
+    @property
+    def status(self):
+        """
+        Determine the status of the connection and receiver, and return
+        ERROR, CONNECTED, or DISCONNECTED as appropriate.
+
+        For simplicity, we only consider ourselves to be connected
+        after the Connection class has setup a receiver task. This
+        only happens after the websocket is open, and the connection
+        isn't usable until that receiver has been started.
+
+        """
+        connection = self.connection()
+
+        # the connection instance was destroyed but someone kept
+        # a separate reference to the monitor for some reason
+        if not connection:
+            return self.DISCONNECTED
+
+        # connection cleanly disconnected or not yet opened
+        if not connection.ws:
+            return self.DISCONNECTED
+
+        # close called but not yet complete
+        if self.close_called.is_set():
+            return self.DISCONNECTING
+
+        # connection closed uncleanly (we didn't call connection.close)
+        if self.receiver_stopped.is_set() or not connection.ws.open:
+            return self.ERROR
+
+        # everything is fine!
+        return self.CONNECTED
+
+
 class Connection:
     """
     Usage::
 class Connection:
     """
     Usage::
@@ -29,39 +102,139 @@ class Connection:
         # Connect to the currently active model
         client = await Connection.connect_current()
 
         # Connect to the currently active model
         client = await Connection.connect_current()
 
+    Note: Any connection method or constructor can accept an optional `loop`
+    argument to override the default event loop from `asyncio.get_event_loop`.
     """
     """
-    def __init__(self, endpoint, uuid, username, password, cacert=None):
+
+    DEFAULT_FRAME_SIZE = 'default_frame_size'
+    MAX_FRAME_SIZE = 2**22
+    "Maximum size for a single frame.  Defaults to 4MB."
+
+    def __init__(
+            self, endpoint, uuid, username, password, cacert=None,
+            macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
         self.endpoint = endpoint
         self.endpoint = endpoint
+        self._endpoint = endpoint
         self.uuid = uuid
         self.uuid = uuid
-        self.username = username
-        self.password = password
+        if macaroons:
+            self.macaroons = macaroons
+            self.username = ''
+            self.password = ''
+        else:
+            self.macaroons = []
+            self.username = username
+            self.password = password
         self.cacert = cacert
         self.cacert = cacert
+        self._cacert = cacert
+        self.loop = loop or asyncio.get_event_loop()
 
         self.__request_id__ = 0
         self.addr = None
         self.ws = None
         self.facades = {}
 
         self.__request_id__ = 0
         self.addr = None
         self.ws = None
         self.facades = {}
+        self.messages = IdQueue(loop=self.loop)
+        self.monitor = Monitor(connection=self)
+        if max_frame_size is self.DEFAULT_FRAME_SIZE:
+            max_frame_size = self.MAX_FRAME_SIZE
+        self.max_frame_size = max_frame_size
 
 
-    def _get_ssl(self, cert):
+    @property
+    def is_open(self):
+        return self.monitor.status == Monitor.CONNECTED
+
+    def _get_ssl(self, cert=None):
         return ssl.create_default_context(
             purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
 
         return ssl.create_default_context(
             purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
 
-    async def open(self, addr, cert=None):
+    async def open(self):
+        if self.uuid:
+            url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
+        else:
+            url = "wss://{}/api".format(self.endpoint)
+
         kw = dict()
         kw = dict()
-        if cert:
-            kw['ssl'] = self._get_ssl(cert)
-        self.addr = addr
-        self.ws = await websockets.connect(addr, **kw)
+        kw['ssl'] = self._get_ssl(self.cacert)
+        kw['loop'] = self.loop
+        kw['max_size'] = self.max_frame_size
+        self.addr = url
+        self.ws = await websockets.connect(url, **kw)
+        self.loop.create_task(self.receiver())
+        self.monitor.receiver_stopped.clear()
+        log.info("Driver connected to juju %s", url)
+        self.monitor.close_called.clear()
         return self
 
     async def close(self):
         return self
 
     async def close(self):
+        if not self.ws:
+            return
+        self.monitor.close_called.set()
+        await self.monitor.pinger_stopped.wait()
+        await self.monitor.receiver_stopped.wait()
         await self.ws.close()
         await self.ws.close()
+        self.ws = None
 
 
-    async def recv(self):
-        result = await self.ws.recv()
-        if result is not None:
-            result = json.loads(result)
-        return result
+    async def recv(self, request_id):
+        if not self.is_open:
+            raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
+        return await self.messages.get(request_id)
+
+    async def receiver(self):
+        try:
+            while self.is_open:
+                result = await utils.run_with_interrupt(
+                    self.ws.recv(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
+                if result is not None:
+                    result = json.loads(result)
+                    await self.messages.put(result['request-id'], result)
+        except CancelledError:
+            pass
+        except websockets.ConnectionClosed as e:
+            log.warning('Receiver: Connection closed, reconnecting')
+            await self.messages.put_all(e)
+            # the reconnect has to be done as a task because the receiver will
+            # be cancelled by the reconnect and we don't want the reconnect
+            # to be aborted half-way through
+            self.loop.create_task(self.reconnect())
+            return
+        except Exception as e:
+            log.exception("Error in receiver")
+            # make pending listeners aware of the error
+            await self.messages.put_all(e)
+            raise
+        finally:
+            self.monitor.receiver_stopped.set()
+
+    async def pinger(self):
+        '''
+        A Controller can time us out if we are silent for too long. This
+        is especially true in JaaS, which has a fairly strict timeout.
+
+        To prevent timing out, we send a ping every ten seconds.
+
+        '''
+        async def _do_ping():
+            try:
+                await pinger_facade.Ping()
+                await asyncio.sleep(10, loop=self.loop)
+            except CancelledError:
+                pass
+
+        pinger_facade = client.PingerFacade.from_connection(self)
+        try:
+            while True:
+                await utils.run_with_interrupt(
+                    _do_ping(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
+        finally:
+            self.monitor.pinger_stopped.set()
+            return
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -71,14 +244,92 @@ class Connection:
         if "version" not in msg:
             msg['version'] = self.facades[msg['type']]
         outgoing = json.dumps(msg, indent=2, cls=encoder)
         if "version" not in msg:
             msg['version'] = self.facades[msg['type']]
         outgoing = json.dumps(msg, indent=2, cls=encoder)
-        await self.ws.send(outgoing)
-        result = await self.recv()
-        #log.debug("Send: %s", outgoing)
-        #log.debug("Recv: %s", result)
-        if result and 'error' in result:
-            raise RuntimeError(result)
+        for attempt in range(3):
+            try:
+                await self.ws.send(outgoing)
+                break
+            except websockets.ConnectionClosed:
+                if attempt == 2:
+                    raise
+                log.warning('RPC: Connection closed, reconnecting')
+                # the reconnect has to be done in a separate task because,
+                # if it is triggered by the pinger, then this RPC call will
+                # be cancelled when the pinger is cancelled by the reconnect,
+                # and we don't want the reconnect to be aborted halfway through
+                await asyncio.wait([self.reconnect()], loop=self.loop)
+        result = await self.recv(msg['request-id'])
+
+        if not result:
+            return result
+
+        if 'error' in result:
+            # API Error Response
+            raise JujuAPIError(result)
+
+        if 'response' not in result:
+            # This may never happen
+            return result
+
+        if 'results' in result['response']:
+            # Check for errors in a result list.
+            errors = []
+            for res in result['response']['results']:
+                if res.get('error', {}).get('message'):
+                    errors.append(res['error']['message'])
+            if errors:
+                raise JujuError(errors)
+
+        elif result['response'].get('error', {}).get('message'):
+            raise JujuError(result['response']['error']['message'])
+
         return result
 
         return result
 
+    def http_headers(self):
+        """Return dictionary of http headers necessary for making an http
+        connection to the endpoint of this Connection.
+
+        :return: Dictionary of headers
+
+        """
+        if not self.username:
+            return {}
+
+        creds = u'{}:{}'.format(
+            tag.user(self.username),
+            self.password or ''
+        )
+        token = base64.b64encode(creds.encode())
+        return {
+            'Authorization': 'Basic {}'.format(token.decode())
+        }
+
+    def https_connection(self):
+        """Return an https connection to this Connection's endpoint.
+
+        Returns a 3-tuple containing::
+
+            1. The :class:`HTTPSConnection` instance
+            2. Dictionary of auth headers to be used with the connection
+            3. The root url path (str) to be used for requests.
+
+        """
+        endpoint = self.endpoint
+        host, remainder = endpoint.split(':', 1)
+        port = remainder
+        if '/' in remainder:
+            port, _ = remainder.split('/', 1)
+
+        conn = HTTPSConnection(
+            host, int(port),
+            context=self._get_ssl(self.cacert),
+        )
+
+        path = (
+            "/model/{}".format(self.uuid)
+            if self.uuid else ""
+        )
+        return conn, self.http_headers(), path
+
     async def clone(self):
         """Return a new Connection, connected to the same websocket endpoint
         as this one.
     async def clone(self):
         """Return a new Connection, connected to the same websocket endpoint
         as this one.
@@ -90,6 +341,9 @@ class Connection:
             self.username,
             self.password,
             self.cacert,
             self.username,
             self.password,
             self.cacert,
+            self.macaroons,
+            self.loop,
+            self.max_frame_size,
         )
 
     async def controller(self):
         )
 
     async def controller(self):
@@ -102,77 +356,178 @@ class Connection:
             self.username,
             self.password,
             self.cacert,
             self.username,
             self.password,
             self.cacert,
+            self.macaroons,
+            self.loop,
         )
 
         )
 
+    async def _try_endpoint(self, endpoint, cacert):
+        success = False
+        result = None
+        new_endpoints = []
+
+        self.endpoint = endpoint
+        self.cacert = cacert
+        await self.open()
+        try:
+            result = await self.login()
+            if 'discharge-required-error' in result['response']:
+                log.info('Macaroon discharge required, disconnecting')
+            else:
+                # successful login!
+                log.info('Authenticated')
+                success = True
+        except JujuAPIError as e:
+            if e.error_code != 'redirection required':
+                raise
+            log.info('Controller requested redirect')
+            redirect_info = await self.redirect_info()
+            redir_cacert = redirect_info['ca-cert']
+            new_endpoints = [
+                ("{value}:{port}".format(**s), redir_cacert)
+                for servers in redirect_info['servers']
+                for s in servers if s["scope"] == 'public'
+            ]
+        finally:
+            if not success:
+                await self.close()
+        return success, result, new_endpoints
+
+    async def reconnect(self):
+        """ Force a reconnection.
+        """
+        monitor = self.monitor
+        if monitor.reconnecting.locked() or monitor.close_called.is_set():
+            return
+        async with monitor.reconnecting:
+            await self.close()
+            await self._connect()
+
+    async def _connect(self):
+        endpoints = [(self._endpoint, self._cacert)]
+        while endpoints:
+            _endpoint, _cacert = endpoints.pop(0)
+            success, result, new_endpoints = await self._try_endpoint(
+                _endpoint, _cacert)
+            if success:
+                break
+            endpoints.extend(new_endpoints)
+        else:
+            # ran out of endpoints without a successful login
+            raise Exception("Couldn't authenticate to {}".format(
+                self._endpoint))
+
+        response = result['response']
+        self.info = response.copy()
+        self.build_facades(response.get('facades', {}))
+        self.loop.create_task(self.pinger())
+        self.monitor.pinger_stopped.clear()
+
     @classmethod
     @classmethod
-    async def connect(cls, endpoint, uuid, username, password, cacert=None):
+    async def connect(
+            cls, endpoint, uuid, username, password, cacert=None,
+            macaroons=None, loop=None, max_frame_size=None):
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
         will be to the model.
 
         """
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
         will be to the model.
 
         """
-        if uuid:
-            url = "wss://{}/model/{}/api".format(endpoint, uuid)
-        else:
-            url = "wss://{}/api".format(endpoint)
-        client = cls(endpoint, uuid, username, password, cacert)
-        await client.open(url, cacert)
-        server_info = await client.login(username, password)
-        client.build_facades(server_info['facades'])
-        log.info("Driver connected to juju %s", url)
-
+        client = cls(endpoint, uuid, username, password, cacert, macaroons,
+                     loop, max_frame_size)
+        await client._connect()
         return client
 
     @classmethod
         return client
 
     @classmethod
-    async def connect_current(cls):
+    async def connect_current(cls, loop=None, max_frame_size=None):
         """Connect to the currently active model.
 
         """Connect to the currently active model.
 
+        """
+        jujudata = JujuData()
+
+        controller_name = jujudata.current_controller()
+        if not controller_name:
+            raise JujuConnectionError('No current controller')
+
+        model_name = jujudata.current_model()
+
+        return await cls.connect_model(
+            '{}:{}'.format(controller_name, model_name), loop, max_frame_size)
+
+    @classmethod
+    async def connect_current_controller(cls, loop=None, max_frame_size=None):
+        """Connect to the currently active controller.
+
         """
         jujudata = JujuData()
         controller_name = jujudata.current_controller()
         """
         jujudata = JujuData()
         controller_name = jujudata.current_controller()
+        if not controller_name:
+            raise JujuConnectionError('No current controller')
+
+        return await cls.connect_controller(controller_name, loop,
+                                            max_frame_size)
+
+    @classmethod
+    async def connect_controller(cls, controller_name, loop=None,
+                                 max_frame_size=None):
+        """Connect to a controller by name.
+
+        """
+        jujudata = JujuData()
         controller = jujudata.controllers()[controller_name]
         endpoint = controller['api-endpoints'][0]
         cacert = controller.get('ca-cert')
         accounts = jujudata.accounts()[controller_name]
         username = accounts['user']
         controller = jujudata.controllers()[controller_name]
         endpoint = controller['api-endpoints'][0]
         cacert = controller.get('ca-cert')
         accounts = jujudata.accounts()[controller_name]
         username = accounts['user']
-        password = accounts['password']
-        models = jujudata.models()[controller_name]
-        model_name = models['current-model']
-        model_uuid = models['models'][model_name]['uuid']
+        password = accounts.get('password')
+        macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert)
+            endpoint, None, username, password, cacert, macaroons, loop,
+            max_frame_size)
 
     @classmethod
 
     @classmethod
-    async def connect_model(cls, model):
+    async def connect_model(cls, model, loop=None, max_frame_size=None):
         """Connect to a model by name.
 
         """Connect to a model by name.
 
-        :param str model: <controller>:<model>
+        :param str model: [<controller>:]<model>
 
         """
 
         """
-        controller_name, model_name = model.split(':')
-
         jujudata = JujuData()
         jujudata = JujuData()
+
+        if ':' in model:
+            # explicit controller given
+            controller_name, model_name = model.split(':')
+        else:
+            # use the current controller if one isn't explicitly given
+            controller_name = jujudata.current_controller()
+            model_name = model
+
+        accounts = jujudata.accounts()[controller_name]
+        username = accounts['user']
+        # model name must include a user prefix, so add it if it doesn't
+        if '/' not in model_name:
+            model_name = '{}/{}'.format(username, model_name)
+
         controller = jujudata.controllers()[controller_name]
         endpoint = controller['api-endpoints'][0]
         cacert = controller.get('ca-cert')
         controller = jujudata.controllers()[controller_name]
         endpoint = controller['api-endpoints'][0]
         cacert = controller.get('ca-cert')
-        accounts = jujudata.accounts()[controller_name]
-        username = accounts['user']
-        password = accounts['password']
+        password = accounts.get('password')
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
+        macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert)
+            endpoint, model_uuid, username, password, cacert, macaroons, loop,
+            max_frame_size)
 
 
-    def build_facades(self, info):
+    def build_facades(self, facades):
         self.facades.clear()
         self.facades.clear()
-        for facade in info:
+        for facade in facades:
             self.facades[facade['name']] = facade['versions'][-1]
 
             self.facades[facade['name']] = facade['versions'][-1]
 
-    async def login(self, username, password):
-        if not username.startswith('user-'):
+    async def login(self):
+        username = self.username
+        if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
         result = await self.rpc({
             username = 'user-{}'.format(username)
 
         result = await self.rpc({
@@ -181,9 +536,23 @@ class Connection:
             "version": 3,
             "params": {
                 "auth-tag": username,
             "version": 3,
             "params": {
                 "auth-tag": username,
-                "credentials": password,
+                "credentials": self.password,
                 "nonce": "".join(random.sample(string.printable, 12)),
                 "nonce": "".join(random.sample(string.printable, 12)),
+                "macaroons": self.macaroons
             }})
             }})
+        return result
+
+    async def redirect_info(self):
+        try:
+            result = await self.rpc({
+                "type": "Admin",
+                "request": "RedirectInfo",
+                "version": 3,
+            })
+        except JujuAPIError as e:
+            if e.message == 'not redirected':
+                return None
+            raise
         return result['response']
 
 
         return result['response']
 
 
@@ -193,10 +562,18 @@ class JujuData:
         self.path = os.path.abspath(os.path.expanduser(self.path))
 
     def current_controller(self):
         self.path = os.path.abspath(os.path.expanduser(self.path))
 
     def current_controller(self):
-        cmd = shlex.split('juju show-controller --format yaml')
+        cmd = shlex.split('juju list-controllers --format yaml')
         output = subprocess.check_output(cmd)
         output = yaml.safe_load(output)
         output = subprocess.check_output(cmd)
         output = yaml.safe_load(output)
-        return list(output.keys())[0]
+        return output.get('current-controller', '')
+
+    def current_model(self, controller_name=None):
+        if not controller_name:
+            controller_name = self.current_controller()
+        models = self.models()[controller_name]
+        if 'current-model' not in models:
+            raise JujuError('No current model')
+        return models['current-model']
 
     def controllers(self):
         return self._load_yaml('controllers.yaml', 'controllers')
 
     def controllers(self):
         return self._load_yaml('controllers.yaml', 'controllers')
@@ -211,3 +588,36 @@ class JujuData:
         filepath = os.path.join(self.path, filename)
         with io.open(filepath, 'rt') as f:
             return yaml.safe_load(f)[key]
         filepath = os.path.join(self.path, filename)
         with io.open(filepath, 'rt') as f:
             return yaml.safe_load(f)[key]
+
+
+def get_macaroons(controller_name=None):
+    """Decode and return macaroons from default ~/.go-cookies
+
+    """
+    cookie_files = []
+    if controller_name:
+        cookie_files.append('~/.local/share/juju/cookies/{}.json'.format(
+            controller_name))
+    cookie_files.append('~/.go-cookies')
+    for cookie_file in cookie_files:
+        cookie_file = Path(cookie_file).expanduser()
+        if cookie_file.exists():
+            try:
+                cookies = json.loads(cookie_file.read_text())
+                break
+            except (OSError, ValueError):
+                log.warn("Couldn't load macaroons from %s", cookie_file)
+                return []
+    else:
+        log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files))
+        return []
+
+    base64_macaroons = [
+        c['Value'] for c in cookies
+        if c['Name'].startswith('macaroon-') and c['Value']
+    ]
+
+    return [
+        json.loads(base64.b64decode(value).decode('utf-8'))
+        for value in base64_macaroons
+    ]