Revert "Remove vendored libjuju"
[osm/N2VC.git] / modules / libjuju / juju / client / connection.py
diff --git a/modules/libjuju/juju/client/connection.py b/modules/libjuju/juju/client/connection.py
new file mode 100644 (file)
index 0000000..f2150b7
--- /dev/null
@@ -0,0 +1,601 @@
+import asyncio
+import base64
+import json
+import logging
+import ssl
+import urllib.request
+import weakref
+from concurrent.futures import CancelledError
+from http.client import HTTPSConnection
+
+import macaroonbakery.httpbakery as httpbakery
+import macaroonbakery.bakery as bakery
+import websockets
+from juju import errors, tag, utils
+from juju.client import client
+from juju.utils import IdQueue
+
+log = logging.getLogger('juju.client.connection')
+
+
+class Monitor:
+    """
+    Monitor helper class for our Connection class.
+
+    Contains a reference to an instantiated Connection, along with a
+    reference to the Connection.receiver Future. Upon inspection of
+    these objects, this class determines whether the connection is in
+    an 'error', 'connected' or 'disconnected' state.
+
+    Use this class to stay up to date on the health of a connection,
+    and take appropriate action if the connection errors out due to
+    network issues or other unexpected circumstances.
+
+    """
+    ERROR = 'error'
+    CONNECTED = 'connected'
+    DISCONNECTING = 'disconnecting'
+    DISCONNECTED = 'disconnected'
+
+    def __init__(self, connection):
+        self.connection = weakref.ref(connection)
+        self.reconnecting = asyncio.Lock(loop=connection.loop)
+        self.close_called = asyncio.Event(loop=connection.loop)
+
+    @property
+    def status(self):
+        """
+        Determine the status of the connection and receiver, and return
+        ERROR, CONNECTED, or DISCONNECTED as appropriate.
+
+        For simplicity, we only consider ourselves to be connected
+        after the Connection class has setup a receiver task. This
+        only happens after the websocket is open, and the connection
+        isn't usable until that receiver has been started.
+
+        """
+        connection = self.connection()
+
+        # the connection instance was destroyed but someone kept
+        # a separate reference to the monitor for some reason
+        if not connection:
+            return self.DISCONNECTED
+
+        # connection cleanly disconnected or not yet opened
+        if not connection.ws:
+            return self.DISCONNECTED
+
+        # close called but not yet complete
+        if self.close_called.is_set():
+            return self.DISCONNECTING
+
+        # connection closed uncleanly (we didn't call connection.close)
+        stopped = connection._receiver_task.stopped.is_set()
+        if stopped or not connection.ws.open:
+            return self.ERROR
+
+        # everything is fine!
+        return self.CONNECTED
+
+
+class Connection:
+    """
+    Usage::
+
+        # Connect to an arbitrary api server
+        client = await Connection.connect(
+            api_endpoint, model_uuid, username, password, cacert)
+
+    Note: Any connection method or constructor can accept an optional `loop`
+    argument to override the default event loop from `asyncio.get_event_loop`.
+    """
+
+    MAX_FRAME_SIZE = 2**22
+    "Maximum size for a single frame.  Defaults to 4MB."
+
+    @classmethod
+    async def connect(
+            cls,
+            endpoint=None,
+            uuid=None,
+            username=None,
+            password=None,
+            cacert=None,
+            bakery_client=None,
+            loop=None,
+            max_frame_size=None,
+            retries=3,
+            retry_backoff=10,
+    ):
+        """Connect to the websocket.
+
+        If uuid is None, the connection will be to the controller. Otherwise it
+        will be to the model.
+
+        :param str endpoint: The hostname:port of the controller to connect to.
+        :param str uuid: The model UUID to connect to (None for a
+            controller-only connection).
+        :param str username: The username for controller-local users (or None
+            to use macaroon-based login.)
+        :param str password: The password for controller-local users.
+        :param str cacert: The CA certificate of the controller
+            (PEM formatted).
+        :param httpbakery.Client bakery_client: The macaroon bakery client to
+            to use when performing macaroon-based login. Macaroon tokens
+            acquired when logging will be saved to bakery_client.cookies.
+            If this is None, a default bakery_client will be used.
+        :param asyncio.BaseEventLoop loop: The event loop to use for async
+            operations.
+        :param int max_frame_size: The maximum websocket frame size to allow.
+        :param int retries: When connecting or reconnecting, and all endpoints
+            fail, how many times to retry the connection before giving up.
+        :param int retry_backoff: Number of seconds to increase the wait
+            between connection retry attempts (a backoff of 10 with 3 retries
+            would wait 10s, 20s, and 30s).
+        """
+        self = cls()
+        if endpoint is None:
+            raise ValueError('no endpoint provided')
+        self.uuid = uuid
+        if bakery_client is None:
+            bakery_client = httpbakery.Client()
+        self.bakery_client = bakery_client
+        if username and '@' in username and not username.endswith('@local'):
+            # We're trying to log in as an external user - we need to use
+            # macaroon authentication with no username or password.
+            if password is not None:
+                raise errors.JujuAuthError('cannot log in as external '
+                                           'user with a password')
+            username = None
+        self.usertag = tag.user(username)
+        self.password = password
+        self.loop = loop or asyncio.get_event_loop()
+
+        self.__request_id__ = 0
+
+        # The following instance variables are initialized by the
+        # _connect_with_redirect method, but create them here
+        # as a reminder that they will exist.
+        self.addr = None
+        self.ws = None
+        self.endpoint = None
+        self.cacert = None
+        self.info = None
+
+        # Create that _Task objects but don't start the tasks yet.
+        self._pinger_task = _Task(self._pinger, self.loop)
+        self._receiver_task = _Task(self._receiver, self.loop)
+
+        self._retries = retries
+        self._retry_backoff = retry_backoff
+
+        self.facades = {}
+        self.messages = IdQueue(loop=self.loop)
+        self.monitor = Monitor(connection=self)
+        if max_frame_size is None:
+            max_frame_size = self.MAX_FRAME_SIZE
+        self.max_frame_size = max_frame_size
+        await self._connect_with_redirect([(endpoint, cacert)])
+        return self
+
+    @property
+    def username(self):
+        if not self.usertag:
+            return None
+        return self.usertag[len('user-'):]
+
+    @property
+    def is_open(self):
+        return self.monitor.status == Monitor.CONNECTED
+
+    def _get_ssl(self, cert=None):
+        return ssl.create_default_context(
+            purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
+
+    async def _open(self, endpoint, cacert):
+        if self.uuid:
+            url = "wss://{}/model/{}/api".format(endpoint, self.uuid)
+        else:
+            url = "wss://{}/api".format(endpoint)
+
+        return (await websockets.connect(
+            url,
+            ssl=self._get_ssl(cacert),
+            loop=self.loop,
+            max_size=self.max_frame_size,
+        ), url, endpoint, cacert)
+
+    async def close(self):
+        if not self.ws:
+            return
+        self.monitor.close_called.set()
+        await self._pinger_task.stopped.wait()
+        await self._receiver_task.stopped.wait()
+        await self.ws.close()
+        self.ws = None
+
+    async def _recv(self, request_id):
+        if not self.is_open:
+            raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
+        return await self.messages.get(request_id)
+
+    async def _receiver(self):
+        try:
+            while self.is_open:
+                result = await utils.run_with_interrupt(
+                    self.ws.recv(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
+                if result is not None:
+                    result = json.loads(result)
+                    await self.messages.put(result['request-id'], result)
+        except CancelledError:
+            pass
+        except websockets.ConnectionClosed as e:
+            log.warning('Receiver: Connection closed, reconnecting')
+            await self.messages.put_all(e)
+            # the reconnect has to be done as a task because the receiver will
+            # be cancelled by the reconnect and we don't want the reconnect
+            # to be aborted half-way through
+            self.loop.create_task(self.reconnect())
+            return
+        except Exception as e:
+            log.exception("Error in receiver")
+            # make pending listeners aware of the error
+            await self.messages.put_all(e)
+            raise
+
+    async def _pinger(self):
+        '''
+        A Controller can time us out if we are silent for too long. This
+        is especially true in JaaS, which has a fairly strict timeout.
+
+        To prevent timing out, we send a ping every ten seconds.
+
+        '''
+        async def _do_ping():
+            try:
+                await pinger_facade.Ping()
+                await asyncio.sleep(10, loop=self.loop)
+            except CancelledError:
+                pass
+
+        pinger_facade = client.PingerFacade.from_connection(self)
+        try:
+            while True:
+                await utils.run_with_interrupt(
+                    _do_ping(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
+        except websockets.exceptions.ConnectionClosed:
+            # The connection has closed - we can't do anything
+            # more until the connection is restarted.
+            log.debug('ping failed because of closed connection')
+            pass
+
+    async def rpc(self, msg, encoder=None):
+        '''Make an RPC to the API. The message is encoded as JSON
+        using the given encoder if any.
+        :param msg: Parameters for the call (will be encoded as JSON).
+        :param encoder: Encoder to be used when encoding the message.
+        :return: The result of the call.
+        :raises JujuAPIError: When there's an error returned.
+        :raises JujuError:
+        '''
+        self.__request_id__ += 1
+        msg['request-id'] = self.__request_id__
+        if'params' not in msg:
+            msg['params'] = {}
+        if "version" not in msg:
+            msg['version'] = self.facades[msg['type']]
+        outgoing = json.dumps(msg, indent=2, cls=encoder)
+        log.debug('connection {} -> {}'.format(id(self), outgoing))
+        for attempt in range(3):
+            if self.monitor.status == Monitor.DISCONNECTED:
+                # closed cleanly; shouldn't try to reconnect
+                raise websockets.exceptions.ConnectionClosed(
+                    0, 'websocket closed')
+            try:
+                await self.ws.send(outgoing)
+                break
+            except websockets.ConnectionClosed:
+                if attempt == 2:
+                    raise
+                log.warning('RPC: Connection closed, reconnecting')
+                # the reconnect has to be done in a separate task because,
+                # if it is triggered by the pinger, then this RPC call will
+                # be cancelled when the pinger is cancelled by the reconnect,
+                # and we don't want the reconnect to be aborted halfway through
+                await asyncio.wait([self.reconnect()], loop=self.loop)
+                if self.monitor.status != Monitor.CONNECTED:
+                    # reconnect failed; abort and shutdown
+                    log.error('RPC: Automatic reconnect failed')
+                    raise
+        result = await self._recv(msg['request-id'])
+        log.debug('connection {} <- {}'.format(id(self), result))
+
+        if not result:
+            return result
+
+        if 'error' in result:
+            # API Error Response
+            raise errors.JujuAPIError(result)
+
+        if 'response' not in result:
+            # This may never happen
+            return result
+
+        if 'results' in result['response']:
+            # Check for errors in a result list.
+            # TODO This loses the results that might have succeeded.
+            # Perhaps JujuError should return all the results including
+            # errors, or perhaps a keyword parameter to the rpc method
+            # could be added to trigger this behaviour.
+            err_results = []
+            for res in result['response']['results']:
+                if res.get('error', {}).get('message'):
+                    err_results.append(res['error']['message'])
+            if err_results:
+                raise errors.JujuError(err_results)
+
+        elif result['response'].get('error', {}).get('message'):
+            raise errors.JujuError(result['response']['error']['message'])
+
+        return result
+
+    def _http_headers(self):
+        """Return dictionary of http headers necessary for making an http
+        connection to the endpoint of this Connection.
+
+        :return: Dictionary of headers
+
+        """
+        if not self.usertag:
+            return {}
+
+        creds = u'{}:{}'.format(
+            self.usertag,
+            self.password or ''
+        )
+        token = base64.b64encode(creds.encode())
+        return {
+            'Authorization': 'Basic {}'.format(token.decode())
+        }
+
+    def https_connection(self):
+        """Return an https connection to this Connection's endpoint.
+
+        Returns a 3-tuple containing::
+
+            1. The :class:`HTTPSConnection` instance
+            2. Dictionary of auth headers to be used with the connection
+            3. The root url path (str) to be used for requests.
+
+        """
+        endpoint = self.endpoint
+        host, remainder = endpoint.split(':', 1)
+        port = remainder
+        if '/' in remainder:
+            port, _ = remainder.split('/', 1)
+
+        conn = HTTPSConnection(
+            host, int(port),
+            context=self._get_ssl(self.cacert),
+        )
+
+        path = (
+            "/model/{}".format(self.uuid)
+            if self.uuid else ""
+        )
+        return conn, self._http_headers(), path
+
+    async def clone(self):
+        """Return a new Connection, connected to the same websocket endpoint
+        as this one.
+
+        """
+        return await Connection.connect(**self.connect_params())
+
+    def connect_params(self):
+        """Return a tuple of parameters suitable for passing to
+        Connection.connect that can be used to make a new connection
+        to the same controller (and model if specified. The first
+        element in the returned tuple holds the endpoint argument;
+        the other holds a dict of the keyword args.
+        """
+        return {
+            'endpoint': self.endpoint,
+            'uuid': self.uuid,
+            'username': self.username,
+            'password': self.password,
+            'cacert': self.cacert,
+            'bakery_client': self.bakery_client,
+            'loop': self.loop,
+            'max_frame_size': self.max_frame_size,
+        }
+
+    async def controller(self):
+        """Return a Connection to the controller at self.endpoint
+        """
+        return await Connection.connect(
+            self.endpoint,
+            username=self.username,
+            password=self.password,
+            cacert=self.cacert,
+            bakery_client=self.bakery_client,
+            loop=self.loop,
+            max_frame_size=self.max_frame_size,
+        )
+
+    async def reconnect(self):
+        """ Force a reconnection.
+        """
+        monitor = self.monitor
+        if monitor.reconnecting.locked() or monitor.close_called.is_set():
+            return
+        async with monitor.reconnecting:
+            await self.close()
+            await self._connect_with_login([(self.endpoint, self.cacert)])
+
+    async def _connect(self, endpoints):
+        if len(endpoints) == 0:
+            raise errors.JujuConnectionError('no endpoints to connect to')
+
+        async def _try_endpoint(endpoint, cacert, delay):
+            if delay:
+                await asyncio.sleep(delay)
+            return await self._open(endpoint, cacert)
+
+        # Try all endpoints in parallel, with slight increasing delay (+100ms
+        # for each subsequent endpoint); the delay allows us to prefer the
+        # earlier endpoints over the latter. Use first successful connection.
+        tasks = [self.loop.create_task(_try_endpoint(endpoint, cacert,
+                                                     0.1 * i))
+                 for i, (endpoint, cacert) in enumerate(endpoints)]
+        for attempt in range(self._retries + 1):
+            for task in asyncio.as_completed(tasks, loop=self.loop):
+                try:
+                    result = await task
+                    break
+                except ConnectionError:
+                    continue  # ignore; try another endpoint
+            else:
+                _endpoints_str = ', '.join([endpoint
+                                            for endpoint, cacert in endpoints])
+                if attempt < self._retries:
+                    log.debug('Retrying connection to endpoints: {}; '
+                              'attempt {} of {}'.format(_endpoints_str,
+                                                        attempt + 1,
+                                                        self._retries + 1))
+                    await asyncio.sleep((attempt + 1) * self._retry_backoff)
+                    continue
+                else:
+                    raise errors.JujuConnectionError(
+                        'Unable to connect to any endpoint: '
+                        '{}'.format(_endpoints_str))
+            # only executed if inner loop's else did not continue
+            # (i.e., inner loop did break due to successful connection)
+            break
+        for task in tasks:
+            task.cancel()
+        self.ws = result[0]
+        self.addr = result[1]
+        self.endpoint = result[2]
+        self.cacert = result[3]
+        self._receiver_task.start()
+        log.debug("Driver connected to juju %s", self.addr)
+        self.monitor.close_called.clear()
+
+    async def _connect_with_login(self, endpoints):
+        """Connect to the websocket.
+
+        If uuid is None, the connection will be to the controller. Otherwise it
+        will be to the model.
+        :return: The response field of login response JSON object.
+        """
+        success = False
+        try:
+            await self._connect(endpoints)
+            # It's possible that we may get several discharge-required errors,
+            # corresponding to different levels of authentication, so retry
+            # a few times.
+            for i in range(0, 2):
+                result = (await self.login())['response']
+                macaroonJSON = result.get('discharge-required')
+                if macaroonJSON is None:
+                    self.info = result
+                    success = True
+                    return result
+                macaroon = bakery.Macaroon.from_dict(macaroonJSON)
+                self.bakery_client.handle_error(
+                    httpbakery.Error(
+                        code=httpbakery.ERR_DISCHARGE_REQUIRED,
+                        message=result.get('discharge-required-error'),
+                        version=macaroon.version,
+                        info=httpbakery.ErrorInfo(
+                            macaroon=macaroon,
+                            macaroon_path=result.get('macaroon-path'),
+                        ),
+                    ),
+                    # note: remove the port number.
+                    'https://' + self.endpoint + '/',
+                )
+            raise errors.JujuAuthError('failed to authenticate '
+                                       'after several attempts')
+        finally:
+            if not success:
+                await self.close()
+
+    async def _connect_with_redirect(self, endpoints):
+        try:
+            login_result = await self._connect_with_login(endpoints)
+        except errors.JujuRedirectException as e:
+            login_result = await self._connect_with_login(e.endpoints)
+        self._build_facades(login_result.get('facades', {}))
+        self._pinger_task.start()
+
+    def _build_facades(self, facades):
+        self.facades.clear()
+        for facade in facades:
+            self.facades[facade['name']] = facade['versions'][-1]
+
+    async def login(self):
+        params = {}
+        params['auth-tag'] = self.usertag
+        if self.password:
+            params['credentials'] = self.password
+        else:
+            macaroons = _macaroons_for_domain(self.bakery_client.cookies,
+                                              self.endpoint)
+            params['macaroons'] = [[bakery.macaroon_to_dict(m) for m in ms]
+                                   for ms in macaroons]
+
+        try:
+            return await self.rpc({
+                "type": "Admin",
+                "request": "Login",
+                "version": 3,
+                "params": params,
+            })
+        except errors.JujuAPIError as e:
+            if e.error_code != 'redirection required':
+                raise
+            log.info('Controller requested redirect')
+            # Fetch additional redirection information now so that
+            # we can safely close the connection after login
+            # fails.
+            redirect_info = (await self.rpc({
+                "type": "Admin",
+                "request": "RedirectInfo",
+                "version": 3,
+            }))['response']
+            raise errors.JujuRedirectException(redirect_info) from e
+
+
+class _Task:
+    def __init__(self, task, loop):
+        self.stopped = asyncio.Event(loop=loop)
+        self.stopped.set()
+        self.task = task
+        self.loop = loop
+
+    def start(self):
+        async def run():
+            try:
+                return await self.task()
+            finally:
+                self.stopped.set()
+        self.stopped.clear()
+        self.loop.create_task(run())
+
+
+def _macaroons_for_domain(cookies, domain):
+    '''Return any macaroons from the given cookie jar that
+    apply to the given domain name.'''
+    req = urllib.request.Request('https://' + domain + '/')
+    cookies.add_cookie_header(req)
+    return httpbakery.extract_macaroons(req)