From 054805953c7c8ed11ce341adf55bf6e19e589fbe Mon Sep 17 00:00:00 2001 From: Cory Johns Date: Thu, 27 Apr 2017 16:27:32 -0400 Subject: [PATCH] Refactor connection task management to avoid cancels (#117) Connections in libjuju require a few tasks to help manage the data flowing through it. The model AllWatcher task, the connection packet receiver, and the connection keep-alive pinger. Due to how these connections were being managed, they had to be tracked and cancelled externally, leading to exceptions being masked and the need to shield AllWatcher notifier callbacks, making them entirely uncancellable. This refactor cleans that up by using events to track when the connection is being shutdown and cleanly stopping the tasks internally. --- juju/client/connection.py | 102 ++++++++++++++++++++-------------- juju/model.py | 48 +++++++--------- juju/utils.py | 27 +++++++++ tests/unit/test_connection.py | 17 +++--- 4 files changed, 117 insertions(+), 77 deletions(-) diff --git a/juju/client/connection.py b/juju/client/connection.py index 2be360f..c2c6b2d 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -9,12 +9,13 @@ import ssl import string import subprocess import websockets +from concurrent.futures import CancelledError from http.client import HTTPSConnection import asyncio import yaml -from juju import tag +from juju import tag, utils from juju.client import client from juju.client.version_map import VERSION_MAP from juju.errors import JujuError, JujuAPIError, JujuConnectionError @@ -44,8 +45,11 @@ class Monitor: def __init__(self, connection): self.connection = connection - self.receiver = None - self.pinger = None + self.close_called = asyncio.Event(loop=self.connection.loop) + self.receiver_stopped = asyncio.Event(loop=self.connection.loop) + self.pinger_stopped = asyncio.Event(loop=self.connection.loop) + self.receiver_stopped.set() + self.pinger_stopped.set() @property def status(self): @@ -63,21 +67,21 @@ class Monitor: # DISCONNECTED: connection not yet open if not self.connection.ws: return self.DISCONNECTED - if not self.receiver: + if self.receiver_stopped.is_set(): return self.DISCONNECTED # ERROR: Connection closed (or errored), but we didn't call # connection.close - if not self.connection.close_called and self.receiver_exceptions(): + if not self.close_called.is_set() and self.receiver_stopped.is_set(): return self.ERROR - if not self.connection.close_called and not self.connection.ws.open: - # The check for self.receiver existing above guards against the - # case where we're not open because we simply haven't - # setup the connection yet. + if not self.close_called.is_set() and not self.connection.ws.open: + # The check for self.receiver_stopped existing above guards + # against the case where we're not open because we simply + # haven't setup the connection yet. return self.ERROR # DISCONNECTED: cleanly disconnected. - if self.connection.close_called and not self.connection.ws.open: + if self.close_called.is_set() and not self.connection.ws.open: return self.DISCONNECTED # CONNECTED: everything is fine! @@ -89,17 +93,6 @@ class Monitor: # know what state the connection is in. return self.UNKNOWN - def receiver_exceptions(self): - """ - Return exceptions in the receiver, if any. - - """ - if not self.receiver: - return None - if not self.receiver.done(): - return None - return self.receiver.exception() - class Connection: """ @@ -139,7 +132,6 @@ class Connection: self.ws = None self.facades = {} self.messages = IdQueue(loop=self.loop) - self.close_called = False self.monitor = Monitor(connection=self) @property @@ -163,19 +155,18 @@ class Connection: kw['loop'] = self.loop self.addr = url self.ws = await websockets.connect(url, **kw) - self.monitor.receiver = self.loop.create_task(self.receiver()) + self.loop.create_task(self.receiver()) + self.monitor.receiver_stopped.clear() log.info("Driver connected to juju %s", url) + self.monitor.close_called.clear() return self async def close(self): if not self.is_open: return - self.close_called = True - if self.monitor.pinger: - # might be closing due to login failure, - # in which case we won't have a pinger yet - self.monitor.pinger.cancel() - self.monitor.receiver.cancel() + self.monitor.close_called.set() + await self.monitor.pinger_stopped.wait() + await self.monitor.receiver_stopped.wait() await self.ws.close() async def recv(self, request_id): @@ -184,19 +175,29 @@ class Connection: return await self.messages.get(request_id) async def receiver(self): - while self.is_open: - try: - result = await self.ws.recv() + try: + while self.is_open: + result = await utils.run_with_interrupt( + self.ws.recv(), + self.monitor.close_called, + loop=self.loop) + if self.monitor.close_called.is_set(): + break if result is not None: result = json.loads(result) await self.messages.put(result['request-id'], result) - except Exception as e: - await self.messages.put_all(e) - if isinstance(e, websockets.ConnectionClosed): - # ConnectionClosed is not really exceptional for us, - # but it may be for any pending message listeners - return - raise + except CancelledError: + pass + except Exception as e: + await self.messages.put_all(e) + if isinstance(e, websockets.ConnectionClosed): + # ConnectionClosed is not really exceptional for us, + # but it may be for any pending message listeners + return + log.exception("Error in receiver") + raise + finally: + self.monitor.receiver_stopped.set() async def pinger(self): ''' @@ -206,10 +207,24 @@ class Connection: To prevent timing out, we send a ping every ten seconds. ''' + async def _do_ping(): + try: + await pinger_facade.Ping() + await asyncio.sleep(10, loop=self.loop) + except CancelledError: + pass + pinger_facade = client.PingerFacade.from_connection(self) - while self.is_open: - await pinger_facade.Ping() - await asyncio.sleep(10, loop=self.loop) + try: + while self.is_open: + await utils.run_with_interrupt( + _do_ping(), + self.monitor.close_called, + loop=self.loop) + if self.monitor.close_called.is_set(): + break + finally: + self.monitor.pinger_stopped.set() async def rpc(self, msg, encoder=None): self.__request_id__ += 1 @@ -381,7 +396,8 @@ class Connection: response = result['response'] client.info = response.copy() client.build_facades(response.get('facades', {})) - client.monitor.pinger = client.loop.create_task(client.pinger()) + client.loop.create_task(client.pinger()) + client.monitor.pinger_stopped.clear() return client diff --git a/juju/model.py b/juju/model.py index 3ed8fa7..e024c65 100644 --- a/juju/model.py +++ b/juju/model.py @@ -18,7 +18,7 @@ import yaml import theblues.charmstore import theblues.errors -from . import tag +from . import tag, utils from .client import client from .client import connection from .constraints import parse as parse_constraints, normalize_key @@ -385,8 +385,8 @@ class Model(object): self.observers = weakref.WeakValueDictionary() self.state = ModelState(self) self.info = None - self._watcher_task = None - self._watch_shutdown = asyncio.Event(loop=self.loop) + self._watch_stopping = asyncio.Event(loop=self.loop) + self._watch_stopped = asyncio.Event(loop=self.loop) self._watch_received = asyncio.Event(loop=self.loop) self._charmstore = CharmStore(self.loop) @@ -431,9 +431,10 @@ class Model(object): """Shut down the watcher task and close websockets. """ - self._stop_watching() if self.connection and self.connection.is_open: - await self._watch_shutdown.wait() + log.debug('Stopping watcher task') + self._watch_stopping.set() + await self._watch_stopped.wait() log.debug('Closing model connection') await self.connection.close() self.connection = None @@ -619,41 +620,34 @@ class Model(object): """ async def _start_watch(): - self._watch_shutdown.clear() try: allwatcher = client.AllWatcherFacade.from_connection( self.connection) - while True: - results = await allwatcher.Next() + while not self._watch_stopping.is_set(): + results = await utils.run_with_interrupt( + allwatcher.Next(), + self._watch_stopping, + self.loop) + if self._watch_stopping.is_set(): + break for delta in results.deltas: delta = get_entity_delta(delta) old_obj, new_obj = self.state.apply_delta(delta) - # XXX: Might not want to shield at this level - # We are shielding because when the watcher is - # canceled (on disconnect()), we don't want all of - # its children (every observer callback) to be - # canceled with it. So we shield them. But this means - # they can *never* be canceled. - await asyncio.shield( - self._notify_observers(delta, old_obj, new_obj), - loop=self.loop) + await self._notify_observers(delta, old_obj, new_obj) self._watch_received.set() except CancelledError: - self._watch_shutdown.set() + pass except Exception: log.exception('Error in watcher') raise + finally: + self._watch_stopped.set() log.debug('Starting watcher task') - self._watcher_task = self.loop.create_task(_start_watch()) - - def _stop_watching(self): - """Stop the asynchronous watch against this model. - - """ - log.debug('Stopping watcher task') - if self._watcher_task: - self._watcher_task.cancel() + self._watch_received.clear() + self._watch_stopping.clear() + self._watch_stopped.clear() + self.loop.create_task(_start_watch()) async def _notify_observers(self, delta, old_obj, new_obj): """Call observing callbacks, notifying them of a change in model state diff --git a/juju/utils.py b/juju/utils.py index f4db66e..1d1b24e 100644 --- a/juju/utils.py +++ b/juju/utils.py @@ -69,3 +69,30 @@ class IdQueue: async def put_all(self, value): for queue in self._queues.values(): await queue.put(value) + + +async def run_with_interrupt(task, event, loop=None): + """ + Awaits a task while allowing it to be interrupted by an `asyncio.Event`. + + If the task finishes without the event becoming set, the results of the + task will be returned. If the event becomes set, the task will be + cancelled ``None`` will be returned. + + :param task: Task to run + :param event: An `asyncio.Event` which, if set, will interrupt `task` + and cause it to be cancelled. + :param loop: Optional event loop to use other than the default. + """ + loop = loop or asyncio.get_event_loop() + event_task = loop.create_task(event.wait()) + done, pending = await asyncio.wait([task, event_task], + loop=loop, + return_when=asyncio.FIRST_COMPLETED) + for f in pending: + f.cancel() + result = [f.result() for f in done if f is not event_task] + if result: + return result[0] + else: + return None diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 5371fdb..340264e 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -20,7 +20,7 @@ class WebsocketMock: async def recv(self): if not self.responses: - raise ConnectionClosed(0, 'no reason') + raise ConnectionClosed(0, 'ran out of responses') return json.dumps(self.responses.popleft()) async def close(self): @@ -41,9 +41,12 @@ async def test_out_of_order(event_loop): {'request-id': 3}, ] con._get_sll = mock.MagicMock() - with mock.patch('websockets.connect', base.AsyncMock(return_value=ws)): - await con.open() - actual_responses = [] - for i in range(3): - actual_responses.append(await con.rpc({'version': 1})) - assert actual_responses == expected_responses + try: + with mock.patch('websockets.connect', base.AsyncMock(return_value=ws)): + await con.open() + actual_responses = [] + for i in range(3): + actual_responses.append(await con.rpc({'version': 1})) + assert actual_responses == expected_responses + finally: + await con.close() -- 2.25.1