Refactor connection task management to avoid cancels (#117)
authorCory Johns <johnsca@gmail.com>
Thu, 27 Apr 2017 20:27:32 +0000 (16:27 -0400)
committerPete Vander Giessen <petevg@gmail.com>
Thu, 27 Apr 2017 20:27:32 +0000 (16:27 -0400)
Connections in libjuju require a few tasks to help manage the data
flowing through it.  The model AllWatcher task, the connection packet
receiver, and the connection keep-alive pinger.  Due to how these
connections were being managed, they had to be tracked and cancelled
externally, leading to exceptions being masked and the need to shield
AllWatcher notifier callbacks, making them entirely uncancellable.  This
refactor cleans that up by using events to track when the connection is
being shutdown and cleanly stopping the tasks internally.

juju/client/connection.py
juju/model.py
juju/utils.py
tests/unit/test_connection.py

index 2be360f..c2c6b2d 100644 (file)
@@ -9,12 +9,13 @@ import ssl
 import string
 import subprocess
 import websockets
+from concurrent.futures import CancelledError
 from http.client import HTTPSConnection
 
 import asyncio
 import yaml
 
-from juju import tag
+from juju import tag, utils
 from juju.client import client
 from juju.client.version_map import VERSION_MAP
 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
@@ -44,8 +45,11 @@ class Monitor:
 
     def __init__(self, connection):
         self.connection = connection
-        self.receiver = None
-        self.pinger = None
+        self.close_called = asyncio.Event(loop=self.connection.loop)
+        self.receiver_stopped = asyncio.Event(loop=self.connection.loop)
+        self.pinger_stopped = asyncio.Event(loop=self.connection.loop)
+        self.receiver_stopped.set()
+        self.pinger_stopped.set()
 
     @property
     def status(self):
@@ -63,21 +67,21 @@ class Monitor:
         # DISCONNECTED: connection not yet open
         if not self.connection.ws:
             return self.DISCONNECTED
-        if not self.receiver:
+        if self.receiver_stopped.is_set():
             return self.DISCONNECTED
 
         # ERROR: Connection closed (or errored), but we didn't call
         # connection.close
-        if not self.connection.close_called and self.receiver_exceptions():
+        if not self.close_called.is_set() and self.receiver_stopped.is_set():
             return self.ERROR
-        if not self.connection.close_called and not self.connection.ws.open:
-            # The check for self.receiver existing above guards against the
-            # case where we're not open because we simply haven't
-            # setup the connection yet.
+        if not self.close_called.is_set() and not self.connection.ws.open:
+            # The check for self.receiver_stopped existing above guards
+            # against the case where we're not open because we simply
+            # haven't setup the connection yet.
             return self.ERROR
 
         # DISCONNECTED: cleanly disconnected.
-        if self.connection.close_called and not self.connection.ws.open:
+        if self.close_called.is_set() and not self.connection.ws.open:
             return self.DISCONNECTED
 
         # CONNECTED: everything is fine!
@@ -89,17 +93,6 @@ class Monitor:
         # know what state the connection is in.
         return self.UNKNOWN
 
-    def receiver_exceptions(self):
-        """
-        Return exceptions in the receiver, if any.
-
-        """
-        if not self.receiver:
-            return None
-        if not self.receiver.done():
-            return None
-        return self.receiver.exception()
-
 
 class Connection:
     """
@@ -139,7 +132,6 @@ class Connection:
         self.ws = None
         self.facades = {}
         self.messages = IdQueue(loop=self.loop)
-        self.close_called = False
         self.monitor = Monitor(connection=self)
 
     @property
@@ -163,19 +155,18 @@ class Connection:
         kw['loop'] = self.loop
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
-        self.monitor.receiver = self.loop.create_task(self.receiver())
+        self.loop.create_task(self.receiver())
+        self.monitor.receiver_stopped.clear()
         log.info("Driver connected to juju %s", url)
+        self.monitor.close_called.clear()
         return self
 
     async def close(self):
         if not self.is_open:
             return
-        self.close_called = True
-        if self.monitor.pinger:
-            # might be closing due to login failure,
-            # in which case we won't have a pinger yet
-            self.monitor.pinger.cancel()
-        self.monitor.receiver.cancel()
+        self.monitor.close_called.set()
+        await self.monitor.pinger_stopped.wait()
+        await self.monitor.receiver_stopped.wait()
         await self.ws.close()
 
     async def recv(self, request_id):
@@ -184,19 +175,29 @@ class Connection:
         return await self.messages.get(request_id)
 
     async def receiver(self):
-        while self.is_open:
-            try:
-                result = await self.ws.recv()
+        try:
+            while self.is_open:
+                result = await utils.run_with_interrupt(
+                    self.ws.recv(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
                 if result is not None:
                     result = json.loads(result)
                     await self.messages.put(result['request-id'], result)
-            except Exception as e:
-                await self.messages.put_all(e)
-                if isinstance(e, websockets.ConnectionClosed):
-                    # ConnectionClosed is not really exceptional for us,
-                    # but it may be for any pending message listeners
-                    return
-                raise
+        except CancelledError:
+            pass
+        except Exception as e:
+            await self.messages.put_all(e)
+            if isinstance(e, websockets.ConnectionClosed):
+                # ConnectionClosed is not really exceptional for us,
+                # but it may be for any pending message listeners
+                return
+            log.exception("Error in receiver")
+            raise
+        finally:
+            self.monitor.receiver_stopped.set()
 
     async def pinger(self):
         '''
@@ -206,10 +207,24 @@ class Connection:
         To prevent timing out, we send a ping every ten seconds.
 
         '''
+        async def _do_ping():
+            try:
+                await pinger_facade.Ping()
+                await asyncio.sleep(10, loop=self.loop)
+            except CancelledError:
+                pass
+
         pinger_facade = client.PingerFacade.from_connection(self)
-        while self.is_open:
-            await pinger_facade.Ping()
-            await asyncio.sleep(10, loop=self.loop)
+        try:
+            while self.is_open:
+                await utils.run_with_interrupt(
+                    _do_ping(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
+        finally:
+            self.monitor.pinger_stopped.set()
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -381,7 +396,8 @@ class Connection:
         response = result['response']
         client.info = response.copy()
         client.build_facades(response.get('facades', {}))
-        client.monitor.pinger = client.loop.create_task(client.pinger())
+        client.loop.create_task(client.pinger())
+        client.monitor.pinger_stopped.clear()
 
         return client
 
index 3ed8fa7..e024c65 100644 (file)
@@ -18,7 +18,7 @@ import yaml
 import theblues.charmstore
 import theblues.errors
 
-from . import tag
+from . import tag, utils
 from .client import client
 from .client import connection
 from .constraints import parse as parse_constraints, normalize_key
@@ -385,8 +385,8 @@ class Model(object):
         self.observers = weakref.WeakValueDictionary()
         self.state = ModelState(self)
         self.info = None
-        self._watcher_task = None
-        self._watch_shutdown = asyncio.Event(loop=self.loop)
+        self._watch_stopping = asyncio.Event(loop=self.loop)
+        self._watch_stopped = asyncio.Event(loop=self.loop)
         self._watch_received = asyncio.Event(loop=self.loop)
         self._charmstore = CharmStore(self.loop)
 
@@ -431,9 +431,10 @@ class Model(object):
         """Shut down the watcher task and close websockets.
 
         """
-        self._stop_watching()
         if self.connection and self.connection.is_open:
-            await self._watch_shutdown.wait()
+            log.debug('Stopping watcher task')
+            self._watch_stopping.set()
+            await self._watch_stopped.wait()
             log.debug('Closing model connection')
             await self.connection.close()
             self.connection = None
@@ -619,41 +620,34 @@ class Model(object):
 
         """
         async def _start_watch():
-            self._watch_shutdown.clear()
             try:
                 allwatcher = client.AllWatcherFacade.from_connection(
                     self.connection)
-                while True:
-                    results = await allwatcher.Next()
+                while not self._watch_stopping.is_set():
+                    results = await utils.run_with_interrupt(
+                        allwatcher.Next(),
+                        self._watch_stopping,
+                        self.loop)
+                    if self._watch_stopping.is_set():
+                        break
                     for delta in results.deltas:
                         delta = get_entity_delta(delta)
                         old_obj, new_obj = self.state.apply_delta(delta)
-                        # XXX: Might not want to shield at this level
-                        # We are shielding because when the watcher is
-                        # canceled (on disconnect()), we don't want all of
-                        # its children (every observer callback) to be
-                        # canceled with it. So we shield them. But this means
-                        # they can *never* be canceled.
-                        await asyncio.shield(
-                            self._notify_observers(delta, old_obj, new_obj),
-                            loop=self.loop)
+                        await self._notify_observers(delta, old_obj, new_obj)
                     self._watch_received.set()
             except CancelledError:
-                self._watch_shutdown.set()
+                pass
             except Exception:
                 log.exception('Error in watcher')
                 raise
+            finally:
+                self._watch_stopped.set()
 
         log.debug('Starting watcher task')
-        self._watcher_task = self.loop.create_task(_start_watch())
-
-    def _stop_watching(self):
-        """Stop the asynchronous watch against this model.
-
-        """
-        log.debug('Stopping watcher task')
-        if self._watcher_task:
-            self._watcher_task.cancel()
+        self._watch_received.clear()
+        self._watch_stopping.clear()
+        self._watch_stopped.clear()
+        self.loop.create_task(_start_watch())
 
     async def _notify_observers(self, delta, old_obj, new_obj):
         """Call observing callbacks, notifying them of a change in model state
index f4db66e..1d1b24e 100644 (file)
@@ -69,3 +69,30 @@ class IdQueue:
     async def put_all(self, value):
         for queue in self._queues.values():
             await queue.put(value)
+
+
+async def run_with_interrupt(task, event, loop=None):
+    """
+    Awaits a task while allowing it to be interrupted by an `asyncio.Event`.
+
+    If the task finishes without the event becoming set, the results of the
+    task will be returned.  If the event becomes set, the task will be
+    cancelled ``None`` will be returned.
+
+    :param task: Task to run
+    :param event: An `asyncio.Event` which, if set, will interrupt `task`
+        and cause it to be cancelled.
+    :param loop: Optional event loop to use other than the default.
+    """
+    loop = loop or asyncio.get_event_loop()
+    event_task = loop.create_task(event.wait())
+    done, pending = await asyncio.wait([task, event_task],
+                                       loop=loop,
+                                       return_when=asyncio.FIRST_COMPLETED)
+    for f in pending:
+        f.cancel()
+    result = [f.result() for f in done if f is not event_task]
+    if result:
+        return result[0]
+    else:
+        return None
index 5371fdb..340264e 100644 (file)
@@ -20,7 +20,7 @@ class WebsocketMock:
 
     async def recv(self):
         if not self.responses:
-            raise ConnectionClosed(0, 'no reason')
+            raise ConnectionClosed(0, 'ran out of responses')
         return json.dumps(self.responses.popleft())
 
     async def close(self):
@@ -41,9 +41,12 @@ async def test_out_of_order(event_loop):
         {'request-id': 3},
     ]
     con._get_sll = mock.MagicMock()
-    with mock.patch('websockets.connect', base.AsyncMock(return_value=ws)):
-        await con.open()
-    actual_responses = []
-    for i in range(3):
-        actual_responses.append(await con.rpc({'version': 1}))
-    assert actual_responses == expected_responses
+    try:
+        with mock.patch('websockets.connect', base.AsyncMock(return_value=ws)):
+            await con.open()
+        actual_responses = []
+        for i in range(3):
+            actual_responses.append(await con.rpc({'version': 1}))
+        assert actual_responses == expected_responses
+    finally:
+        await con.close()