X-Git-Url: https://osm.etsi.org/gitweb/?p=osm%2FN2VC.git;a=blobdiff_plain;f=modules%2Flibjuju%2Fjuju%2Fclient%2Fconnection.py;h=13770a5343d841bbb8369034c03499eaf7d8bc5e;hp=c09468c65eef6f1d265fc9107f99ebd1a05bfa8b;hb=29ad6453fb8cdece73b8c2f623cf81d5d730982d;hpb=1a15d1c84fc826fa7996c1c9d221a324edd33432 diff --git a/modules/libjuju/juju/client/connection.py b/modules/libjuju/juju/client/connection.py index c09468c..13770a5 100644 --- a/modules/libjuju/juju/client/connection.py +++ b/modules/libjuju/juju/client/connection.py @@ -1,28 +1,21 @@ +import asyncio import base64 -import io import json import logging -import os -import random -import shlex import ssl -import string -import subprocess +import urllib.request import weakref -import websockets from concurrent.futures import CancelledError from http.client import HTTPSConnection -from pathlib import Path - -import asyncio -import yaml -from juju import tag, utils +import macaroonbakery.httpbakery as httpbakery +import macaroonbakery.bakery as bakery +import websockets +from juju import errors, tag, utils from juju.client import client -from juju.errors import JujuError, JujuAPIError, JujuConnectionError from juju.utils import IdQueue -log = logging.getLogger("websocket") +log = logging.getLogger('juju.client.connection') class Monitor: @@ -30,7 +23,7 @@ class Monitor: Monitor helper class for our Connection class. Contains a reference to an instantiated Connection, along with a - reference to the Connection.receiver Future. Upon inspecttion of + reference to the Connection.receiver Future. Upon inspection of these objects, this class determines whether the connection is in an 'error', 'connected' or 'disconnected' state. @@ -48,10 +41,6 @@ class Monitor: self.connection = weakref.ref(connection) self.reconnecting = asyncio.Lock(loop=connection.loop) self.close_called = asyncio.Event(loop=connection.loop) - self.receiver_stopped = asyncio.Event(loop=connection.loop) - self.pinger_stopped = asyncio.Event(loop=connection.loop) - self.receiver_stopped.set() - self.pinger_stopped.set() @property def status(self): @@ -81,7 +70,8 @@ class Monitor: return self.DISCONNECTING # connection closed uncleanly (we didn't call connection.close) - if self.receiver_stopped.is_set() or not connection.ws.open: + stopped = connection._receiver_task.stopped.is_set() + if stopped or not connection.ws.open: return self.ERROR # everything is fine! @@ -96,47 +86,93 @@ class Connection: client = await Connection.connect( api_endpoint, model_uuid, username, password, cacert) - # Connect using a controller/model name - client = await Connection.connect_model('local.local:default') - - # Connect to the currently active model - client = await Connection.connect_current() - Note: Any connection method or constructor can accept an optional `loop` argument to override the default event loop from `asyncio.get_event_loop`. """ - DEFAULT_FRAME_SIZE = 'default_frame_size' MAX_FRAME_SIZE = 2**22 "Maximum size for a single frame. Defaults to 4MB." - def __init__( - self, endpoint, uuid, username, password, cacert=None, - macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE): - self.endpoint = endpoint - self._endpoint = endpoint + @classmethod + async def connect( + cls, + endpoint=None, + uuid=None, + username=None, + password=None, + cacert=None, + bakery_client=None, + loop=None, + max_frame_size=None, + ): + """Connect to the websocket. + + If uuid is None, the connection will be to the controller. Otherwise it + will be to the model. + + :param str endpoint: The hostname:port of the controller to connect to. + :param str uuid: The model UUID to connect to (None for a + controller-only connection). + :param str username: The username for controller-local users (or None + to use macaroon-based login.) + :param str password: The password for controller-local users. + :param str cacert: The CA certificate of the controller + (PEM formatted). + :param httpbakery.Client bakery_client: The macaroon bakery client to + to use when performing macaroon-based login. Macaroon tokens + acquired when logging will be saved to bakery_client.cookies. + If this is None, a default bakery_client will be used. + :param asyncio.BaseEventLoop loop: The event loop to use for async + operations. + :param int max_frame_size: The maximum websocket frame size to allow. + """ + self = cls() + if endpoint is None: + raise ValueError('no endpoint provided') self.uuid = uuid - if macaroons: - self.macaroons = macaroons - self.username = '' - self.password = '' - else: - self.macaroons = [] - self.username = username - self.password = password - self.cacert = cacert - self._cacert = cacert + if bakery_client is None: + bakery_client = httpbakery.Client() + self.bakery_client = bakery_client + if username and '@' in username and not username.endswith('@local'): + # We're trying to log in as an external user - we need to use + # macaroon authentication with no username or password. + if password is not None: + raise errors.JujuAuthError('cannot log in as external ' + 'user with a password') + username = None + self.usertag = tag.user(username) + self.password = password self.loop = loop or asyncio.get_event_loop() self.__request_id__ = 0 + + # The following instance variables are initialized by the + # _connect_with_redirect method, but create them here + # as a reminder that they will exist. self.addr = None self.ws = None + self.endpoint = None + self.cacert = None + self.info = None + + # Create that _Task objects but don't start the tasks yet. + self._pinger_task = _Task(self._pinger, self.loop) + self._receiver_task = _Task(self._receiver, self.loop) + self.facades = {} self.messages = IdQueue(loop=self.loop) self.monitor = Monitor(connection=self) - if max_frame_size is self.DEFAULT_FRAME_SIZE: + if max_frame_size is None: max_frame_size = self.MAX_FRAME_SIZE self.max_frame_size = max_frame_size + await self._connect_with_redirect([(endpoint, cacert)]) + return self + + @property + def username(self): + if not self.usertag: + return None + return self.usertag[len('user-'):] @property def is_open(self): @@ -146,39 +182,34 @@ class Connection: return ssl.create_default_context( purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert) - async def open(self): + async def _open(self, endpoint, cacert): if self.uuid: - url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid) + url = "wss://{}/model/{}/api".format(endpoint, self.uuid) else: - url = "wss://{}/api".format(self.endpoint) - - kw = dict() - kw['ssl'] = self._get_ssl(self.cacert) - kw['loop'] = self.loop - kw['max_size'] = self.max_frame_size - self.addr = url - self.ws = await websockets.connect(url, **kw) - self.loop.create_task(self.receiver()) - self.monitor.receiver_stopped.clear() - log.info("Driver connected to juju %s", url) - self.monitor.close_called.clear() - return self + url = "wss://{}/api".format(endpoint) + + return (await websockets.connect( + url, + ssl=self._get_ssl(cacert), + loop=self.loop, + max_size=self.max_frame_size, + ), url, endpoint, cacert) async def close(self): if not self.ws: return self.monitor.close_called.set() - await self.monitor.pinger_stopped.wait() - await self.monitor.receiver_stopped.wait() + await self._pinger_task.stopped.wait() + await self._receiver_task.stopped.wait() await self.ws.close() self.ws = None - async def recv(self, request_id): + async def _recv(self, request_id): if not self.is_open: raise websockets.exceptions.ConnectionClosed(0, 'websocket closed') return await self.messages.get(request_id) - async def receiver(self): + async def _receiver(self): try: while self.is_open: result = await utils.run_with_interrupt( @@ -205,10 +236,8 @@ class Connection: # make pending listeners aware of the error await self.messages.put_all(e) raise - finally: - self.monitor.receiver_stopped.set() - async def pinger(self): + async def _pinger(self): ''' A Controller can time us out if we are silent for too long. This is especially true in JaaS, which has a fairly strict timeout. @@ -232,11 +261,21 @@ class Connection: loop=self.loop) if self.monitor.close_called.is_set(): break - finally: - self.monitor.pinger_stopped.set() - return + except websockets.exceptions.ConnectionClosed: + # The connection has closed - we can't do anything + # more until the connection is restarted. + log.debug('ping failed because of closed connection') + pass async def rpc(self, msg, encoder=None): + '''Make an RPC to the API. The message is encoded as JSON + using the given encoder if any. + :param msg: Parameters for the call (will be encoded as JSON). + :param encoder: Encoder to be used when encoding the message. + :return: The result of the call. + :raises JujuAPIError: When there's an error returned. + :raises JujuError: + ''' self.__request_id__ += 1 msg['request-id'] = self.__request_id__ if'params' not in msg: @@ -244,7 +283,12 @@ class Connection: if "version" not in msg: msg['version'] = self.facades[msg['type']] outgoing = json.dumps(msg, indent=2, cls=encoder) + log.debug('connection {} -> {}'.format(id(self), outgoing)) for attempt in range(3): + if self.monitor.status == Monitor.DISCONNECTED: + # closed cleanly; shouldn't try to reconnect + raise websockets.exceptions.ConnectionClosed( + 0, 'websocket closed') try: await self.ws.send(outgoing) break @@ -257,14 +301,19 @@ class Connection: # be cancelled when the pinger is cancelled by the reconnect, # and we don't want the reconnect to be aborted halfway through await asyncio.wait([self.reconnect()], loop=self.loop) - result = await self.recv(msg['request-id']) + if self.monitor.status != Monitor.CONNECTED: + # reconnect failed; abort and shutdown + log.error('RPC: Automatic reconnect failed') + raise + result = await self._recv(msg['request-id']) + log.debug('connection {} <- {}'.format(id(self), result)) if not result: return result if 'error' in result: # API Error Response - raise JujuAPIError(result) + raise errors.JujuAPIError(result) if 'response' not in result: # This may never happen @@ -272,30 +321,34 @@ class Connection: if 'results' in result['response']: # Check for errors in a result list. - errors = [] + # TODO This loses the results that might have succeeded. + # Perhaps JujuError should return all the results including + # errors, or perhaps a keyword parameter to the rpc method + # could be added to trigger this behaviour. + err_results = [] for res in result['response']['results']: if res.get('error', {}).get('message'): - errors.append(res['error']['message']) - if errors: - raise JujuError(errors) + err_results.append(res['error']['message']) + if err_results: + raise errors.JujuError(err_results) elif result['response'].get('error', {}).get('message'): - raise JujuError(result['response']['error']['message']) + raise errors.JujuError(result['response']['error']['message']) return result - def http_headers(self): + def _http_headers(self): """Return dictionary of http headers necessary for making an http connection to the endpoint of this Connection. :return: Dictionary of headers """ - if not self.username: + if not self.usertag: return {} creds = u'{}:{}'.format( - tag.user(self.username), + self.usertag, self.password or '' ) token = base64.b64encode(creds.encode()) @@ -328,70 +381,46 @@ class Connection: "/model/{}".format(self.uuid) if self.uuid else "" ) - return conn, self.http_headers(), path + return conn, self._http_headers(), path async def clone(self): """Return a new Connection, connected to the same websocket endpoint as this one. """ - return await Connection.connect( - self.endpoint, - self.uuid, - self.username, - self.password, - self.cacert, - self.macaroons, - self.loop, - self.max_frame_size, - ) + return await Connection.connect(**self.connect_params()) + + def connect_params(self): + """Return a tuple of parameters suitable for passing to + Connection.connect that can be used to make a new connection + to the same controller (and model if specified. The first + element in the returned tuple holds the endpoint argument; + the other holds a dict of the keyword args. + """ + return { + 'endpoint': self.endpoint, + 'uuid': self.uuid, + 'username': self.username, + 'password': self.password, + 'cacert': self.cacert, + 'bakery_client': self.bakery_client, + 'loop': self.loop, + 'max_frame_size': self.max_frame_size, + } async def controller(self): """Return a Connection to the controller at self.endpoint - """ return await Connection.connect( self.endpoint, - None, - self.username, - self.password, - self.cacert, - self.macaroons, - self.loop, + username=self.username, + password=self.password, + cacert=self.cacert, + bakery_client=self.bakery_client, + loop=self.loop, + max_frame_size=self.max_frame_size, ) - async def _try_endpoint(self, endpoint, cacert): - success = False - result = None - new_endpoints = [] - - self.endpoint = endpoint - self.cacert = cacert - await self.open() - try: - result = await self.login() - if 'discharge-required-error' in result['response']: - log.info('Macaroon discharge required, disconnecting') - else: - # successful login! - log.info('Authenticated') - success = True - except JujuAPIError as e: - if e.error_code != 'redirection required': - raise - log.info('Controller requested redirect') - redirect_info = await self.redirect_info() - redir_cacert = redirect_info['ca-cert'] - new_endpoints = [ - ("{value}:{port}".format(**s), redir_cacert) - for servers in redirect_info['servers'] - for s in servers if s["scope"] == 'public' - ] - finally: - if not success: - await self.close() - return success, result, new_endpoints - async def reconnect(self): """ Force a reconnection. """ @@ -400,256 +429,149 @@ class Connection: return async with monitor.reconnecting: await self.close() - await self._connect() - - async def _connect(self): - endpoints = [(self._endpoint, self._cacert)] - while endpoints: - _endpoint, _cacert = endpoints.pop(0) - success, result, new_endpoints = await self._try_endpoint( - _endpoint, _cacert) - if success: + await self._connect_with_login([(self.endpoint, self.cacert)]) + + async def _connect(self, endpoints): + if len(endpoints) == 0: + raise errors.JujuConnectionError('no endpoints to connect to') + + async def _try_endpoint(endpoint, cacert, delay): + if delay: + await asyncio.sleep(delay) + return await self._open(endpoint, cacert) + + # Try all endpoints in parallel, with slight increasing delay (+100ms + # for each subsequent endpoint); the delay allows us to prefer the + # earlier endpoints over the latter. Use first successful connection. + tasks = [self.loop.create_task(_try_endpoint(endpoint, cacert, + 0.1 * i)) + for i, (endpoint, cacert) in enumerate(endpoints)] + for task in asyncio.as_completed(tasks, loop=self.loop): + try: + result = await task break - endpoints.extend(new_endpoints) + except ConnectionError: + continue # ignore; try another endpoint else: - # ran out of endpoints without a successful login - raise JujuConnectionError("Couldn't authenticate to {}".format( - self._endpoint)) - - response = result['response'] - self.info = response.copy() - self.build_facades(response.get('facades', {})) - self.loop.create_task(self.pinger()) - self.monitor.pinger_stopped.clear() + raise errors.JujuConnectionError( + 'Unable to connect to any endpoint: {}'.format(', '.join([ + endpoint for endpoint, cacert in endpoints]))) + for task in tasks: + task.cancel() + self.ws = result[0] + self.addr = result[1] + self.endpoint = result[2] + self.cacert = result[3] + self._receiver_task.start() + log.info("Driver connected to juju %s", self.addr) + self.monitor.close_called.clear() - @classmethod - async def connect( - cls, endpoint, uuid, username, password, cacert=None, - macaroons=None, loop=None, max_frame_size=None): + async def _connect_with_login(self, endpoints): """Connect to the websocket. If uuid is None, the connection will be to the controller. Otherwise it will be to the model. - - """ - client = cls(endpoint, uuid, username, password, cacert, macaroons, - loop, max_frame_size) - await client._connect() - return client - - @classmethod - async def connect_current(cls, loop=None, max_frame_size=None): - """Connect to the currently active model. - - """ - jujudata = JujuData() - - controller_name = jujudata.current_controller() - if not controller_name: - raise JujuConnectionError('No current controller') - - model_name = jujudata.current_model() - - return await cls.connect_model( - '{}:{}'.format(controller_name, model_name), loop, max_frame_size) - - @classmethod - async def connect_current_controller(cls, loop=None, max_frame_size=None): - """Connect to the currently active controller. - + :return: The response field of login response JSON object. """ - jujudata = JujuData() - controller_name = jujudata.current_controller() - if not controller_name: - raise JujuConnectionError('No current controller') - - return await cls.connect_controller(controller_name, loop, - max_frame_size) - - @classmethod - async def connect_controller(cls, controller_name, loop=None, - max_frame_size=None): - """Connect to a controller by name. - - """ - jujudata = JujuData() - controller = jujudata.controllers()[controller_name] - endpoint = controller['api-endpoints'][0] - cacert = controller.get('ca-cert') - accounts = jujudata.accounts()[controller_name] - username = accounts['user'] - password = accounts.get('password') - macaroons = get_macaroons(controller_name) if not password else None - - return await cls.connect( - endpoint, None, username, password, cacert, macaroons, loop, - max_frame_size) - - @classmethod - async def connect_model(cls, model, loop=None, max_frame_size=None): - """Connect to a model by name. - - :param str model: [:] + success = False + try: + await self._connect(endpoints) + # It's possible that we may get several discharge-required errors, + # corresponding to different levels of authentication, so retry + # a few times. + for i in range(0, 2): + result = (await self.login())['response'] + macaroonJSON = result.get('discharge-required') + if macaroonJSON is None: + self.info = result + success = True + return result + macaroon = bakery.Macaroon.from_dict(macaroonJSON) + self.bakery_client.handle_error( + httpbakery.Error( + code=httpbakery.ERR_DISCHARGE_REQUIRED, + message=result.get('discharge-required-error'), + version=macaroon.version, + info=httpbakery.ErrorInfo( + macaroon=macaroon, + macaroon_path=result.get('macaroon-path'), + ), + ), + # note: remove the port number. + 'https://' + self.endpoint + '/', + ) + raise errors.JujuAuthError('failed to authenticate ' + 'after several attempts') + finally: + if not success: + await self.close() - """ - jujudata = JujuData() + async def _connect_with_redirect(self, endpoints): + try: + login_result = await self._connect_with_login(endpoints) + except errors.JujuRedirectException as e: + login_result = await self._connect_with_login(e.endpoints) + self._build_facades(login_result.get('facades', {})) + self._pinger_task.start() - if ':' in model: - # explicit controller given - controller_name, model_name = model.split(':') - else: - # use the current controller if one isn't explicitly given - controller_name = jujudata.current_controller() - model_name = model - - accounts = jujudata.accounts()[controller_name] - username = accounts['user'] - # model name must include a user prefix, so add it if it doesn't - if '/' not in model_name: - model_name = '{}/{}'.format(username, model_name) - - controller = jujudata.controllers()[controller_name] - endpoint = controller['api-endpoints'][0] - cacert = controller.get('ca-cert') - password = accounts.get('password') - models = jujudata.models()[controller_name] - model_uuid = models['models'][model_name]['uuid'] - macaroons = get_macaroons(controller_name) if not password else None - - return await cls.connect( - endpoint, model_uuid, username, password, cacert, macaroons, loop, - max_frame_size) - - def build_facades(self, facades): + def _build_facades(self, facades): self.facades.clear() for facade in facades: self.facades[facade['name']] = facade['versions'][-1] async def login(self): - username = self.username - if username and not username.startswith('user-'): - username = 'user-{}'.format(username) - - result = await self.rpc({ - "type": "Admin", - "request": "Login", - "version": 3, - "params": { - "auth-tag": username, - "credentials": self.password, - "nonce": "".join(random.sample(string.printable, 12)), - "macaroons": self.macaroons - }}) - return result + params = {} + params['auth-tag'] = self.usertag + if self.password: + params['credentials'] = self.password + else: + macaroons = _macaroons_for_domain(self.bakery_client.cookies, + self.endpoint) + params['macaroons'] = [[bakery.macaroon_to_dict(m) for m in ms] + for ms in macaroons] - async def redirect_info(self): try: - result = await self.rpc({ + return await self.rpc({ "type": "Admin", - "request": "RedirectInfo", + "request": "Login", "version": 3, + "params": params, }) - except JujuAPIError as e: - if e.message == 'not redirected': - return None - raise - return result['response'] - - -class JujuData: - def __init__(self): - self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju' - self.path = os.path.abspath(os.path.expanduser(self.path)) - - def current_controller(self): - cmd = shlex.split('juju list-controllers --format yaml') - output = subprocess.check_output(cmd) - output = yaml.safe_load(output) - return output.get('current-controller', '') - - def current_model(self, controller_name=None): - if not controller_name: - controller_name = self.current_controller() - models = self.models()[controller_name] - if 'current-model' not in models: - raise JujuError('No current model') - return models['current-model'] - - def controllers(self): - return self._load_yaml('controllers.yaml', 'controllers') - - def models(self): - return self._load_yaml('models.yaml', 'controllers') - - def accounts(self): - return self._load_yaml('accounts.yaml', 'controllers') - - def credentials(self): - return self._load_yaml('credentials.yaml', 'credentials') + except errors.JujuAPIError as e: + if e.error_code != 'redirection required': + raise + log.info('Controller requested redirect') + # Fetch additional redirection information now so that + # we can safely close the connection after login + # fails. + redirect_info = (await self.rpc({ + "type": "Admin", + "request": "RedirectInfo", + "version": 3, + }))['response'] + raise errors.JujuRedirectException(redirect_info) from e - def load_credential(self, cloud, name=None): - """Load a local credential. - :param str cloud: Name of cloud to load credentials from. - :param str name: Name of credential. If None, the default credential - will be used, if available. - :returns: A CloudCredential instance, or None. - """ - try: - cloud = tag.untag('cloud-', cloud) - creds_data = self.credentials()[cloud] - if not name: - default_credential = creds_data.pop('default-credential', None) - default_region = creds_data.pop('default-region', None) # noqa - if default_credential: - name = creds_data['default-credential'] - elif len(creds_data) == 1: - name = list(creds_data)[0] - else: - return None, None - cred_data = creds_data[name] - auth_type = cred_data.pop('auth-type') - return name, client.CloudCredential( - auth_type=auth_type, - attrs=cred_data, - ) - except (KeyError, FileNotFoundError): - return None, None - - def _load_yaml(self, filename, key): - filepath = os.path.join(self.path, filename) - with io.open(filepath, 'rt') as f: - return yaml.safe_load(f)[key] - - -def get_macaroons(controller_name=None): - """Decode and return macaroons from default ~/.go-cookies +class _Task: + def __init__(self, task, loop): + self.stopped = asyncio.Event(loop=loop) + self.stopped.set() + self.task = task + self.loop = loop - """ - cookie_files = [] - if controller_name: - cookie_files.append('~/.local/share/juju/cookies/{}.json'.format( - controller_name)) - cookie_files.append('~/.go-cookies') - for cookie_file in cookie_files: - cookie_file = Path(cookie_file).expanduser() - if cookie_file.exists(): + def start(self): + async def run(): try: - cookies = json.loads(cookie_file.read_text()) - break - except (OSError, ValueError): - log.warn("Couldn't load macaroons from %s", cookie_file) - return [] - else: - log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files)) - return [] - - base64_macaroons = [ - c['Value'] for c in cookies - if c['Name'].startswith('macaroon-') and c['Value'] - ] - - return [ - json.loads(base64.b64decode(value).decode('utf-8')) - for value in base64_macaroons - ] + return await(self.task()) + finally: + self.stopped.set() + self.stopped.clear() + self.loop.create_task(run()) + + +def _macaroons_for_domain(cookies, domain): + '''Return any macaroons from the given cookie jar that + apply to the given domain name.''' + req = urllib.request.Request('https://' + domain + '/') + cookies.add_cookie_header(req) + return httpbakery.extract_macaroons(req)