X-Git-Url: https://osm.etsi.org/gitweb/?p=osm%2FN2VC.git;a=blobdiff_plain;f=modules%2Flibjuju%2Fjuju%2Fclient%2Fconnection.py;fp=modules%2Flibjuju%2Fjuju%2Fclient%2Fconnection.py;h=7457391877e481c1c237e3bdc8c68a8e6cdadd65;hp=0000000000000000000000000000000000000000;hb=68858c1915122c2dbc8999a5cd3229694abf5f3a;hpb=032a71b2a6692b8b4e30f629a1f906d246f06736 diff --git a/modules/libjuju/juju/client/connection.py b/modules/libjuju/juju/client/connection.py new file mode 100644 index 0000000..7457391 --- /dev/null +++ b/modules/libjuju/juju/client/connection.py @@ -0,0 +1,623 @@ +import base64 +import io +import json +import logging +import os +import random +import shlex +import ssl +import string +import subprocess +import weakref +import websockets +from concurrent.futures import CancelledError +from http.client import HTTPSConnection +from pathlib import Path + +import asyncio +import yaml + +from juju import tag, utils +from juju.client import client +from juju.errors import JujuError, JujuAPIError, JujuConnectionError +from juju.utils import IdQueue + +log = logging.getLogger("websocket") + + +class Monitor: + """ + Monitor helper class for our Connection class. + + Contains a reference to an instantiated Connection, along with a + reference to the Connection.receiver Future. Upon inspecttion of + these objects, this class determines whether the connection is in + an 'error', 'connected' or 'disconnected' state. + + Use this class to stay up to date on the health of a connection, + and take appropriate action if the connection errors out due to + network issues or other unexpected circumstances. + + """ + ERROR = 'error' + CONNECTED = 'connected' + DISCONNECTING = 'disconnecting' + DISCONNECTED = 'disconnected' + + def __init__(self, connection): + self.connection = weakref.ref(connection) + self.reconnecting = asyncio.Lock(loop=connection.loop) + self.close_called = asyncio.Event(loop=connection.loop) + self.receiver_stopped = asyncio.Event(loop=connection.loop) + self.pinger_stopped = asyncio.Event(loop=connection.loop) + self.receiver_stopped.set() + self.pinger_stopped.set() + + @property + def status(self): + """ + Determine the status of the connection and receiver, and return + ERROR, CONNECTED, or DISCONNECTED as appropriate. + + For simplicity, we only consider ourselves to be connected + after the Connection class has setup a receiver task. This + only happens after the websocket is open, and the connection + isn't usable until that receiver has been started. + + """ + connection = self.connection() + + # the connection instance was destroyed but someone kept + # a separate reference to the monitor for some reason + if not connection: + return self.DISCONNECTED + + # connection cleanly disconnected or not yet opened + if not connection.ws: + return self.DISCONNECTED + + # close called but not yet complete + if self.close_called.is_set(): + return self.DISCONNECTING + + # connection closed uncleanly (we didn't call connection.close) + if self.receiver_stopped.is_set() or not connection.ws.open: + return self.ERROR + + # everything is fine! + return self.CONNECTED + + +class Connection: + """ + Usage:: + + # Connect to an arbitrary api server + client = await Connection.connect( + api_endpoint, model_uuid, username, password, cacert) + + # Connect using a controller/model name + client = await Connection.connect_model('local.local:default') + + # Connect to the currently active model + client = await Connection.connect_current() + + Note: Any connection method or constructor can accept an optional `loop` + argument to override the default event loop from `asyncio.get_event_loop`. + """ + + DEFAULT_FRAME_SIZE = 'default_frame_size' + MAX_FRAME_SIZE = 2**22 + "Maximum size for a single frame. Defaults to 4MB." + + def __init__( + self, endpoint, uuid, username, password, cacert=None, + macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE): + self.endpoint = endpoint + self._endpoint = endpoint + self.uuid = uuid + if macaroons: + self.macaroons = macaroons + self.username = '' + self.password = '' + else: + self.macaroons = [] + self.username = username + self.password = password + self.cacert = cacert + self._cacert = cacert + self.loop = loop or asyncio.get_event_loop() + + self.__request_id__ = 0 + self.addr = None + self.ws = None + self.facades = {} + self.messages = IdQueue(loop=self.loop) + self.monitor = Monitor(connection=self) + if max_frame_size is self.DEFAULT_FRAME_SIZE: + max_frame_size = self.MAX_FRAME_SIZE + self.max_frame_size = max_frame_size + + @property + def is_open(self): + return self.monitor.status == Monitor.CONNECTED + + def _get_ssl(self, cert=None): + return ssl.create_default_context( + purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert) + + async def open(self): + if self.uuid: + url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid) + else: + url = "wss://{}/api".format(self.endpoint) + + kw = dict() + kw['ssl'] = self._get_ssl(self.cacert) + kw['loop'] = self.loop + kw['max_size'] = self.max_frame_size + self.addr = url + self.ws = await websockets.connect(url, **kw) + self.loop.create_task(self.receiver()) + self.monitor.receiver_stopped.clear() + log.info("Driver connected to juju %s", url) + self.monitor.close_called.clear() + return self + + async def close(self): + if not self.ws: + return + self.monitor.close_called.set() + await self.monitor.pinger_stopped.wait() + await self.monitor.receiver_stopped.wait() + await self.ws.close() + self.ws = None + + async def recv(self, request_id): + if not self.is_open: + raise websockets.exceptions.ConnectionClosed(0, 'websocket closed') + return await self.messages.get(request_id) + + async def receiver(self): + try: + while self.is_open: + result = await utils.run_with_interrupt( + self.ws.recv(), + self.monitor.close_called, + loop=self.loop) + if self.monitor.close_called.is_set(): + break + if result is not None: + result = json.loads(result) + await self.messages.put(result['request-id'], result) + except CancelledError: + pass + except websockets.ConnectionClosed as e: + log.warning('Receiver: Connection closed, reconnecting') + await self.messages.put_all(e) + # the reconnect has to be done as a task because the receiver will + # be cancelled by the reconnect and we don't want the reconnect + # to be aborted half-way through + self.loop.create_task(self.reconnect()) + return + except Exception as e: + log.exception("Error in receiver") + # make pending listeners aware of the error + await self.messages.put_all(e) + raise + finally: + self.monitor.receiver_stopped.set() + + async def pinger(self): + ''' + A Controller can time us out if we are silent for too long. This + is especially true in JaaS, which has a fairly strict timeout. + + To prevent timing out, we send a ping every ten seconds. + + ''' + async def _do_ping(): + try: + await pinger_facade.Ping() + await asyncio.sleep(10, loop=self.loop) + except CancelledError: + pass + + pinger_facade = client.PingerFacade.from_connection(self) + try: + while True: + await utils.run_with_interrupt( + _do_ping(), + self.monitor.close_called, + loop=self.loop) + if self.monitor.close_called.is_set(): + break + finally: + self.monitor.pinger_stopped.set() + return + + async def rpc(self, msg, encoder=None): + self.__request_id__ += 1 + msg['request-id'] = self.__request_id__ + if'params' not in msg: + msg['params'] = {} + if "version" not in msg: + msg['version'] = self.facades[msg['type']] + outgoing = json.dumps(msg, indent=2, cls=encoder) + for attempt in range(3): + try: + await self.ws.send(outgoing) + break + except websockets.ConnectionClosed: + if attempt == 2: + raise + log.warning('RPC: Connection closed, reconnecting') + # the reconnect has to be done in a separate task because, + # if it is triggered by the pinger, then this RPC call will + # be cancelled when the pinger is cancelled by the reconnect, + # and we don't want the reconnect to be aborted halfway through + await asyncio.wait([self.reconnect()], loop=self.loop) + result = await self.recv(msg['request-id']) + + if not result: + return result + + if 'error' in result: + # API Error Response + raise JujuAPIError(result) + + if 'response' not in result: + # This may never happen + return result + + if 'results' in result['response']: + # Check for errors in a result list. + errors = [] + for res in result['response']['results']: + if res.get('error', {}).get('message'): + errors.append(res['error']['message']) + if errors: + raise JujuError(errors) + + elif result['response'].get('error', {}).get('message'): + raise JujuError(result['response']['error']['message']) + + return result + + def http_headers(self): + """Return dictionary of http headers necessary for making an http + connection to the endpoint of this Connection. + + :return: Dictionary of headers + + """ + if not self.username: + return {} + + creds = u'{}:{}'.format( + tag.user(self.username), + self.password or '' + ) + token = base64.b64encode(creds.encode()) + return { + 'Authorization': 'Basic {}'.format(token.decode()) + } + + def https_connection(self): + """Return an https connection to this Connection's endpoint. + + Returns a 3-tuple containing:: + + 1. The :class:`HTTPSConnection` instance + 2. Dictionary of auth headers to be used with the connection + 3. The root url path (str) to be used for requests. + + """ + endpoint = self.endpoint + host, remainder = endpoint.split(':', 1) + port = remainder + if '/' in remainder: + port, _ = remainder.split('/', 1) + + conn = HTTPSConnection( + host, int(port), + context=self._get_ssl(self.cacert), + ) + + path = ( + "/model/{}".format(self.uuid) + if self.uuid else "" + ) + return conn, self.http_headers(), path + + async def clone(self): + """Return a new Connection, connected to the same websocket endpoint + as this one. + + """ + return await Connection.connect( + self.endpoint, + self.uuid, + self.username, + self.password, + self.cacert, + self.macaroons, + self.loop, + self.max_frame_size, + ) + + async def controller(self): + """Return a Connection to the controller at self.endpoint + + """ + return await Connection.connect( + self.endpoint, + None, + self.username, + self.password, + self.cacert, + self.macaroons, + self.loop, + ) + + async def _try_endpoint(self, endpoint, cacert): + success = False + result = None + new_endpoints = [] + + self.endpoint = endpoint + self.cacert = cacert + await self.open() + try: + result = await self.login() + if 'discharge-required-error' in result['response']: + log.info('Macaroon discharge required, disconnecting') + else: + # successful login! + log.info('Authenticated') + success = True + except JujuAPIError as e: + if e.error_code != 'redirection required': + raise + log.info('Controller requested redirect') + redirect_info = await self.redirect_info() + redir_cacert = redirect_info['ca-cert'] + new_endpoints = [ + ("{value}:{port}".format(**s), redir_cacert) + for servers in redirect_info['servers'] + for s in servers if s["scope"] == 'public' + ] + finally: + if not success: + await self.close() + return success, result, new_endpoints + + async def reconnect(self): + """ Force a reconnection. + """ + monitor = self.monitor + if monitor.reconnecting.locked() or monitor.close_called.is_set(): + return + async with monitor.reconnecting: + await self.close() + await self._connect() + + async def _connect(self): + endpoints = [(self._endpoint, self._cacert)] + while endpoints: + _endpoint, _cacert = endpoints.pop(0) + success, result, new_endpoints = await self._try_endpoint( + _endpoint, _cacert) + if success: + break + endpoints.extend(new_endpoints) + else: + # ran out of endpoints without a successful login + raise Exception("Couldn't authenticate to {}".format( + self._endpoint)) + + response = result['response'] + self.info = response.copy() + self.build_facades(response.get('facades', {})) + self.loop.create_task(self.pinger()) + self.monitor.pinger_stopped.clear() + + @classmethod + async def connect( + cls, endpoint, uuid, username, password, cacert=None, + macaroons=None, loop=None, max_frame_size=None): + """Connect to the websocket. + + If uuid is None, the connection will be to the controller. Otherwise it + will be to the model. + + """ + client = cls(endpoint, uuid, username, password, cacert, macaroons, + loop, max_frame_size) + await client._connect() + return client + + @classmethod + async def connect_current(cls, loop=None, max_frame_size=None): + """Connect to the currently active model. + + """ + jujudata = JujuData() + + controller_name = jujudata.current_controller() + if not controller_name: + raise JujuConnectionError('No current controller') + + model_name = jujudata.current_model() + + return await cls.connect_model( + '{}:{}'.format(controller_name, model_name), loop, max_frame_size) + + @classmethod + async def connect_current_controller(cls, loop=None, max_frame_size=None): + """Connect to the currently active controller. + + """ + jujudata = JujuData() + controller_name = jujudata.current_controller() + if not controller_name: + raise JujuConnectionError('No current controller') + + return await cls.connect_controller(controller_name, loop, + max_frame_size) + + @classmethod + async def connect_controller(cls, controller_name, loop=None, + max_frame_size=None): + """Connect to a controller by name. + + """ + jujudata = JujuData() + controller = jujudata.controllers()[controller_name] + endpoint = controller['api-endpoints'][0] + cacert = controller.get('ca-cert') + accounts = jujudata.accounts()[controller_name] + username = accounts['user'] + password = accounts.get('password') + macaroons = get_macaroons(controller_name) if not password else None + + return await cls.connect( + endpoint, None, username, password, cacert, macaroons, loop, + max_frame_size) + + @classmethod + async def connect_model(cls, model, loop=None, max_frame_size=None): + """Connect to a model by name. + + :param str model: [:] + + """ + jujudata = JujuData() + + if ':' in model: + # explicit controller given + controller_name, model_name = model.split(':') + else: + # use the current controller if one isn't explicitly given + controller_name = jujudata.current_controller() + model_name = model + + accounts = jujudata.accounts()[controller_name] + username = accounts['user'] + # model name must include a user prefix, so add it if it doesn't + if '/' not in model_name: + model_name = '{}/{}'.format(username, model_name) + + controller = jujudata.controllers()[controller_name] + endpoint = controller['api-endpoints'][0] + cacert = controller.get('ca-cert') + password = accounts.get('password') + models = jujudata.models()[controller_name] + model_uuid = models['models'][model_name]['uuid'] + macaroons = get_macaroons(controller_name) if not password else None + + return await cls.connect( + endpoint, model_uuid, username, password, cacert, macaroons, loop, + max_frame_size) + + def build_facades(self, facades): + self.facades.clear() + for facade in facades: + self.facades[facade['name']] = facade['versions'][-1] + + async def login(self): + username = self.username + if username and not username.startswith('user-'): + username = 'user-{}'.format(username) + + result = await self.rpc({ + "type": "Admin", + "request": "Login", + "version": 3, + "params": { + "auth-tag": username, + "credentials": self.password, + "nonce": "".join(random.sample(string.printable, 12)), + "macaroons": self.macaroons + }}) + return result + + async def redirect_info(self): + try: + result = await self.rpc({ + "type": "Admin", + "request": "RedirectInfo", + "version": 3, + }) + except JujuAPIError as e: + if e.message == 'not redirected': + return None + raise + return result['response'] + + +class JujuData: + def __init__(self): + self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju' + self.path = os.path.abspath(os.path.expanduser(self.path)) + + def current_controller(self): + cmd = shlex.split('juju list-controllers --format yaml') + output = subprocess.check_output(cmd) + output = yaml.safe_load(output) + return output.get('current-controller', '') + + def current_model(self, controller_name=None): + if not controller_name: + controller_name = self.current_controller() + models = self.models()[controller_name] + if 'current-model' not in models: + raise JujuError('No current model') + return models['current-model'] + + def controllers(self): + return self._load_yaml('controllers.yaml', 'controllers') + + def models(self): + return self._load_yaml('models.yaml', 'controllers') + + def accounts(self): + return self._load_yaml('accounts.yaml', 'controllers') + + def _load_yaml(self, filename, key): + filepath = os.path.join(self.path, filename) + with io.open(filepath, 'rt') as f: + return yaml.safe_load(f)[key] + + +def get_macaroons(controller_name=None): + """Decode and return macaroons from default ~/.go-cookies + + """ + cookie_files = [] + if controller_name: + cookie_files.append('~/.local/share/juju/cookies/{}.json'.format( + controller_name)) + cookie_files.append('~/.go-cookies') + for cookie_file in cookie_files: + cookie_file = Path(cookie_file).expanduser() + if cookie_file.exists(): + try: + cookies = json.loads(cookie_file.read_text()) + break + except (OSError, ValueError): + log.warn("Couldn't load macaroons from %s", cookie_file) + return [] + else: + log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files)) + return [] + + base64_macaroons = [ + c['Value'] for c in cookies + if c['Name'].startswith('macaroon-') and c['Value'] + ] + + return [ + json.loads(base64.b64decode(value).decode('utf-8')) + for value in base64_macaroons + ]