Allow auth with macaroons
authorTim Van Steenburgh <tvansteenburgh@gmail.com>
Thu, 17 Nov 2016 14:49:34 +0000 (09:49 -0500)
committerTim Van Steenburgh <tvansteenburgh@gmail.com>
Thu, 17 Nov 2016 14:49:34 +0000 (09:49 -0500)
If there's no password, attempt macaroon auth using macaroons
in ~/.go-cookies.

juju/client/connection.py

index cdd93d9..18111ce 100644 (file)
@@ -1,3 +1,4 @@
+import base64
 import io
 import json
 import logging
 import io
 import json
 import logging
@@ -31,11 +32,14 @@ class Connection:
         client = await Connection.connect_current()
 
     """
         client = await Connection.connect_current()
 
     """
-    def __init__(self, endpoint, uuid, username, password, cacert=None):
+    def __init__(
+            self, endpoint, uuid, username, password, cacert=None,
+            macaroons=None):
         self.endpoint = endpoint
         self.uuid = uuid
         self.username = username
         self.password = password
         self.endpoint = endpoint
         self.uuid = uuid
         self.username = username
         self.password = password
+        self.macaroons = macaroons
         self.cacert = cacert
 
         self.__request_id__ = 0
         self.cacert = cacert
 
         self.__request_id__ = 0
@@ -49,16 +53,21 @@ class Connection:
             return self.ws.open
         return False
 
             return self.ws.open
         return False
 
-    def _get_ssl(self, cert):
+    def _get_ssl(self, cert=None):
         return ssl.create_default_context(
             purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
 
         return ssl.create_default_context(
             purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
 
-    async def open(self, addr, cert=None):
+    async def open(self):
+        if self.uuid:
+            url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
+        else:
+            url = "wss://{}/api".format(self.endpoint)
+
         kw = dict()
         kw = dict()
-        if cert:
-            kw['ssl'] = self._get_ssl(cert)
-        self.addr = addr
-        self.ws = await websockets.connect(addr, **kw)
+        kw['ssl'] = self._get_ssl(self.cacert)
+        self.addr = url
+        self.ws = await websockets.connect(url, **kw)
+        log.info("Driver connected to juju %s", url)
         return self
 
     async def close(self):
         return self
 
     async def close(self):
@@ -80,8 +89,6 @@ class Connection:
         outgoing = json.dumps(msg, indent=2, cls=encoder)
         await self.ws.send(outgoing)
         result = await self.recv()
         outgoing = json.dumps(msg, indent=2, cls=encoder)
         await self.ws.send(outgoing)
         result = await self.recv()
-        #log.debug("Send: %s", outgoing)
-        #log.debug("Recv: %s", result)
         if result and 'error' in result:
             raise JujuAPIError(result)
         return result
         if result and 'error' in result:
             raise JujuAPIError(result)
         return result
@@ -97,6 +104,7 @@ class Connection:
             self.username,
             self.password,
             self.cacert,
             self.username,
             self.password,
             self.cacert,
+            self.macaroons,
         )
 
     async def controller(self):
         )
 
     async def controller(self):
@@ -109,27 +117,50 @@ class Connection:
             self.username,
             self.password,
             self.cacert,
             self.username,
             self.password,
             self.cacert,
+            self.macaroons,
         )
 
     @classmethod
         )
 
     @classmethod
-    async def connect(cls, endpoint, uuid, username, password, cacert=None):
+    async def connect(
+            cls, endpoint, uuid, username, password, cacert=None,
+            macaroons=None):
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
         will be to the model.
 
         """
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
         will be to the model.
 
         """
-        if uuid:
-            url = "wss://{}/model/{}/api".format(endpoint, uuid)
-        else:
-            url = "wss://{}/api".format(endpoint)
-        client = cls(endpoint, uuid, username, password, cacert)
-        await client.open(url, cacert)
-        server_info = await client.login(username, password)
-        client.build_facades(server_info['facades'])
-        log.info("Driver connected to juju %s", url)
-
-        return client
+        client = cls(endpoint, uuid, username, password, cacert, macaroons)
+        await client.open()
+
+        redirect_info = await client.redirect_info()
+        if not redirect_info:
+            server_info = await client.login(username, password, macaroons)
+            client.build_facades(server_info['facades'])
+            return client
+
+        await client.close()
+        servers = [
+            s for servers in redirect_info['servers']
+            for s in servers if s["scope"] == 'public'
+        ]
+        for server in servers:
+            client = cls(
+                "{value}:{port}".format(**server), uuid, username,
+                password, redirect_info['ca-cert'], macaroons)
+            await client.open()
+            try:
+                result = await client.login(username, password, macaroons)
+                if 'discharge-required-error' in result:
+                    continue
+                client.build_facades(result['facades'])
+                return client
+            except Exception as e:
+                await client.close()
+                log.exception(e)
+
+        raise Exception(
+            "Couldn't authenticate to %s", endpoint)
 
     @classmethod
     async def connect_current(cls):
 
     @classmethod
     async def connect_current(cls):
