--- /dev/null
+import asyncio
+import base64
+import json
+import logging
+import ssl
+import urllib.request
+import weakref
+from concurrent.futures import CancelledError
+from http.client import HTTPSConnection
+
+import macaroonbakery.httpbakery as httpbakery
+import macaroonbakery.bakery as bakery
+import websockets
+from juju import errors, tag, utils
+from juju.client import client
+from juju.utils import IdQueue
+
+log = logging.getLogger('juju.client.connection')
+
+
+class Monitor:
+ """
+ Monitor helper class for our Connection class.
+
+ Contains a reference to an instantiated Connection, along with a
+ reference to the Connection.receiver Future. Upon inspection of
+ these objects, this class determines whether the connection is in
+ an 'error', 'connected' or 'disconnected' state.
+
+ Use this class to stay up to date on the health of a connection,
+ and take appropriate action if the connection errors out due to
+ network issues or other unexpected circumstances.
+
+ """
+ ERROR = 'error'
+ CONNECTED = 'connected'
+ DISCONNECTING = 'disconnecting'
+ DISCONNECTED = 'disconnected'
+
+ def __init__(self, connection):
+ self.connection = weakref.ref(connection)
+ self.reconnecting = asyncio.Lock(loop=connection.loop)
+ self.close_called = asyncio.Event(loop=connection.loop)
+
+ @property
+ def status(self):
+ """
+ Determine the status of the connection and receiver, and return
+ ERROR, CONNECTED, or DISCONNECTED as appropriate.
+
+ For simplicity, we only consider ourselves to be connected
+ after the Connection class has setup a receiver task. This
+ only happens after the websocket is open, and the connection
+ isn't usable until that receiver has been started.
+
+ """
+ connection = self.connection()
+
+ # the connection instance was destroyed but someone kept
+ # a separate reference to the monitor for some reason
+ if not connection:
+ return self.DISCONNECTED
+
+ # connection cleanly disconnected or not yet opened
+ if not connection.ws:
+ return self.DISCONNECTED
+
+ # close called but not yet complete
+ if self.close_called.is_set():
+ return self.DISCONNECTING
+
+ # connection closed uncleanly (we didn't call connection.close)
+ stopped = connection._receiver_task.stopped.is_set()
+ if stopped or not connection.ws.open:
+ return self.ERROR
+
+ # everything is fine!
+ return self.CONNECTED
+
+
+class Connection:
+ """
+ Usage::
+
+ # Connect to an arbitrary api server
+ client = await Connection.connect(
+ api_endpoint, model_uuid, username, password, cacert)
+
+ Note: Any connection method or constructor can accept an optional `loop`
+ argument to override the default event loop from `asyncio.get_event_loop`.
+ """
+
+ MAX_FRAME_SIZE = 2**22
+ "Maximum size for a single frame. Defaults to 4MB."
+
+ @classmethod
+ async def connect(
+ cls,
+ endpoint=None,
+ uuid=None,
+ username=None,
+ password=None,
+ cacert=None,
+ bakery_client=None,
+ loop=None,
+ max_frame_size=None,
+ retries=3,
+ retry_backoff=10,
+ ):
+ """Connect to the websocket.
+
+ If uuid is None, the connection will be to the controller. Otherwise it
+ will be to the model.
+
+ :param str endpoint: The hostname:port of the controller to connect to.
+ :param str uuid: The model UUID to connect to (None for a
+ controller-only connection).
+ :param str username: The username for controller-local users (or None
+ to use macaroon-based login.)
+ :param str password: The password for controller-local users.
+ :param str cacert: The CA certificate of the controller
+ (PEM formatted).
+ :param httpbakery.Client bakery_client: The macaroon bakery client to
+ to use when performing macaroon-based login. Macaroon tokens
+ acquired when logging will be saved to bakery_client.cookies.
+ If this is None, a default bakery_client will be used.
+ :param asyncio.BaseEventLoop loop: The event loop to use for async
+ operations.
+ :param int max_frame_size: The maximum websocket frame size to allow.
+ :param int retries: When connecting or reconnecting, and all endpoints
+ fail, how many times to retry the connection before giving up.
+ :param int retry_backoff: Number of seconds to increase the wait
+ between connection retry attempts (a backoff of 10 with 3 retries
+ would wait 10s, 20s, and 30s).
+ """
+ self = cls()
+ if endpoint is None:
+ raise ValueError('no endpoint provided')
+ self.uuid = uuid
+ if bakery_client is None:
+ bakery_client = httpbakery.Client()
+ self.bakery_client = bakery_client
+ if username and '@' in username and not username.endswith('@local'):
+ # We're trying to log in as an external user - we need to use
+ # macaroon authentication with no username or password.
+ if password is not None:
+ raise errors.JujuAuthError('cannot log in as external '
+ 'user with a password')
+ username = None
+ self.usertag = tag.user(username)
+ self.password = password
+ self.loop = loop or asyncio.get_event_loop()
+
+ self.__request_id__ = 0
+
+ # The following instance variables are initialized by the
+ # _connect_with_redirect method, but create them here
+ # as a reminder that they will exist.
+ self.addr = None
+ self.ws = None
+ self.endpoint = None
+ self.cacert = None
+ self.info = None
+
+ # Create that _Task objects but don't start the tasks yet.
+ self._pinger_task = _Task(self._pinger, self.loop)
+ self._receiver_task = _Task(self._receiver, self.loop)
+
+ self._retries = retries
+ self._retry_backoff = retry_backoff
+
+ self.facades = {}
+ self.messages = IdQueue(loop=self.loop)
+ self.monitor = Monitor(connection=self)
+ if max_frame_size is None:
+ max_frame_size = self.MAX_FRAME_SIZE
+ self.max_frame_size = max_frame_size
+ await self._connect_with_redirect([(endpoint, cacert)])
+ return self
+
+ @property
+ def username(self):
+ if not self.usertag:
+ return None
+ return self.usertag[len('user-'):]
+
+ @property
+ def is_open(self):
+ return self.monitor.status == Monitor.CONNECTED
+
+ def _get_ssl(self, cert=None):
+ return ssl.create_default_context(
+ purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
+
+ async def _open(self, endpoint, cacert):
+ if self.uuid:
+ url = "wss://{}/model/{}/api".format(endpoint, self.uuid)
+ else:
+ url = "wss://{}/api".format(endpoint)
+
+ return (await websockets.connect(
+ url,
+ ssl=self._get_ssl(cacert),
+ loop=self.loop,
+ max_size=self.max_frame_size,
+ ), url, endpoint, cacert)
+
+ async def close(self):
+ if not self.ws:
+ return
+ self.monitor.close_called.set()
+ await self._pinger_task.stopped.wait()
+ await self._receiver_task.stopped.wait()
+ await self.ws.close()
+ self.ws = None
+
+ async def _recv(self, request_id):
+ if not self.is_open:
+ raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
+ return await self.messages.get(request_id)
+
+ async def _receiver(self):
+ try:
+ while self.is_open:
+ result = await utils.run_with_interrupt(
+ self.ws.recv(),
+ self.monitor.close_called,
+ loop=self.loop)
+ if self.monitor.close_called.is_set():
+ break
+ if result is not None:
+ result = json.loads(result)
+ await self.messages.put(result['request-id'], result)
+ except CancelledError:
+ pass
+ except websockets.ConnectionClosed as e:
+ log.warning('Receiver: Connection closed, reconnecting')
+ await self.messages.put_all(e)
+ # the reconnect has to be done as a task because the receiver will
+ # be cancelled by the reconnect and we don't want the reconnect
+ # to be aborted half-way through
+ self.loop.create_task(self.reconnect())
+ return
+ except Exception as e:
+ log.exception("Error in receiver")
+ # make pending listeners aware of the error
+ await self.messages.put_all(e)
+ raise
+
+ async def _pinger(self):
+ '''
+ A Controller can time us out if we are silent for too long. This
+ is especially true in JaaS, which has a fairly strict timeout.
+
+ To prevent timing out, we send a ping every ten seconds.
+
+ '''
+ async def _do_ping():
+ try:
+ await pinger_facade.Ping()
+ await asyncio.sleep(10, loop=self.loop)
+ except CancelledError:
+ pass
+
+ pinger_facade = client.PingerFacade.from_connection(self)
+ try:
+ while True:
+ await utils.run_with_interrupt(
+ _do_ping(),
+ self.monitor.close_called,
+ loop=self.loop)
+ if self.monitor.close_called.is_set():
+ break
+ except websockets.exceptions.ConnectionClosed:
+ # The connection has closed - we can't do anything
+ # more until the connection is restarted.
+ log.debug('ping failed because of closed connection')
+ pass
+
+ async def rpc(self, msg, encoder=None):
+ '''Make an RPC to the API. The message is encoded as JSON
+ using the given encoder if any.
+ :param msg: Parameters for the call (will be encoded as JSON).
+ :param encoder: Encoder to be used when encoding the message.
+ :return: The result of the call.
+ :raises JujuAPIError: When there's an error returned.
+ :raises JujuError:
+ '''
+ self.__request_id__ += 1
+ msg['request-id'] = self.__request_id__
+ if'params' not in msg:
+ msg['params'] = {}
+ if "version" not in msg:
+ msg['version'] = self.facades[msg['type']]
+ outgoing = json.dumps(msg, indent=2, cls=encoder)
+ log.debug('connection {} -> {}'.format(id(self), outgoing))
+ for attempt in range(3):
+ if self.monitor.status == Monitor.DISCONNECTED:
+ # closed cleanly; shouldn't try to reconnect
+ raise websockets.exceptions.ConnectionClosed(
+ 0, 'websocket closed')
+ try:
+ await self.ws.send(outgoing)
+ break
+ except websockets.ConnectionClosed:
+ if attempt == 2:
+ raise
+ log.warning('RPC: Connection closed, reconnecting')
+ # the reconnect has to be done in a separate task because,
+ # if it is triggered by the pinger, then this RPC call will
+ # be cancelled when the pinger is cancelled by the reconnect,
+ # and we don't want the reconnect to be aborted halfway through
+ await asyncio.wait([self.reconnect()], loop=self.loop)
+ if self.monitor.status != Monitor.CONNECTED:
+ # reconnect failed; abort and shutdown
+ log.error('RPC: Automatic reconnect failed')
+ raise
+ result = await self._recv(msg['request-id'])
+ log.debug('connection {} <- {}'.format(id(self), result))
+
+ if not result:
+ return result
+
+ if 'error' in result:
+ # API Error Response
+ raise errors.JujuAPIError(result)
+
+ if 'response' not in result:
+ # This may never happen
+ return result
+
+ if 'results' in result['response']:
+ # Check for errors in a result list.
+ # TODO This loses the results that might have succeeded.
+ # Perhaps JujuError should return all the results including
+ # errors, or perhaps a keyword parameter to the rpc method
+ # could be added to trigger this behaviour.
+ err_results = []
+ for res in result['response']['results']:
+ if res.get('error', {}).get('message'):
+ err_results.append(res['error']['message'])
+ if err_results:
+ raise errors.JujuError(err_results)
+
+ elif result['response'].get('error', {}).get('message'):
+ raise errors.JujuError(result['response']['error']['message'])
+
+ return result
+
+ def _http_headers(self):
+ """Return dictionary of http headers necessary for making an http
+ connection to the endpoint of this Connection.
+
+ :return: Dictionary of headers
+
+ """
+ if not self.usertag:
+ return {}
+
+ creds = u'{}:{}'.format(
+ self.usertag,
+ self.password or ''
+ )
+ token = base64.b64encode(creds.encode())
+ return {
+ 'Authorization': 'Basic {}'.format(token.decode())
+ }
+
+ def https_connection(self):
+ """Return an https connection to this Connection's endpoint.
+
+ Returns a 3-tuple containing::
+
+ 1. The :class:`HTTPSConnection` instance
+ 2. Dictionary of auth headers to be used with the connection
+ 3. The root url path (str) to be used for requests.
+
+ """
+ endpoint = self.endpoint
+ host, remainder = endpoint.split(':', 1)
+ port = remainder
+ if '/' in remainder:
+ port, _ = remainder.split('/', 1)
+
+ conn = HTTPSConnection(
+ host, int(port),
+ context=self._get_ssl(self.cacert),
+ )
+
+ path = (
+ "/model/{}".format(self.uuid)
+ if self.uuid else ""
+ )
+ return conn, self._http_headers(), path
+
+ async def clone(self):
+ """Return a new Connection, connected to the same websocket endpoint
+ as this one.
+
+ """
+ return await Connection.connect(**self.connect_params())
+
+ def connect_params(self):
+ """Return a tuple of parameters suitable for passing to
+ Connection.connect that can be used to make a new connection
+ to the same controller (and model if specified. The first
+ element in the returned tuple holds the endpoint argument;
+ the other holds a dict of the keyword args.
+ """
+ return {
+ 'endpoint': self.endpoint,
+ 'uuid': self.uuid,
+ 'username': self.username,
+ 'password': self.password,
+ 'cacert': self.cacert,
+ 'bakery_client': self.bakery_client,
+ 'loop': self.loop,
+ 'max_frame_size': self.max_frame_size,
+ }
+
+ async def controller(self):
+ """Return a Connection to the controller at self.endpoint
+ """
+ return await Connection.connect(
+ self.endpoint,
+ username=self.username,
+ password=self.password,
+ cacert=self.cacert,
+ bakery_client=self.bakery_client,
+ loop=self.loop,
+ max_frame_size=self.max_frame_size,
+ )
+
+ async def reconnect(self):
+ """ Force a reconnection.
+ """
+ monitor = self.monitor
+ if monitor.reconnecting.locked() or monitor.close_called.is_set():
+ return
+ async with monitor.reconnecting:
+ await self.close()
+ await self._connect_with_login([(self.endpoint, self.cacert)])
+
+ async def _connect(self, endpoints):
+ if len(endpoints) == 0:
+ raise errors.JujuConnectionError('no endpoints to connect to')
+
+ async def _try_endpoint(endpoint, cacert, delay):
+ if delay:
+ await asyncio.sleep(delay)
+ return await self._open(endpoint, cacert)
+
+ # Try all endpoints in parallel, with slight increasing delay (+100ms
+ # for each subsequent endpoint); the delay allows us to prefer the
+ # earlier endpoints over the latter. Use first successful connection.
+ tasks = [self.loop.create_task(_try_endpoint(endpoint, cacert,
+ 0.1 * i))
+ for i, (endpoint, cacert) in enumerate(endpoints)]
+ for attempt in range(self._retries + 1):
+ for task in asyncio.as_completed(tasks, loop=self.loop):
+ try:
+ result = await task
+ break
+ except ConnectionError:
+ continue # ignore; try another endpoint
+ else:
+ _endpoints_str = ', '.join([endpoint
+ for endpoint, cacert in endpoints])
+ if attempt < self._retries:
+ log.debug('Retrying connection to endpoints: {}; '
+ 'attempt {} of {}'.format(_endpoints_str,
+ attempt + 1,
+ self._retries + 1))
+ await asyncio.sleep((attempt + 1) * self._retry_backoff)
+ continue
+ else:
+ raise errors.JujuConnectionError(
+ 'Unable to connect to any endpoint: '
+ '{}'.format(_endpoints_str))
+ # only executed if inner loop's else did not continue
+ # (i.e., inner loop did break due to successful connection)
+ break
+ for task in tasks:
+ task.cancel()
+ self.ws = result[0]
+ self.addr = result[1]
+ self.endpoint = result[2]
+ self.cacert = result[3]
+ self._receiver_task.start()
+ log.debug("Driver connected to juju %s", self.addr)
+ self.monitor.close_called.clear()
+
+ async def _connect_with_login(self, endpoints):
+ """Connect to the websocket.
+
+ If uuid is None, the connection will be to the controller. Otherwise it
+ will be to the model.
+ :return: The response field of login response JSON object.
+ """
+ success = False
+ try:
+ await self._connect(endpoints)
+ # It's possible that we may get several discharge-required errors,
+ # corresponding to different levels of authentication, so retry
+ # a few times.
+ for i in range(0, 2):
+ result = (await self.login())['response']
+ macaroonJSON = result.get('discharge-required')
+ if macaroonJSON is None:
+ self.info = result
+ success = True
+ return result
+ macaroon = bakery.Macaroon.from_dict(macaroonJSON)
+ self.bakery_client.handle_error(
+ httpbakery.Error(
+ code=httpbakery.ERR_DISCHARGE_REQUIRED,
+ message=result.get('discharge-required-error'),
+ version=macaroon.version,
+ info=httpbakery.ErrorInfo(
+ macaroon=macaroon,
+ macaroon_path=result.get('macaroon-path'),
+ ),
+ ),
+ # note: remove the port number.
+ 'https://' + self.endpoint + '/',
+ )
+ raise errors.JujuAuthError('failed to authenticate '
+ 'after several attempts')
+ finally:
+ if not success:
+ await self.close()
+
+ async def _connect_with_redirect(self, endpoints):
+ try:
+ login_result = await self._connect_with_login(endpoints)
+ except errors.JujuRedirectException as e:
+ login_result = await self._connect_with_login(e.endpoints)
+ self._build_facades(login_result.get('facades', {}))
+ self._pinger_task.start()
+
+ def _build_facades(self, facades):
+ self.facades.clear()
+ for facade in facades:
+ self.facades[facade['name']] = facade['versions'][-1]
+
+ async def login(self):
+ params = {}
+ params['auth-tag'] = self.usertag
+ if self.password:
+ params['credentials'] = self.password
+ else:
+ macaroons = _macaroons_for_domain(self.bakery_client.cookies,
+ self.endpoint)
+ params['macaroons'] = [[bakery.macaroon_to_dict(m) for m in ms]
+ for ms in macaroons]
+
+ try:
+ return await self.rpc({
+ "type": "Admin",
+ "request": "Login",
+ "version": 3,
+ "params": params,
+ })
+ except errors.JujuAPIError as e:
+ if e.error_code != 'redirection required':
+ raise
+ log.info('Controller requested redirect')
+ # Fetch additional redirection information now so that
+ # we can safely close the connection after login
+ # fails.
+ redirect_info = (await self.rpc({
+ "type": "Admin",
+ "request": "RedirectInfo",
+ "version": 3,
+ }))['response']
+ raise errors.JujuRedirectException(redirect_info) from e
+
+
+class _Task:
+ def __init__(self, task, loop):
+ self.stopped = asyncio.Event(loop=loop)
+ self.stopped.set()
+ self.task = task
+ self.loop = loop
+
+ def start(self):
+ async def run():
+ try:
+ return await self.task()
+ finally:
+ self.stopped.set()
+ self.stopped.clear()
+ self.loop.create_task(run())
+
+
+def _macaroons_for_domain(cookies, domain):
+ '''Return any macaroons from the given cookie jar that
+ apply to the given domain name.'''
+ req = urllib.request.Request('https://' + domain + '/')
+ cookies.add_cookie_header(req)
+ return httpbakery.extract_macaroons(req)