Allow auth with macaroons
authorTim Van Steenburgh <tvansteenburgh@gmail.com>
Thu, 17 Nov 2016 14:49:34 +0000 (09:49 -0500)
committerTim Van Steenburgh <tvansteenburgh@gmail.com>
Thu, 17 Nov 2016 14:49:34 +0000 (09:49 -0500)
If there's no password, attempt macaroon auth using macaroons
in ~/.go-cookies.

juju/client/connection.py

index cdd93d9..18111ce 100644 (file)
@@ -1,3 +1,4 @@
+import base64
 import io
 import json
 import logging
@@ -31,11 +32,14 @@ class Connection:
         client = await Connection.connect_current()
 
     """
-    def __init__(self, endpoint, uuid, username, password, cacert=None):
+    def __init__(
+            self, endpoint, uuid, username, password, cacert=None,
+            macaroons=None):
         self.endpoint = endpoint
         self.uuid = uuid
         self.username = username
         self.password = password
+        self.macaroons = macaroons
         self.cacert = cacert
 
         self.__request_id__ = 0
@@ -49,16 +53,21 @@ class Connection:
             return self.ws.open
         return False
 
-    def _get_ssl(self, cert):
+    def _get_ssl(self, cert=None):
         return ssl.create_default_context(
             purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
 
-    async def open(self, addr, cert=None):
+    async def open(self):
+        if self.uuid:
+            url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
+        else:
+            url = "wss://{}/api".format(self.endpoint)
+
         kw = dict()
-        if cert:
-            kw['ssl'] = self._get_ssl(cert)
-        self.addr = addr
-        self.ws = await websockets.connect(addr, **kw)
+        kw['ssl'] = self._get_ssl(self.cacert)
+        self.addr = url
+        self.ws = await websockets.connect(url, **kw)
+        log.info("Driver connected to juju %s", url)
         return self
 
     async def close(self):
@@ -80,8 +89,6 @@ class Connection:
         outgoing = json.dumps(msg, indent=2, cls=encoder)
         await self.ws.send(outgoing)
         result = await self.recv()
-        #log.debug("Send: %s", outgoing)
-        #log.debug("Recv: %s", result)
         if result and 'error' in result:
             raise JujuAPIError(result)
         return result
@@ -97,6 +104,7 @@ class Connection:
             self.username,
             self.password,
             self.cacert,
+            self.macaroons,
         )
 
     async def controller(self):
@@ -109,27 +117,50 @@ class Connection:
             self.username,
             self.password,
             self.cacert,
+            self.macaroons,
         )
 
     @classmethod
-    async def connect(cls, endpoint, uuid, username, password, cacert=None):
+    async def connect(
+            cls, endpoint, uuid, username, password, cacert=None,
+            macaroons=None):
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
         will be to the model.
 
         """
-        if uuid:
-            url = "wss://{}/model/{}/api".format(endpoint, uuid)
-        else:
-            url = "wss://{}/api".format(endpoint)
-        client = cls(endpoint, uuid, username, password, cacert)
-        await client.open(url, cacert)
-        server_info = await client.login(username, password)
-        client.build_facades(server_info['facades'])
-        log.info("Driver connected to juju %s", url)
-
-        return client
+        client = cls(endpoint, uuid, username, password, cacert, macaroons)
+        await client.open()
+
+        redirect_info = await client.redirect_info()
+        if not redirect_info:
+            server_info = await client.login(username, password, macaroons)
+            client.build_facades(server_info['facades'])
+            return client
+
+        await client.close()
+        servers = [
+            s for servers in redirect_info['servers']
+            for s in servers if s["scope"] == 'public'
+        ]
+        for server in servers:
+            client = cls(
+                "{value}:{port}".format(**server), uuid, username,
+                password, redirect_info['ca-cert'], macaroons)
+            await client.open()
+            try:
+                result = await client.login(username, password, macaroons)
+                if 'discharge-required-error' in result:
+                    continue
+                client.build_facades(result['facades'])
+                return client
+            except Exception as e:
+                await client.close()
+                log.exception(e)
+
+        raise Exception(
+            "Couldn't authenticate to %s", endpoint)
 
     @classmethod
     async def connect_current(cls):
@@ -138,18 +169,11 @@ class Connection:
         """
         jujudata = JujuData()
         controller_name = jujudata.current_controller()
-        controller = jujudata.controllers()[controller_name]
-        endpoint = controller['api-endpoints'][0]
-        cacert = controller.get('ca-cert')
-        accounts = jujudata.accounts()[controller_name]
-        username = accounts['user']
-        password = accounts['password']
         models = jujudata.models()[controller_name]
         model_name = models['current-model']
-        model_uuid = models['models'][model_name]['uuid']
 
-        return await cls.connect(
-            endpoint, model_uuid, username, password, cacert)
+        return await cls.connect_model(
+            '{}:{}'.format(controller_name, model_name))
 
     @classmethod
     async def connect_model(cls, model):
@@ -166,20 +190,25 @@ class Connection:
         cacert = controller.get('ca-cert')
         accounts = jujudata.accounts()[controller_name]
         username = accounts['user']
-        password = accounts['password']
+        password = accounts.get('password')
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
+        macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert)
+            endpoint, model_uuid, username, password, cacert, macaroons)
 
     def build_facades(self, info):
         self.facades.clear()
         for facade in info:
             self.facades[facade['name']] = facade['versions'][-1]
 
-    async def login(self, username, password):
-        if not username.startswith('user-'):
+    async def login(self, username, password, macaroons=None):
+        if macaroons:
+            username = ''
+            password = ''
+
+        if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
         result = await self.rpc({
@@ -190,9 +219,23 @@ class Connection:
                 "auth-tag": username,
                 "credentials": password,
                 "nonce": "".join(random.sample(string.printable, 12)),
+                "macaroons": macaroons or []
             }})
         return result['response']
 
+    async def redirect_info(self):
+        try:
+            result = await self.rpc({
+                "type": "Admin",
+                "request": "RedirectInfo",
+                "version": 3,
+            })
+        except JujuAPIError as e:
+            if e.message == 'not redirected':
+                return None
+            raise
+        return result['response']
+
 
 class JujuData:
     def __init__(self):
@@ -218,3 +261,26 @@ class JujuData:
         filepath = os.path.join(self.path, filename)
         with io.open(filepath, 'rt') as f:
             return yaml.safe_load(f)[key]
+
+
+def get_macaroons():
+    """Decode and return macaroons from default ~/.go-cookies
+
+    """
+    try:
+        cookie_file = os.path.expanduser('~/.go-cookies')
+        with open(cookie_file, 'r') as f:
+            cookies = json.load(f)
+    except (OSError, ValueError) as e:
+        log.warn("Couldn't load macaroons from %s", cookie_file)
+        return []
+
+    base64_macaroons = [
+        c['Value'] for c in cookies
+        if c['Name'].startswith('macaroon-') and c['Value']
+    ]
+
+    return [
+        json.loads(base64.b64decode(value).decode('utf-8'))
+        for value in base64_macaroons
+    ]