@@ -138,18 +169,11 @@ class Connection:
         """
         jujudata = JujuData()
         controller_name = jujudata.current_controller()
         """
         jujudata = JujuData()
         controller_name = jujudata.current_controller()
-        controller = jujudata.controllers()[controller_name]
-        endpoint = controller['api-endpoints'][0]
-        cacert = controller.get('ca-cert')
-        accounts = jujudata.accounts()[controller_name]
-        username = accounts['user']
-        password = accounts['password']
         models = jujudata.models()[controller_name]
         model_name = models['current-model']
         models = jujudata.models()[controller_name]
         model_name = models['current-model']
-        model_uuid = models['models'][model_name]['uuid']
 
 
-        return await cls.connect(
-            endpoint, model_uuid, username, password, cacert)
+        return await cls.connect_model(
+            '{}:{}'.format(controller_name, model_name))
 
     @classmethod
     async def connect_model(cls, model):
 
     @classmethod
     async def connect_model(cls, model):
@@ -166,20 +190,25 @@ class Connection:
         cacert = controller.get('ca-cert')
         accounts = jujudata.accounts()[controller_name]
         username = accounts['user']
         cacert = controller.get('ca-cert')
         accounts = jujudata.accounts()[controller_name]
         username = accounts['user']
-        password = accounts['password']
+        password = accounts.get('password')
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
+        macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert)
+            endpoint, model_uuid, username, password, cacert, macaroons)
 
     def build_facades(self, info):
         self.facades.clear()
         for facade in info:
             self.facades[facade['name']] = facade['versions'][-1]
 
 
     def build_facades(self, info):
         self.facades.clear()
         for facade in info:
             self.facades[facade['name']] = facade['versions'][-1]
 
-    async def login(self, username, password):
-        if not username.startswith('user-'):
+    async def login(self, username, password, macaroons=None):
+        if macaroons:
+            username = ''
+            password = ''
+
+        if username and not username.startswith('user-'):
             username = 'user-{}'.format(username)
 
         result = await self.rpc({
             username = 'user-{}'.format(username)
 
         result = await self.rpc({
@@ -190,9 +219,23 @@ class Connection:
                 "auth-tag": username,
                 "credentials": password,
                 "nonce": "".join(random.sample(string.printable, 12)),
                 "auth-tag": username,
                 "credentials": password,
                 "nonce": "".join(random.sample(string.printable, 12)),
+                "macaroons": macaroons or []
             }})
         return result['response']
 
             }})
         return result['response']
 
+    async def redirect_info(self):
+        try:
+            result = await self.rpc({
+                "type": "Admin",
+                "request": "RedirectInfo",
+                "version": 3,
+            })
+        except JujuAPIError as e:
+            if e.message == 'not redirected':
+                return None
+            raise
+        return result['response']
+
 
 class JujuData:
     def __init__(self):
 
 class JujuData:
     def __init__(self):
@@ -218,3 +261,26 @@ class JujuData:
         filepath = os.path.join(self.path, filename)
         with io.open(filepath, 'rt') as f:
             return yaml.safe_load(f)[key]
         filepath = os.path.join(self.path, filename)
         with io.open(filepath, 'rt') as f:
             return yaml.safe_load(f)[key]
+
+
+def get_macaroons():
+    """Decode and return macaroons from default ~/.go-cookies
+
+    """
+    try:
+        cookie_file = os.path.expanduser('~/.go-cookies')
+        with open(cookie_file, 'r') as f:
+            cookies = json.load(f)
+    except (OSError, ValueError) as e:
+        log.warn("Couldn't load macaroons from %s", cookie_file)
+        return []
+
+    base64_macaroons = [
+        c['Value'] for c in cookies
+        if c['Name'].startswith('macaroon-') and c['Value']
+    ]
+
+    return [
+        json.loads(base64.b64decode(value).decode('utf-8'))
+        for value in base64_macaroons
+    ]