Improve handling of closed connections (#148)
authorCory Johns <johnsca@gmail.com>
Thu, 22 Jun 2017 20:19:19 +0000 (16:19 -0400)
committerGitHub <noreply@github.com>
Thu, 22 Jun 2017 20:19:19 +0000 (16:19 -0400)
* Prevent circular reference between Monitor and Connection

* Improve monitor status logic

* Stop watcher task cleanly on disconnect

Instead of the watcher task blowing up with an exception when the
connection is closed, it should exit cleanly and allow the Monitor to
report the connection error.

Fixes #147

* Raise ConnectionClosed exception in Model.block_until

Rather than blocking indefinitely when waiting on a model change that
will never happen if the connection gets closed out from under it, this
makes `Model.block_until` raise a `websockets.ConnectionClosed` error so
that it can be caught and dealt with by the client.

* Automatically reconnect lost websocket connections

The receiver or all_watcher should immediately reconnect a lost
connection, or it will be reconnected automatically upon issuance of the
next RPC call.  However, there is still a small chance that the
disconnect could happen between sending a API call and receiving the
response, in which case a websockets.ConnectionClosed error will still
be raised to the caller.  These should be quite rare, though.

* Gracefully handle add_signal_handler failing in a thread in loop.run

* Restart AllWatcher if controller stops it

Fixes conjure-up/conjure-up#965

* Explicitly let the controller know we're stopping the watcher

* Skip reconnect if close was requested

juju/client/connection.py
juju/loop.py
juju/model.py
tests/base.py
tests/integration/test_connection.py
tests/integration/test_model.py
tests/unit/test_connection.py

index 6f2f2a2..7457391 100644 (file)
@@ -8,6 +8,7 @@ import shlex
 import ssl
 import string
 import subprocess
+import weakref
 import websockets
 from concurrent.futures import CancelledError
 from http.client import HTTPSConnection
@@ -40,14 +41,15 @@ class Monitor:
     """
     ERROR = 'error'
     CONNECTED = 'connected'
+    DISCONNECTING = 'disconnecting'
     DISCONNECTED = 'disconnected'
-    UNKNOWN = 'unknown'
 
     def __init__(self, connection):
-        self.connection = connection
-        self.close_called = asyncio.Event(loop=self.connection.loop)
-        self.receiver_stopped = asyncio.Event(loop=self.connection.loop)
-        self.pinger_stopped = asyncio.Event(loop=self.connection.loop)
+        self.connection = weakref.ref(connection)
+        self.reconnecting = asyncio.Lock(loop=connection.loop)
+        self.close_called = asyncio.Event(loop=connection.loop)
+        self.receiver_stopped = asyncio.Event(loop=connection.loop)
+        self.pinger_stopped = asyncio.Event(loop=connection.loop)
         self.receiver_stopped.set()
         self.pinger_stopped.set()
 
@@ -63,35 +65,27 @@ class Monitor:
         isn't usable until that receiver has been started.
 
         """
+        connection = self.connection()
 
-        # DISCONNECTED: connection not yet open
-        if not self.connection.ws:
+        # the connection instance was destroyed but someone kept
+        # a separate reference to the monitor for some reason
+        if not connection:
             return self.DISCONNECTED
-        if self.receiver_stopped.is_set():
-            return self.DISCONNECTED
-
-        # ERROR: Connection closed (or errored), but we didn't call
-        # connection.close
-        if not self.close_called.is_set() and self.receiver_stopped.is_set():
-            return self.ERROR
-        if not self.close_called.is_set() and not self.connection.ws.open:
-            # The check for self.receiver_stopped existing above guards
-            # against the case where we're not open because we simply
-            # haven't setup the connection yet.
-            return self.ERROR
 
-        # DISCONNECTED: cleanly disconnected.
-        if self.close_called.is_set() and not self.connection.ws.open:
+        # connection cleanly disconnected or not yet opened
+        if not connection.ws:
             return self.DISCONNECTED
 
-        # CONNECTED: everything is fine!
-        if self.connection.ws.open:
-            return self.CONNECTED
+        # close called but not yet complete
+        if self.close_called.is_set():
+            return self.DISCONNECTING
+
+        # connection closed uncleanly (we didn't call connection.close)
+        if self.receiver_stopped.is_set() or not connection.ws.open:
+            return self.ERROR
 
-        # UNKNOWN: We should never hit this state -- if we do,
-        # something went wrong with the logic above, and we do not
-        # know what state the connection is in.
-        return self.UNKNOWN
+        # everything is fine!
+        return self.CONNECTED
 
 
 class Connection:
@@ -120,6 +114,7 @@ class Connection:
             self, endpoint, uuid, username, password, cacert=None,
             macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
         self.endpoint = endpoint
+        self._endpoint = endpoint
         self.uuid = uuid
         if macaroons:
             self.macaroons = macaroons
@@ -130,6 +125,7 @@ class Connection:
             self.username = username
             self.password = password
         self.cacert = cacert
+        self._cacert = cacert
         self.loop = loop or asyncio.get_event_loop()
 
         self.__request_id__ = 0
@@ -144,9 +140,7 @@ class Connection:
 
     @property
     def is_open(self):
-        if self.ws:
-            return self.ws.open
-        return False
+        return self.monitor.status == Monitor.CONNECTED
 
     def _get_ssl(self, cert=None):
         return ssl.create_default_context(
@@ -171,12 +165,13 @@ class Connection:
         return self
 
     async def close(self):
-        if not self.is_open:
+        if not self.ws:
             return
         self.monitor.close_called.set()
         await self.monitor.pinger_stopped.wait()
         await self.monitor.receiver_stopped.wait()
         await self.ws.close()
+        self.ws = None
 
     async def recv(self, request_id):
         if not self.is_open:
@@ -197,13 +192,18 @@ class Connection:
                     await self.messages.put(result['request-id'], result)
         except CancelledError:
             pass
-        except Exception as e:
+        except websockets.ConnectionClosed as e:
+            log.warning('Receiver: Connection closed, reconnecting')
             await self.messages.put_all(e)
-            if isinstance(e, websockets.ConnectionClosed):
-                # ConnectionClosed is not really exceptional for us,
-                # but it may be for any pending message listeners
-                return
+            # the reconnect has to be done as a task because the receiver will
+            # be cancelled by the reconnect and we don't want the reconnect
+            # to be aborted half-way through
+            self.loop.create_task(self.reconnect())
+            return
+        except Exception as e:
             log.exception("Error in receiver")
+            # make pending listeners aware of the error
+            await self.messages.put_all(e)
             raise
         finally:
             self.monitor.receiver_stopped.set()
@@ -225,7 +225,7 @@ class Connection:
 
         pinger_facade = client.PingerFacade.from_connection(self)
         try:
-            while self.is_open:
+            while True:
                 await utils.run_with_interrupt(
                     _do_ping(),
                     self.monitor.close_called,
@@ -234,6 +234,7 @@ class Connection:
                     break
         finally:
             self.monitor.pinger_stopped.set()
+            return
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -243,7 +244,19 @@ class Connection:
         if "version" not in msg:
             msg['version'] = self.facades[msg['type']]
         outgoing = json.dumps(msg, indent=2, cls=encoder)
-        await self.ws.send(outgoing)
+        for attempt in range(3):
+            try:
+                await self.ws.send(outgoing)
+                break
+            except websockets.ConnectionClosed:
+                if attempt == 2:
+                    raise
+                log.warning('RPC: Connection closed, reconnecting')
+                # the reconnect has to be done in a separate task because,
+                # if it is triggered by the pinger, then this RPC call will
+                # be cancelled when the pinger is cancelled by the reconnect,
+                # and we don't want the reconnect to be aborted halfway through
+                await asyncio.wait([self.reconnect()], loop=self.loop)
         result = await self.recv(msg['request-id'])
 
         if not result:
@@ -379,36 +392,49 @@ class Connection:
                 await self.close()
         return success, result, new_endpoints
 
-    @classmethod
-    async def connect(
-            cls, endpoint, uuid, username, password, cacert=None,
-            macaroons=None, loop=None, max_frame_size=None):
-        """Connect to the websocket.
-
-        If uuid is None, the connection will be to the controller. Otherwise it
-        will be to the model.
-
+    async def reconnect(self):
+        """ Force a reconnection.
         """
-        client = cls(endpoint, uuid, username, password, cacert, macaroons,
-                     loop, max_frame_size)
-        endpoints = [(endpoint, cacert)]
+        monitor = self.monitor
+        if monitor.reconnecting.locked() or monitor.close_called.is_set():
+            return
+        async with monitor.reconnecting:
+            await self.close()
+            await self._connect()
+
+    async def _connect(self):
+        endpoints = [(self._endpoint, self._cacert)]
         while endpoints:
             _endpoint, _cacert = endpoints.pop(0)
-            success, result, new_endpoints = await client._try_endpoint(
+            success, result, new_endpoints = await self._try_endpoint(
                 _endpoint, _cacert)
             if success:
                 break
             endpoints.extend(new_endpoints)
         else:
             # ran out of endpoints without a successful login
-            raise Exception("Couldn't authenticate to {}".format(endpoint))
+            raise Exception("Couldn't authenticate to {}".format(
+                self._endpoint))
 
         response = result['response']
-        client.info = response.copy()
-        client.build_facades(response.get('facades', {}))
-        client.loop.create_task(client.pinger())
-        client.monitor.pinger_stopped.clear()
+        self.info = response.copy()
+        self.build_facades(response.get('facades', {}))
+        self.loop.create_task(self.pinger())
+        self.monitor.pinger_stopped.clear()
+
+    @classmethod
+    async def connect(
+            cls, endpoint, uuid, username, password, cacert=None,
+            macaroons=None, loop=None, max_frame_size=None):
+        """Connect to the websocket.
+
+        If uuid is None, the connection will be to the controller. Otherwise it
+        will be to the model.
 
+        """
+        client = cls(endpoint, uuid, username, password, cacert, macaroons,
+                     loop, max_frame_size)
+        await client._connect()
         return client
 
     @classmethod
index 3720159..4abedfc 100644 (file)
@@ -20,7 +20,14 @@ def run(*steps):
         task.cancel()
         run._sigint = True
 
-    loop.add_signal_handler(signal.SIGINT, abort)
+    added = False
+    try:
+        loop.add_signal_handler(signal.SIGINT, abort)
+        added = True
+    except ValueError as e:
+        # add_signal_handler doesn't work in a thread
+        if 'main thread' not in str(e):
+            raise
     try:
         for step in steps:
             task = loop.create_task(step)
@@ -31,4 +38,5 @@ def run(*steps):
                 raise task.exception()
         return task.result()
     finally:
-        loop.remove_signal_handler(signal.SIGINT)
+        if added:
+            loop.remove_signal_handler(signal.SIGINT)
index 61905c9..7b86ba3 100644 (file)
@@ -14,6 +14,7 @@ from concurrent.futures import CancelledError
 from functools import partial
 from pathlib import Path
 
+import websockets
 import yaml
 import theblues.charmstore
 import theblues.errors
@@ -550,6 +551,8 @@ class Model(object):
         """
         async def _block():
             while not all(c() for c in conditions):
+                if not (self.connection and self.connection.is_open):
+                    raise websockets.ConnectionClosed(1006, 'no reason')
                 await asyncio.sleep(wait_period, loop=self.loop)
         await asyncio.wait_for(_block(), timeout, loop=self.loop)
 
@@ -643,16 +646,45 @@ class Model(object):
         See :meth:`add_observer` to register an onchange callback.
 
         """
-        async def _start_watch():
+        async def _all_watcher():
             try:
                 allwatcher = client.AllWatcherFacade.from_connection(
                     self.connection)
                 while not self._watch_stopping.is_set():
-                    results = await utils.run_with_interrupt(
-                        allwatcher.Next(),
-                        self._watch_stopping,
-                        self.loop)
+                    try:
+                        results = await utils.run_with_interrupt(
+                            allwatcher.Next(),
+                            self._watch_stopping,
+                            self.loop)
+                    except JujuAPIError as e:
+                        if 'watcher was stopped' not in str(e):
+                            raise
+                        if self._watch_stopping.is_set():
+                            # this shouldn't ever actually happen, because
+                            # the event should trigger before the controller
+                            # has a chance to tell us the watcher is stopped
+                            # but handle it gracefully, just in case
+                            break
+                        # controller stopped our watcher for some reason
+                        # but we're not actually stopping, so just restart it
+                        log.warning(
+                            'Watcher: watcher stopped, restarting')
+                        del allwatcher.Id
+                        continue
+                    except websockets.ConnectionClosed:
+                        monitor = self.connection.monitor
+                        if monitor.status == monitor.ERROR:
+                            # closed unexpectedly, try to reopen
+                            log.warning(
+                                'Watcher: connection closed, reopening')
+                            await self.connection.reconnect()
+                            del allwatcher.Id
+                            continue
+                        else:
+                            # closed on request, go ahead and shutdown
+                            break
                     if self._watch_stopping.is_set():
+                        await allwatcher.Stop()
                         break
                     for delta in results.deltas:
                         delta = get_entity_delta(delta)
@@ -671,7 +703,7 @@ class Model(object):
         self._watch_received.clear()
         self._watch_stopping.clear()
         self._watch_stopped.clear()
-        self.loop.create_task(_start_watch())
+        self.loop.create_task(_all_watcher())
 
     async def _notify_observers(self, delta, old_obj, new_obj):
         """Call observing callbacks, notifying them of a change in model state
index 8ea5109..e1ec452 100644 (file)
@@ -44,6 +44,9 @@ class CleanModel():
         model_name = 'model-{}'.format(uuid.uuid4())
         self.model = await self.controller.add_model(model_name)
 
+        # save the model UUID in case test closes model
+        self.model_uuid = self.model.info.uuid
+
         # Ensure that we connect to the new model by default.  This also
         # prevents failures if test was started with no current model.
         self._patch_cm = mock.patch.object(JujuData, 'current_model',
@@ -55,7 +58,7 @@ class CleanModel():
     async def __aexit__(self, exc_type, exc, tb):
         self._patch_cm.stop()
         await self.model.disconnect()
-        await self.controller.destroy_model(self.model.info.uuid)
+        await self.controller.destroy_model(self.model_uuid)
         await self.controller.disconnect()
 
 
index 67dfb2e..290203d 100644 (file)
@@ -1,3 +1,4 @@
+import asyncio
 import pytest
 
 from juju.client.connection import Connection
@@ -47,7 +48,7 @@ async def test_monitor_catches_error(event_loop):
 @pytest.mark.asyncio
 async def test_full_status(event_loop):
     async with base.CleanModel() as model:
-        app = await model.deploy(
+        await model.deploy(
             'ubuntu-0',
             application_name='ubuntu',
             series='trusty',
@@ -56,4 +57,27 @@ async def test_full_status(event_loop):
 
         c = client.ClientFacade.from_connection(model.connection)
 
-        status = await c.FullStatus(None)
+        await c.FullStatus(None)
+
+
+@base.bootstrapped
+@pytest.mark.asyncio
+async def test_reconnect(event_loop):
+    async with base.CleanModel() as model:
+        conn = await Connection.connect(
+            model.connection.endpoint,
+            model.connection.uuid,
+            model.connection.username,
+            model.connection.password,
+            model.connection.cacert,
+            model.connection.macaroons,
+            model.connection.loop,
+            model.connection.max_frame_size)
+        try:
+            await asyncio.sleep(0.1)
+            assert conn.is_open
+            await conn.ws.close()
+            assert not conn.is_open
+            await model.block_until(lambda: conn.is_open, timeout=3)
+        finally:
+            await conn.close()
index 088dcd5..37f51c0 100644 (file)
@@ -212,6 +212,15 @@ async def test_get_machines(event_loop):
         assert isinstance(result, list)
 
 
+@base.bootstrapped
+@pytest.mark.asyncio
+async def test_watcher_reconnect(event_loop):
+    async with base.CleanModel() as model:
+        await model.connection.ws.close()
+        await asyncio.sleep(0.1)
+        assert model.connection.is_open
+
+
 # @base.bootstrapped
 # @pytest.mark.asyncio
 # async def test_grant(event_loop)
index 340264e..f69b8d6 100644 (file)
@@ -1,3 +1,4 @@
+import asyncio
 import json
 import mock
 import pytest
@@ -20,6 +21,7 @@ class WebsocketMock:
 
     async def recv(self):
         if not self.responses:
+            await asyncio.sleep(1)  # delay to give test time to finish
             raise ConnectionClosed(0, 'ran out of responses')
         return json.dumps(self.responses.popleft())