Fix issue where we do not check to make sure that we are receiving the correct response.
[osm/N2VC.git] / juju / client / connection.py
index 1a9e8e9..486a57f 100644 (file)
@@ -1,15 +1,22 @@
-import asyncio
+import base64
 import io
 import json
 import logging
 import os
 import random
+import shlex
 import ssl
 import string
+import subprocess
 import websockets
+from http.client import HTTPSConnection
 
+import asyncio
 import yaml
 
+from juju import tag
+from juju.errors import JujuError, JujuAPIError, JujuConnectionError
+
 log = logging.getLogger("websocket")
 
 
@@ -27,120 +34,337 @@ class Connection:
         # Connect to the currently active model
         client = await Connection.connect_current()
 
+    Note: Any connection method or constructor can accept an optional `loop`
+    argument to override the default event loop from `asyncio.get_event_loop`.
     """
-    def __init__(self):
+    def __init__(
+            self, endpoint, uuid, username, password, cacert=None,
+            macaroons=None, loop=None):
+        self.endpoint = endpoint
+        self.uuid = uuid
+        self.username = username
+        self.password = password
+        self.macaroons = macaroons
+        self.cacert = cacert
+        self.loop = loop or asyncio.get_event_loop()
+
         self.__request_id__ = 0
         self.addr = None
         self.ws = None
         self.facades = {}
+        self.messages = {}
+
+    @property
+    def is_open(self):
+        if self.ws:
+            return self.ws.open
+        return False
 
-    def _get_ssl(self, cert):
+    def _get_ssl(self, cert=None):
         return ssl.create_default_context(
             purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
 
-    async def open(self, addr, cert=None):
+    async def open(self):
+        if self.uuid:
+            url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
+        else:
+            url = "wss://{}/api".format(self.endpoint)
+
         kw = dict()
-        if cert:
-            kw['ssl'] = self._get_ssl(cert)
-        self.addr = addr
-        self.ws = await websockets.connect(addr, **kw)
+        kw['ssl'] = self._get_ssl(self.cacert)
+        kw['loop'] = self.loop
+        self.addr = url
+        self.ws = await websockets.connect(url, **kw)
+        log.info("Driver connected to juju %s", url)
         return self
 
     async def close(self):
         await self.ws.close()
 
-    async def recv(self):
-        result = await self.ws.recv()
-        if result is not None:
-            result = json.loads(result)
+    async def recv(self, request_id):
+        while not self.messages.get(request_id):
+            await asyncio.sleep(0)
+
+        result = self.messages[request_id]
+
+        del self.messages[request_id]
         return result
 
+    async def receiver(self):
+        while self.is_open:
+            result = await self.ws.recv()
+            if result is not None:
+                result = json.loads(result)
+                self.messages[result['request-id']] = result
+
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
-        msg['RequestId'] = self.__request_id__
-        if'Params' not in msg:
-            msg['Params'] = {}
-        if "Version" not in msg:
-            msg['Version'] = self.facades[msg['Type']]
+        msg['request-id'] = self.__request_id__
+        if'params' not in msg:
+            msg['params'] = {}
+        if "version" not in msg:
+            msg['version'] = self.facades[msg['type']]
         outgoing = json.dumps(msg, indent=2, cls=encoder)
         await self.ws.send(outgoing)
-        result = await self.recv()
-        log.debug("send %s got %s", msg, result)
-        if result and 'Error' in result:
-            raise RuntimeError(result)
+        result = await self.recv(msg['request-id'])
+
+        if not result:
+            return result
+
+        if 'error' in result:
+            # API Error Response
+            raise JujuAPIError(result)
+
+        if not 'response' in result:
+            # This may never happen
+            return result
+
+        if 'results' in result['response']:
+            # Check for errors in a result list.
+            errors = []
+            for res in result['response']['results']:
+                if res.get('error', {}).get('message'):
+                    errors.append(res['error']['message'])
+            if errors:
+                raise JujuError(errors)
+
+        elif result['response'].get('error', {}).get('message'):
+            raise JujuError(result['response']['error']['message'])
+
         return result
 
+    def http_headers(self):
+        """Return dictionary of http headers necessary for making an http
+        connection to the endpoint of this Connection.
+
+        :return: Dictionary of headers
+
+        """
+        if not self.username:
+            return {}
+
+        creds = u'{}:{}'.format(
+            tag.user(self.username),
+            self.password or ''
+        )
+        token = base64.b64encode(creds.encode())
+        return {
+            'Authorization': 'Basic {}'.format(token.decode())
+        }
+
+    def https_connection(self):
+        """Return an https connection to this Connection's endpoint.
+
+        Returns a 3-tuple containing::
+
+            1. The :class:`HTTPSConnection` instance
+            2. Dictionary of auth headers to be used with the connection
+            3. The root url path (str) to be used for requests.
+
+        """
+        endpoint = self.endpoint
+        host, remainder = endpoint.split(':', 1)
+        port = remainder
+        if '/' in remainder:
+            port, _ = remainder.split('/', 1)
+
+        conn = HTTPSConnection(
+            host, int(port),
+            context=self._get_ssl(self.cacert),
+        )
+
+        path = (
+            "/model/{}".format(self.uuid)
+            if self.uuid else ""
+        )
+        return conn, self.http_headers(), path
+
+    async def clone(self):
+        """Return a new Connection, connected to the same websocket endpoint
+        as this one.
+
+        """
+        return await Connection.connect(
+            self.endpoint,
+            self.uuid,
+            self.username,
+            self.password,
+            self.cacert,
+            self.macaroons,
+            self.loop,
+        )
+
+    async def controller(self):
+        """Return a Connection to the controller at self.endpoint
+
+        """
+        return await Connection.connect(
+            self.endpoint,
+            None,
+            self.username,
+            self.password,
+            self.cacert,
+            self.macaroons,
+            self.loop,
+        )
+
     @classmethod
-    async def connect(cls, endpoint, uuid, username, password, cacert=None):
-        url = "wss://{}/model/{}/api".format(endpoint, uuid)
-        client = cls()
-        await client.open(url, cacert)
-        server_info = await client.login(username, password)
-        client.build_facades(server_info['facades'])
-        log.info("Driver connected to juju %s", endpoint)
-        return client
+    async def connect(
+            cls, endpoint, uuid, username, password, cacert=None,
+            macaroons=None, loop=None):
+        """Connect to the websocket.
+
+        If uuid is None, the connection will be to the controller. Otherwise it
+        will be to the model.
+
+        """
+        client = cls(endpoint, uuid, username, password, cacert, macaroons,
+                     loop)
+        await client.open()
+        self.loop.create_task(self.receiver)
+
+        redirect_info = await client.redirect_info()
+        if not redirect_info:
+            await client.login(username, password, macaroons)
+            return client
+
+        await client.close()
+        servers = [
+            s for servers in redirect_info['servers']
+            for s in servers if s["scope"] == 'public'
+        ]
+        for server in servers:
+            client = cls(
+                "{value}:{port}".format(**server), uuid, username,
+                password, redirect_info['ca-cert'], macaroons)
+            await client.open()
+            try:
+                result = await client.login(username, password, macaroons)
+                if 'discharge-required-error' in result:
+                    continue
+                return client
+            except Exception as e:
+                await client.close()
+                log.exception(e)
+
+        raise Exception(
+            "Couldn't authenticate to %s", endpoint)
 
     @classmethod
-    async def connect_current(cls):
+    async def connect_current(cls, loop=None):
         """Connect to the currently active model.
 
         """
         jujudata = JujuData()
         controller_name = jujudata.current_controller()
+        model_name = jujudata.current_model()
+
+        return await cls.connect_model(
+            '{}:{}'.format(controller_name, model_name), loop)
+
+    @classmethod
+    async def connect_current_controller(cls, loop=None):
+        """Connect to the currently active controller.
+
+        """
+        jujudata = JujuData()
+        controller_name = jujudata.current_controller()
+        if not controller_name:
+            raise JujuConnectionError('No current controller')
+
+        return await cls.connect_controller(controller_name, loop)
+
+    @classmethod
+    async def connect_controller(cls, controller_name, loop=None):
+        """Connect to a controller by name.
+
+        """
+        jujudata = JujuData()
         controller = jujudata.controllers()[controller_name]
         endpoint = controller['api-endpoints'][0]
         cacert = controller.get('ca-cert')
         accounts = jujudata.accounts()[controller_name]
-        username = accounts['current-account']
-        password = accounts['accounts'][username]['password']
-        models = jujudata.models()[controller_name]['accounts'][username]
-        model_name = models['current-model']
-        model_uuid = models['models'][model_name]['uuid']
+        username = accounts['user']
+        password = accounts.get('password')
+        macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert)
+            endpoint, None, username, password, cacert, macaroons, loop)
 
     @classmethod
-    async def connect_model(cls, model):
+    async def connect_model(cls, model, loop=None):
         """Connect to a model by name.
 
