X-Git-Url: https://osm.etsi.org/gitweb/?a=blobdiff_plain;f=juju%2Fclient%2Fconnection.py;h=7457391877e481c1c237e3bdc8c68a8e6cdadd65;hb=cd48f185bcb9279a96d3ee85579d96ac10d12dd9;hp=2be360fe0d9306cfb298ea93b696273f67042eff;hpb=7c2a530853c95b8a3518f6db0870f94858f87c27;p=osm%2FN2VC.git diff --git a/juju/client/connection.py b/juju/client/connection.py index 2be360f..7457391 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -8,15 +8,17 @@ import shlex import ssl import string import subprocess +import weakref import websockets +from concurrent.futures import CancelledError from http.client import HTTPSConnection +from pathlib import Path import asyncio import yaml -from juju import tag +from juju import tag, utils from juju.client import client -from juju.client.version_map import VERSION_MAP from juju.errors import JujuError, JujuAPIError, JujuConnectionError from juju.utils import IdQueue @@ -39,13 +41,17 @@ class Monitor: """ ERROR = 'error' CONNECTED = 'connected' + DISCONNECTING = 'disconnecting' DISCONNECTED = 'disconnected' - UNKNOWN = 'unknown' def __init__(self, connection): - self.connection = connection - self.receiver = None - self.pinger = None + self.connection = weakref.ref(connection) + self.reconnecting = asyncio.Lock(loop=connection.loop) + self.close_called = asyncio.Event(loop=connection.loop) + self.receiver_stopped = asyncio.Event(loop=connection.loop) + self.pinger_stopped = asyncio.Event(loop=connection.loop) + self.receiver_stopped.set() + self.pinger_stopped.set() @property def status(self): @@ -59,46 +65,27 @@ class Monitor: isn't usable until that receiver has been started. """ + connection = self.connection() - # DISCONNECTED: connection not yet open - if not self.connection.ws: + # the connection instance was destroyed but someone kept + # a separate reference to the monitor for some reason + if not connection: return self.DISCONNECTED - if not self.receiver: - return self.DISCONNECTED - - # ERROR: Connection closed (or errored), but we didn't call - # connection.close - if not self.connection.close_called and self.receiver_exceptions(): - return self.ERROR - if not self.connection.close_called and not self.connection.ws.open: - # The check for self.receiver existing above guards against the - # case where we're not open because we simply haven't - # setup the connection yet. - return self.ERROR - # DISCONNECTED: cleanly disconnected. - if self.connection.close_called and not self.connection.ws.open: + # connection cleanly disconnected or not yet opened + if not connection.ws: return self.DISCONNECTED - # CONNECTED: everything is fine! - if self.connection.ws.open: - return self.CONNECTED + # close called but not yet complete + if self.close_called.is_set(): + return self.DISCONNECTING - # UNKNOWN: We should never hit this state -- if we do, - # something went wrong with the logic above, and we do not - # know what state the connection is in. - return self.UNKNOWN - - def receiver_exceptions(self): - """ - Return exceptions in the receiver, if any. + # connection closed uncleanly (we didn't call connection.close) + if self.receiver_stopped.is_set() or not connection.ws.open: + return self.ERROR - """ - if not self.receiver: - return None - if not self.receiver.done(): - return None - return self.receiver.exception() + # everything is fine! + return self.CONNECTED class Connection: @@ -118,10 +105,16 @@ class Connection: Note: Any connection method or constructor can accept an optional `loop` argument to override the default event loop from `asyncio.get_event_loop`. """ + + DEFAULT_FRAME_SIZE = 'default_frame_size' + MAX_FRAME_SIZE = 2**22 + "Maximum size for a single frame. Defaults to 4MB." + def __init__( self, endpoint, uuid, username, password, cacert=None, - macaroons=None, loop=None): + macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE): self.endpoint = endpoint + self._endpoint = endpoint self.uuid = uuid if macaroons: self.macaroons = macaroons @@ -132,6 +125,7 @@ class Connection: self.username = username self.password = password self.cacert = cacert + self._cacert = cacert self.loop = loop or asyncio.get_event_loop() self.__request_id__ = 0 @@ -139,14 +133,14 @@ class Connection: self.ws = None self.facades = {} self.messages = IdQueue(loop=self.loop) - self.close_called = False self.monitor = Monitor(connection=self) + if max_frame_size is self.DEFAULT_FRAME_SIZE: + max_frame_size = self.MAX_FRAME_SIZE + self.max_frame_size = max_frame_size @property def is_open(self): - if self.ws: - return self.ws.open - return False + return self.monitor.status == Monitor.CONNECTED def _get_ssl(self, cert=None): return ssl.create_default_context( @@ -161,22 +155,23 @@ class Connection: kw = dict() kw['ssl'] = self._get_ssl(self.cacert) kw['loop'] = self.loop + kw['max_size'] = self.max_frame_size self.addr = url self.ws = await websockets.connect(url, **kw) - self.monitor.receiver = self.loop.create_task(self.receiver()) + self.loop.create_task(self.receiver()) + self.monitor.receiver_stopped.clear() log.info("Driver connected to juju %s", url) + self.monitor.close_called.clear() return self async def close(self): - if not self.is_open: + if not self.ws: return - self.close_called = True - if self.monitor.pinger: - # might be closing due to login failure, - # in which case we won't have a pinger yet - self.monitor.pinger.cancel() - self.monitor.receiver.cancel() + self.monitor.close_called.set() + await self.monitor.pinger_stopped.wait() + await self.monitor.receiver_stopped.wait() await self.ws.close() + self.ws = None async def recv(self, request_id): if not self.is_open: @@ -184,19 +179,34 @@ class Connection: return await self.messages.get(request_id) async def receiver(self): - while self.is_open: - try: - result = await self.ws.recv() + try: + while self.is_open: + result = await utils.run_with_interrupt( + self.ws.recv(), + self.monitor.close_called, + loop=self.loop) + if self.monitor.close_called.is_set(): + break if result is not None: result = json.loads(result) await self.messages.put(result['request-id'], result) - except Exception as e: - await self.messages.put_all(e) - if isinstance(e, websockets.ConnectionClosed): - # ConnectionClosed is not really exceptional for us, - # but it may be for any pending message listeners - return - raise + except CancelledError: + pass + except websockets.ConnectionClosed as e: + log.warning('Receiver: Connection closed, reconnecting') + await self.messages.put_all(e) + # the reconnect has to be done as a task because the receiver will + # be cancelled by the reconnect and we don't want the reconnect + # to be aborted half-way through + self.loop.create_task(self.reconnect()) + return + except Exception as e: + log.exception("Error in receiver") + # make pending listeners aware of the error + await self.messages.put_all(e) + raise + finally: + self.monitor.receiver_stopped.set() async def pinger(self): ''' @@ -206,10 +216,25 @@ class Connection: To prevent timing out, we send a ping every ten seconds. ''' + async def _do_ping(): + try: + await pinger_facade.Ping() + await asyncio.sleep(10, loop=self.loop) + except CancelledError: + pass + pinger_facade = client.PingerFacade.from_connection(self) - while self.is_open: - await pinger_facade.Ping() - await asyncio.sleep(10, loop=self.loop) + try: + while True: + await utils.run_with_interrupt( + _do_ping(), + self.monitor.close_called, + loop=self.loop) + if self.monitor.close_called.is_set(): + break + finally: + self.monitor.pinger_stopped.set() + return async def rpc(self, msg, encoder=None): self.__request_id__ += 1 @@ -219,7 +244,19 @@ class Connection: if "version" not in msg: msg['version'] = self.facades[msg['type']] outgoing = json.dumps(msg, indent=2, cls=encoder) - await self.ws.send(outgoing) + for attempt in range(3): + try: + await self.ws.send(outgoing) + break + except websockets.ConnectionClosed: + if attempt == 2: + raise + log.warning('RPC: Connection closed, reconnecting') + # the reconnect has to be done in a separate task because, + # if it is triggered by the pinger, then this RPC call will + # be cancelled when the pinger is cancelled by the reconnect, + # and we don't want the reconnect to be aborted halfway through + await asyncio.wait([self.reconnect()], loop=self.loop) result = await self.recv(msg['request-id']) if not result: @@ -306,6 +343,7 @@ class Connection: self.cacert, self.macaroons, self.loop, + self.max_frame_size, ) async def controller(self): @@ -354,51 +392,69 @@ class Connection: await self.close() return success, result, new_endpoints - @classmethod - async def connect( - cls, endpoint, uuid, username, password, cacert=None, - macaroons=None, loop=None): - """Connect to the websocket. - - If uuid is None, the connection will be to the controller. Otherwise it - will be to the model. - + async def reconnect(self): + """ Force a reconnection. """ - client = cls(endpoint, uuid, username, password, cacert, macaroons, - loop) - endpoints = [(endpoint, cacert)] + monitor = self.monitor + if monitor.reconnecting.locked() or monitor.close_called.is_set(): + return + async with monitor.reconnecting: + await self.close() + await self._connect() + + async def _connect(self): + endpoints = [(self._endpoint, self._cacert)] while endpoints: _endpoint, _cacert = endpoints.pop(0) - success, result, new_endpoints = await client._try_endpoint( + success, result, new_endpoints = await self._try_endpoint( _endpoint, _cacert) if success: break endpoints.extend(new_endpoints) else: # ran out of endpoints without a successful login - raise Exception("Couldn't authenticate to {}".format(endpoint)) + raise Exception("Couldn't authenticate to {}".format( + self._endpoint)) response = result['response'] - client.info = response.copy() - client.build_facades(response.get('facades', {})) - client.monitor.pinger = client.loop.create_task(client.pinger()) + self.info = response.copy() + self.build_facades(response.get('facades', {})) + self.loop.create_task(self.pinger()) + self.monitor.pinger_stopped.clear() + + @classmethod + async def connect( + cls, endpoint, uuid, username, password, cacert=None, + macaroons=None, loop=None, max_frame_size=None): + """Connect to the websocket. + + If uuid is None, the connection will be to the controller. Otherwise it + will be to the model. + """ + client = cls(endpoint, uuid, username, password, cacert, macaroons, + loop, max_frame_size) + await client._connect() return client @classmethod - async def connect_current(cls, loop=None): + async def connect_current(cls, loop=None, max_frame_size=None): """Connect to the currently active model. """ jujudata = JujuData() + controller_name = jujudata.current_controller() + if not controller_name: + raise JujuConnectionError('No current controller') + model_name = jujudata.current_model() return await cls.connect_model( - '{}:{}'.format(controller_name, model_name), loop) + '{}:{}'.format(controller_name, model_name), loop, max_frame_size) @classmethod - async def connect_current_controller(cls, loop=None): + async def connect_current_controller(cls, loop=None, max_frame_size=None): """Connect to the currently active controller. """ @@ -407,10 +463,12 @@ class Connection: if not controller_name: raise JujuConnectionError('No current controller') - return await cls.connect_controller(controller_name, loop) + return await cls.connect_controller(controller_name, loop, + max_frame_size) @classmethod - async def connect_controller(cls, controller_name, loop=None): + async def connect_controller(cls, controller_name, loop=None, + max_frame_size=None): """Connect to a controller by name. """ @@ -421,13 +479,14 @@ class Connection: accounts = jujudata.accounts()[controller_name] username = accounts['user'] password = accounts.get('password') - macaroons = get_macaroons() if not password else None + macaroons = get_macaroons(controller_name) if not password else None return await cls.connect( - endpoint, None, username, password, cacert, macaroons, loop) + endpoint, None, username, password, cacert, macaroons, loop, + max_frame_size) @classmethod - async def connect_model(cls, model, loop=None): + async def connect_model(cls, model, loop=None, max_frame_size=None): """Connect to a model by name. :param str model: [:] @@ -455,29 +514,16 @@ class Connection: password = accounts.get('password') models = jujudata.models()[controller_name] model_uuid = models['models'][model_name]['uuid'] - macaroons = get_macaroons() if not password else None + macaroons = get_macaroons(controller_name) if not password else None return await cls.connect( - endpoint, model_uuid, username, password, cacert, macaroons, loop) + endpoint, model_uuid, username, password, cacert, macaroons, loop, + max_frame_size) def build_facades(self, facades): self.facades.clear() - # In order to work around an issue where the juju api is not - # returning a complete list of facades, we simply look up the - # juju version in a pregenerated map, and use that info to - # populate our list of facades. - - # TODO: if a future version of juju fixes this bug, restore - # the following code for that version and higher: - # for facade in facades: - # self.facades[facade['name']] = facade['versions'][-1] - try: - self.facades = VERSION_MAP[self.info['server-version']] - except KeyError: - log.warning("Could not find a set of facades for {}. Using " - "the latest facade set instead".format( - self.info['server-version'])) - self.facades = VERSION_MAP['latest'] + for facade in facades: + self.facades[facade['name']] = facade['versions'][-1] async def login(self): username = self.username @@ -544,16 +590,26 @@ class JujuData: return yaml.safe_load(f)[key] -def get_macaroons(): +def get_macaroons(controller_name=None): """Decode and return macaroons from default ~/.go-cookies """ - try: - cookie_file = os.path.expanduser('~/.go-cookies') - with open(cookie_file, 'r') as f: - cookies = json.load(f) - except (OSError, ValueError): - log.warn("Couldn't load macaroons from %s", cookie_file) + cookie_files = [] + if controller_name: + cookie_files.append('~/.local/share/juju/cookies/{}.json'.format( + controller_name)) + cookie_files.append('~/.go-cookies') + for cookie_file in cookie_files: + cookie_file = Path(cookie_file).expanduser() + if cookie_file.exists(): + try: + cookies = json.loads(cookie_file.read_text()) + break + except (OSError, ValueError): + log.warn("Couldn't load macaroons from %s", cookie_file) + return [] + else: + log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files)) return [] base64_macaroons = [