Merge commit '19031b24b523c872c1ac367821dc60c950a09755' as 'modules/libjuju'
[osm/N2VC.git] / modules / libjuju / juju / client / connection.py
diff --git a/modules/libjuju/juju/client/connection.py b/modules/libjuju/juju/client/connection.py
new file mode 100644 (file)
index 0000000..7457391
--- /dev/null
@@ -0,0 +1,623 @@
+import base64
+import io
+import json
+import logging
+import os
+import random
+import shlex
+import ssl
+import string
+import subprocess
+import weakref
+import websockets
+from concurrent.futures import CancelledError
+from http.client import HTTPSConnection
+from pathlib import Path
+
+import asyncio
+import yaml
+
+from juju import tag, utils
+from juju.client import client
+from juju.errors import JujuError, JujuAPIError, JujuConnectionError
+from juju.utils import IdQueue
+
+log = logging.getLogger("websocket")
+
+
+class Monitor:
+    """
+    Monitor helper class for our Connection class.
+
+    Contains a reference to an instantiated Connection, along with a
+    reference to the Connection.receiver Future. Upon inspecttion of
+    these objects, this class determines whether the connection is in
+    an 'error', 'connected' or 'disconnected' state.
+
+    Use this class to stay up to date on the health of a connection,
+    and take appropriate action if the connection errors out due to
+    network issues or other unexpected circumstances.
+
+    """
+    ERROR = 'error'
+    CONNECTED = 'connected'
+    DISCONNECTING = 'disconnecting'
+    DISCONNECTED = 'disconnected'
+
+    def __init__(self, connection):
+        self.connection = weakref.ref(connection)
+        self.reconnecting = asyncio.Lock(loop=connection.loop)
+        self.close_called = asyncio.Event(loop=connection.loop)
+        self.receiver_stopped = asyncio.Event(loop=connection.loop)
+        self.pinger_stopped = asyncio.Event(loop=connection.loop)
+        self.receiver_stopped.set()
+        self.pinger_stopped.set()
+
+    @property
+    def status(self):
+        """
+        Determine the status of the connection and receiver, and return
+        ERROR, CONNECTED, or DISCONNECTED as appropriate.
+
+        For simplicity, we only consider ourselves to be connected
+        after the Connection class has setup a receiver task. This
+        only happens after the websocket is open, and the connection
+        isn't usable until that receiver has been started.
+
+        """
+        connection = self.connection()
+
+        # the connection instance was destroyed but someone kept
+        # a separate reference to the monitor for some reason
+        if not connection:
+            return self.DISCONNECTED
+
+        # connection cleanly disconnected or not yet opened
+        if not connection.ws:
+            return self.DISCONNECTED
+
+        # close called but not yet complete
+        if self.close_called.is_set():
+            return self.DISCONNECTING
+
+        # connection closed uncleanly (we didn't call connection.close)
+        if self.receiver_stopped.is_set() or not connection.ws.open:
+            return self.ERROR
+
+        # everything is fine!
+        return self.CONNECTED
+
+
+class Connection:
+    """
+    Usage::
+
+        # Connect to an arbitrary api server
+        client = await Connection.connect(
+            api_endpoint, model_uuid, username, password, cacert)
+
+        # Connect using a controller/model name
+        client = await Connection.connect_model('local.local:default')
+
+        # Connect to the currently active model
+        client = await Connection.connect_current()
+
+    Note: Any connection method or constructor can accept an optional `loop`
+    argument to override the default event loop from `asyncio.get_event_loop`.
+    """
+
+    DEFAULT_FRAME_SIZE = 'default_frame_size'
+    MAX_FRAME_SIZE = 2**22
+    "Maximum size for a single frame.  Defaults to 4MB."
+
+    def __init__(
+            self, endpoint, uuid, username, password, cacert=None,
+            macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
+        self.endpoint = endpoint
+        self._endpoint = endpoint
+        self.uuid = uuid
+        if macaroons:
+            self.macaroons = macaroons
+            self.username = ''
+            self.password = ''
+        else:
+            self.macaroons = []
+            self.username = username
+            self.password = password
+        self.cacert = cacert
+        self._cacert = cacert
+        self.loop = loop or asyncio.get_event_loop()
+
+        self.__request_id__ = 0
+        self.addr = None
+        self.ws = None
+        self.facades = {}
+        self.messages = IdQueue(loop=self.loop)
+        self.monitor = Monitor(connection=self)
+        if max_frame_size is self.DEFAULT_FRAME_SIZE:
+            max_frame_size = self.MAX_FRAME_SIZE
+        self.max_frame_size = max_frame_size
+
+    @property
+    def is_open(self):
+        return self.monitor.status == Monitor.CONNECTED
+
+    def _get_ssl(self, cert=None):
+        return ssl.create_default_context(
+            purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
+
+    async def open(self):
+        if self.uuid:
+            url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
+        else:
+            url = "wss://{}/api".format(self.endpoint)
+
+        kw = dict()
+        kw['ssl'] = self._get_ssl(self.cacert)
+        kw['loop'] = self.loop
+        kw['max_size'] = self.max_frame_size
+        self.addr = url
+        self.ws = await websockets.connect(url, **kw)
+        self.loop.create_task(self.receiver())
+        self.monitor.receiver_stopped.clear()
+        log.info("Driver connected to juju %s", url)
+        self.monitor.close_called.clear()
+        return self
+
+    async def close(self):
+        if not self.ws:
+            return
+        self.monitor.close_called.set()
+        await self.monitor.pinger_stopped.wait()
+        await self.monitor.receiver_stopped.wait()
+        await self.ws.close()
+        self.ws = None
+
+    async def recv(self, request_id):
+        if not self.is_open:
+            raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
+        return await self.messages.get(request_id)
+
+    async def receiver(self):
+        try:
+            while self.is_open:
+                result = await utils.run_with_interrupt(
+                    self.ws.recv(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
+                if result is not None:
+                    result = json.loads(result)
+                    await self.messages.put(result['request-id'], result)
+        except CancelledError:
+            pass
+        except websockets.ConnectionClosed as e:
+            log.warning('Receiver: Connection closed, reconnecting')
+            await self.messages.put_all(e)
+            # the reconnect has to be done as a task because the receiver will
+            # be cancelled by the reconnect and we don't want the reconnect
+            # to be aborted half-way through
+            self.loop.create_task(self.reconnect())
+            return
+        except Exception as e:
+            log.exception("Error in receiver")
+            # make pending listeners aware of the error
+            await self.messages.put_all(e)
+            raise
+        finally:
+            self.monitor.receiver_stopped.set()
+
+    async def pinger(self):
+        '''
+        A Controller can time us out if we are silent for too long. This
+        is especially true in JaaS, which has a fairly strict timeout.
+
+        To prevent timing out, we send a ping every ten seconds.
+
+        '''
+        async def _do_ping():
+            try:
+                await pinger_facade.Ping()
+                await asyncio.sleep(10, loop=self.loop)
+            except CancelledError:
+                pass
+
+        pinger_facade = client.PingerFacade.from_connection(self)
+        try:
+            while True:
+                await utils.run_with_interrupt(
+                    _do_ping(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
+        finally:
+            self.monitor.pinger_stopped.set()
+            return
+
+    async def rpc(self, msg, encoder=None):
+        self.__request_id__ += 1
+        msg['request-id'] = self.__request_id__
+        if'params' not in msg:
+            msg['params'] = {}
+        if "version" not in msg:
+            msg['version'] = self.facades[msg['type']]
+        outgoing = json.dumps(msg, indent=2, cls=encoder)
+        for attempt in range(3):
+            try:
+                await self.ws.send(outgoing)
+                break
+            except websockets.ConnectionClosed:
+                if attempt == 2:
+                    raise
+                log.warning('RPC: Connection closed, reconnecting')
+                # the reconnect has to be done in a separate task because,
+                # if it is triggered by the pinger, then this RPC call will
+                # be cancelled when the pinger is cancelled by the reconnect,
+                # and we don't want the reconnect to be aborted halfway through
+                await asyncio.wait([self.reconnect()], loop=self.loop)
+        result = await self.recv(msg['request-id'])
+
+        if not result:
+            return result
+
+        if 'error' in result:
+            # API Error Response
+            raise JujuAPIError(result)
+
+        if 'response' not in result:
+            # This may never happen
+            return result
+
+        if 'results' in result['response']:
+            # Check for errors in a result list.
+            errors = []
+            for res in result['response']['results']:
+                if res.get('error', {}).get('message'):
+                    errors.append(res['error']['message'])
+            if errors:
+                raise JujuError(errors)
+
+        elif result['response'].get('error', {}).get('message'):
+            raise JujuError(result['response']['error']['message'])
+
+        return result
+
+    def http_headers(self):
+        """Return dictionary of http headers necessary for making an http
+        connection to the endpoint of this Connection.
+
+        :return: Dictionary of headers
+
+        """
+        if not self.username:
+            return {}
+
+        creds = u'{}:{}'.format(
+            tag.user(self.username),
+            self.password or ''
+        )
+        token = base64.b64encode(creds.encode())
+        return {
+            'Authorization': 'Basic {}'.format(token.decode())
+        }
+
+    def https_connection(self):
+        """Return an https connection to this Connection's endpoint.
+
+        Returns a 3-tuple containing::
+
+            1. The :class:`HTTPSConnection` instance
+            2. Dictionary of auth headers to be used with the connection
+            3. The root url path (str) to be used for requests.
+
+        """
+        endpoint = self.endpoint
+        host, remainder = endpoint.split(':', 1)
+        port = remainder
+        if '/' in remainder:
+            port, _ = remainder.split('/', 1)
+
+        conn = HTTPSConnection(
+            host, int(port),
+            context=self._get_ssl(self.cacert),
+        )
+
+        path = (
+            "/model/{}".format(self.uuid)
+            if self.uuid else ""
+        )
+        return conn, self.http_headers(), path
+
+    async def clone(self):
+        """Return a new Connection, connected to the same websocket endpoint
+        as this one.
+
+        """
+        return await Connection.connect(
+            self.endpoint,
+            self.uuid,
+            self.username,
+            self.password,
+            self.cacert,
+            self.macaroons,
+            self.loop,
+            self.max_frame_size,
+        )
+
+    async def controller(self):
+        """Return a Connection to the controller at self.endpoint
+
+        """
+        return await Connection.connect(
+            self.endpoint,
+            None,
+            self.username,
+            self.password,
+            self.cacert,
+            self.macaroons,
+            self.loop,
+        )
+
+    async def _try_endpoint(self, endpoint, cacert):
+        success = False
+        result = None
+        new_endpoints = []
+
+        self.endpoint = endpoint
+        self.cacert = cacert
+        await self.open()
+        try:
+            result = await self.login()
+            if 'discharge-required-error' in result['response']:
+                log.info('Macaroon discharge required, disconnecting')
+            else:
+                # successful login!
+                log.info('Authenticated')
+                success = True
+        except JujuAPIError as e:
+            if e.error_code != 'redirection required':
+                raise
+            log.info('Controller requested redirect')
+            redirect_info = await self.redirect_info()
+            redir_cacert = redirect_info['ca-cert']
+            new_endpoints = [
+                ("{value}:{port}".format(**s), redir_cacert)
+                for servers in redirect_info['servers']
+                for s in servers if s["scope"] == 'public'
+            ]
+        finally:
+            if not success:
+                await self.close()
+        return success, result, new_endpoints
+
+    async def reconnect(self):
+        """ Force a reconnection.
+        """
+        monitor = self.monitor
+        if monitor.reconnecting.locked() or monitor.close_called.is_set():
+            return
+        async with monitor.reconnecting:
+            await self.close()
+            await self._connect()
+
+    async def _connect(self):
+        endpoints = [(self._endpoint, self._cacert)]
+        while endpoints:
+            _endpoint, _cacert = endpoints.pop(0)
+            success, result, new_endpoints = await self._try_endpoint(
+                _endpoint, _cacert)
+            if success:
+                break
+            endpoints.extend(new_endpoints)
+        else:
+            # ran out of endpoints without a successful login
+            raise Exception("Couldn't authenticate to {}".format(
+                self._endpoint))
+
+        response = result['response']
+        self.info = response.copy()
+        self.build_facades(response.get('facades', {}))
+        self.loop.create_task(self.pinger())
+        self.monitor.pinger_stopped.clear()
+
+    @classmethod
+    async def connect(
+            cls, endpoint, uuid, username, password, cacert=None,
+            macaroons=None, loop=None, max_frame_size=None):
+        """Connect to the websocket.
+
+        If uuid is None, the connection will be to the controller. Otherwise it
+        will be to the model.
+
+        """
+        client = cls(endpoint, uuid, username, password, cacert, macaroons,
+                     loop, max_frame_size)
+        await client._connect()
+        return client
+
+    @classmethod
+    async def connect_current(cls, loop=None, max_frame_size=None):
+        """Connect to the currently active model.
+
+        """
+        jujudata = JujuData()
+
+        controller_name = jujudata.current_controller()
+        if not controller_name:
+            raise JujuConnectionError('No current controller')
+
+        model_name = jujudata.current_model()
+
+        return await cls.connect_model(
+            '{}:{}'.format(controller_name, model_name), loop, max_frame_size)
+
+    @classmethod
+    async def connect_current_controller(cls, loop=None, max_frame_size=None):
+        """Connect to the currently active controller.
+
+        """
+        jujudata = JujuData()
+        controller_name = jujudata.current_controller()
+        if not controller_name:
+            raise JujuConnectionError('No current controller')
+
+        return await cls.connect_controller(controller_name, loop,
+                                            max_frame_size)
+
+    @classmethod
+    async def connect_controller(cls, controller_name, loop=None,
+                                 max_frame_size=None):
+        """Connect to a controller by name.
+
+        """
+        jujudata = JujuData()
+        controller = jujudata.controllers()[controller_name]
+        endpoint = controller['api-endpoints'][0]
+        cacert = controller.get('ca-cert')
+        accounts = jujudata.accounts()[controller_name]
+        username = accounts['user']
+        password = accounts.get('password')
+        macaroons = get_macaroons(controller_name) if not password else None
+
+        return await cls.connect(
+            endpoint, None, username, password, cacert, macaroons, loop,
+            max_frame_size)
+
+    @classmethod
+    async def connect_model(cls, model, loop=None, max_frame_size=None):
+        """Connect to a model by name.
+
+        :param str model: [<controller>:]<model>
+
+        """
+        jujudata = JujuData()
+
+        if ':' in model:
+            # explicit controller given
+            controller_name, model_name = model.split(':')
+        else:
+            # use the current controller if one isn't explicitly given
+            controller_name = jujudata.current_controller()
+            model_name = model
+
+        accounts = jujudata.accounts()[controller_name]
+        username = accounts['user']
+        # model name must include a user prefix, so add it if it doesn't
+        if '/' not in model_name:
+            model_name = '{}/{}'.format(username, model_name)
+
+        controller = jujudata.controllers()[controller_name]
+        endpoint = controller['api-endpoints'][0]
+        cacert = controller.get('ca-cert')
+        password = accounts.get('password')
+        models = jujudata.models()[controller_name]
+        model_uuid = models['models'][model_name]['uuid']
+        macaroons = get_macaroons(controller_name) if not password else None
+
+        return await cls.connect(
+            endpoint, model_uuid, username, password, cacert, macaroons, loop,
+            max_frame_size)
+
+    def build_facades(self, facades):
+        self.facades.clear()
+        for facade in facades:
+            self.facades[facade['name']] = facade['versions'][-1]
+
+    async def login(self):
+        username = self.username
+        if username and not username.startswith('user-'):
+            username = 'user-{}'.format(username)
+
+        result = await self.rpc({
+            "type": "Admin",
+            "request": "Login",
+            "version": 3,
+            "params": {
+                "auth-tag": username,
+                "credentials": self.password,
+                "nonce": "".join(random.sample(string.printable, 12)),
+                "macaroons": self.macaroons
+            }})
+        return result
+
+    async def redirect_info(self):
+        try:
+            result = await self.rpc({
+                "type": "Admin",
+                "request": "RedirectInfo",
+                "version": 3,
+            })
+        except JujuAPIError as e:
+            if e.message == 'not redirected':
+                return None
+            raise
+        return result['response']
+
+
+class JujuData:
+    def __init__(self):
+        self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
+        self.path = os.path.abspath(os.path.expanduser(self.path))
+
+    def current_controller(self):
+        cmd = shlex.split('juju list-controllers --format yaml')
+        output = subprocess.check_output(cmd)
+        output = yaml.safe_load(output)
+        return output.get('current-controller', '')
+
+    def current_model(self, controller_name=None):
+        if not controller_name:
+            controller_name = self.current_controller()
+        models = self.models()[controller_name]
+        if 'current-model' not in models:
+            raise JujuError('No current model')
+        return models['current-model']
+
+    def controllers(self):
+        return self._load_yaml('controllers.yaml', 'controllers')
+
+    def models(self):
+        return self._load_yaml('models.yaml', 'controllers')
+
+    def accounts(self):
+        return self._load_yaml('accounts.yaml', 'controllers')
+
+    def _load_yaml(self, filename, key):
+        filepath = os.path.join(self.path, filename)
+        with io.open(filepath, 'rt') as f:
+            return yaml.safe_load(f)[key]
+
+
+def get_macaroons(controller_name=None):
+    """Decode and return macaroons from default ~/.go-cookies
+
+    """
+    cookie_files = []
+    if controller_name:
+        cookie_files.append('~/.local/share/juju/cookies/{}.json'.format(
+            controller_name))
+    cookie_files.append('~/.go-cookies')
+    for cookie_file in cookie_files:
+        cookie_file = Path(cookie_file).expanduser()
+        if cookie_file.exists():
+            try:
+                cookies = json.loads(cookie_file.read_text())
+                break
+            except (OSError, ValueError):
+                log.warn("Couldn't load macaroons from %s", cookie_file)
+                return []
+    else:
+        log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files))
+        return []
+
+    base64_macaroons = [
+        c['Value'] for c in cookies
+        if c['Name'].startswith('macaroon-') and c['Value']
+    ]
+
+    return [
+        json.loads(base64.b64decode(value).decode('utf-8'))
+        for value in base64_macaroons
+    ]