-        :param str model: <controller>:<model>
+        :param str model: [<controller>:]<model>
 
         """
-        controller_name, model_name = model.split(':')
-
         jujudata = JujuData()
+
+        if ':' in model:
+            # explicit controller given
+            controller_name, model_name = model.split(':')
+        else:
+            # use the current controller if one isn't explicitly given
+            controller_name = jujudata.current_controller()
+            model_name = model
+
+        accounts = jujudata.accounts()[controller_name]
+        username = accounts['user']
+        # model name must include a user prefix, so add it if it doesn't
+        if '/' not in model_name:
+            model_name = '{}/{}'.format(username, model_name)
+
         controller = jujudata.controllers()[controller_name]
         endpoint = controller['api-endpoints'][0]
         cacert = controller.get('ca-cert')
-        accounts = jujudata.accounts()[controller_name]
-        username = accounts['current-account']
-        password = accounts['accounts'][username]['password']
-        models = jujudata.models()[controller_name]['accounts'][username]
+        password = accounts.get('password')
+        models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
+        macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert)
+            endpoint, model_uuid, username, password, cacert, macaroons, loop)
 
     def build_facades(self, info):
         self.facades.clear()
         for facade in info:
-            self.facades[facade['Name']] = facade['Versions'][-1]
+            self.facades[facade['name']] = facade['versions'][-1]
+
+    async def login(self, username, password, macaroons=None):
+        if macaroons:
+            username = ''
+            password = ''
 
-    async def login(self, username, password):
-        if not username.startswith('user-'):
+        if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
         result = await self.rpc({
-            "Type": "Admin",
-            "Request": "Login",
-            "Version": 3,
-            "Params": {
+            "type": "Admin",
+            "request": "Login",
+            "version": 3,
+            "params": {
                 "auth-tag": username,
                 "credentials": password,
-                "Nonce": "".join(random.sample(string.printable, 12)),
+                "nonce": "".join(random.sample(string.printable, 12)),
+                "macaroons": macaroons or []
             }})
-        return result['Response']
+        response = result['response']
+        self.build_facades(response.get('facades', {}))
+        self.info = response.copy()
+        return response
+
+    async def redirect_info(self):
+        try:
+            result = await self.rpc({
+                "type": "Admin",
+                "request": "RedirectInfo",
+                "version": 3,
+            })
+        except JujuAPIError as e:
+            if e.message == 'not redirected':
+                return None
+            raise
+        return result['response']
 
 
 class JujuData:
@@ -149,13 +373,18 @@ class JujuData:
         self.path = os.path.abspath(os.path.expanduser(self.path))
 
     def current_controller(self):
-        try:
-            filepath = os.path.join(self.path, 'current-controller')
-            with io.open(filepath, 'rt') as f:
-                return f.read().strip()
-        except OSError as e:
-            log.exception(e)
-            return None
+        cmd = shlex.split('juju list-controllers --format yaml')
+        output = subprocess.check_output(cmd)
+        output = yaml.safe_load(output)
+        return output.get('current-controller', '')
+
+    def current_model(self, controller_name=None):
+        if not controller_name:
+            controller_name = self.current_controller()
+        models = self.models()[controller_name]
+        if 'current-model' not in models:
+            raise JujuError('No current model')
+        return models['current-model']
 
     def controllers(self):
         return self._load_yaml('controllers.yaml', 'controllers')
@@ -170,3 +399,26 @@ class JujuData:
         filepath = os.path.join(self.path, filename)
         with io.open(filepath, 'rt') as f:
             return yaml.safe_load(f)[key]
+
+
+def get_macaroons():
+    """Decode and return macaroons from default ~/.go-cookies
+
+    """
+    try:
+        cookie_file = os.path.expanduser('~/.go-cookies')
+        with open(cookie_file, 'r') as f:
+            cookies = json.load(f)
+    except (OSError, ValueError):
+        log.warn("Couldn't load macaroons from %s", cookie_file)
+        return []
+
+    base64_macaroons = [
+        c['Value'] for c in cookies
+        if c['Name'].startswith('macaroon-') and c['Value']
+    ]
+
+    return [
+        json.loads(base64.b64decode(value).decode('utf-8'))
+        for value in base64_macaroons
+    ]