From cd48f185bcb9279a96d3ee85579d96ac10d12dd9 Mon Sep 17 00:00:00 2001 From: Cory Johns Date: Thu, 22 Jun 2017 16:19:19 -0400 Subject: [PATCH] Improve handling of closed connections (#148) * Prevent circular reference between Monitor and Connection * Improve monitor status logic * Stop watcher task cleanly on disconnect Instead of the watcher task blowing up with an exception when the connection is closed, it should exit cleanly and allow the Monitor to report the connection error. Fixes #147 * Raise ConnectionClosed exception in Model.block_until Rather than blocking indefinitely when waiting on a model change that will never happen if the connection gets closed out from under it, this makes `Model.block_until` raise a `websockets.ConnectionClosed` error so that it can be caught and dealt with by the client. * Automatically reconnect lost websocket connections The receiver or all_watcher should immediately reconnect a lost connection, or it will be reconnected automatically upon issuance of the next RPC call. However, there is still a small chance that the disconnect could happen between sending a API call and receiving the response, in which case a websockets.ConnectionClosed error will still be raised to the caller. These should be quite rare, though. * Gracefully handle add_signal_handler failing in a thread in loop.run * Restart AllWatcher if controller stops it Fixes conjure-up/conjure-up#965 * Explicitly let the controller know we're stopping the watcher * Skip reconnect if close was requested --- juju/client/connection.py | 140 ++++++++++++++++----------- juju/loop.py | 12 ++- juju/model.py | 44 +++++++-- tests/base.py | 5 +- tests/integration/test_connection.py | 28 +++++- tests/integration/test_model.py | 9 ++ tests/unit/test_connection.py | 2 + 7 files changed, 172 insertions(+), 68 deletions(-) diff --git a/juju/client/connection.py b/juju/client/connection.py index 6f2f2a2..7457391 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -8,6 +8,7 @@ import shlex import ssl import string import subprocess +import weakref import websockets from concurrent.futures import CancelledError from http.client import HTTPSConnection @@ -40,14 +41,15 @@ class Monitor: """ ERROR = 'error' CONNECTED = 'connected' + DISCONNECTING = 'disconnecting' DISCONNECTED = 'disconnected' - UNKNOWN = 'unknown' def __init__(self, connection): - self.connection = connection - self.close_called = asyncio.Event(loop=self.connection.loop) - self.receiver_stopped = asyncio.Event(loop=self.connection.loop) - self.pinger_stopped = asyncio.Event(loop=self.connection.loop) + self.connection = weakref.ref(connection) + self.reconnecting = asyncio.Lock(loop=connection.loop) + self.close_called = asyncio.Event(loop=connection.loop) + self.receiver_stopped = asyncio.Event(loop=connection.loop) + self.pinger_stopped = asyncio.Event(loop=connection.loop) self.receiver_stopped.set() self.pinger_stopped.set() @@ -63,35 +65,27 @@ class Monitor: isn't usable until that receiver has been started. """ + connection = self.connection() - # DISCONNECTED: connection not yet open - if not self.connection.ws: + # the connection instance was destroyed but someone kept + # a separate reference to the monitor for some reason + if not connection: return self.DISCONNECTED - if self.receiver_stopped.is_set(): - return self.DISCONNECTED - - # ERROR: Connection closed (or errored), but we didn't call - # connection.close - if not self.close_called.is_set() and self.receiver_stopped.is_set(): - return self.ERROR - if not self.close_called.is_set() and not self.connection.ws.open: - # The check for self.receiver_stopped existing above guards - # against the case where we're not open because we simply - # haven't setup the connection yet. - return self.ERROR - # DISCONNECTED: cleanly disconnected. - if self.close_called.is_set() and not self.connection.ws.open: + # connection cleanly disconnected or not yet opened + if not connection.ws: return self.DISCONNECTED - # CONNECTED: everything is fine! - if self.connection.ws.open: - return self.CONNECTED + # close called but not yet complete + if self.close_called.is_set(): + return self.DISCONNECTING + + # connection closed uncleanly (we didn't call connection.close) + if self.receiver_stopped.is_set() or not connection.ws.open: + return self.ERROR - # UNKNOWN: We should never hit this state -- if we do, - # something went wrong with the logic above, and we do not - # know what state the connection is in. - return self.UNKNOWN + # everything is fine! + return self.CONNECTED class Connection: @@ -120,6 +114,7 @@ class Connection: self, endpoint, uuid, username, password, cacert=None, macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE): self.endpoint = endpoint + self._endpoint = endpoint self.uuid = uuid if macaroons: self.macaroons = macaroons @@ -130,6 +125,7 @@ class Connection: self.username = username self.password = password self.cacert = cacert + self._cacert = cacert self.loop = loop or asyncio.get_event_loop() self.__request_id__ = 0 @@ -144,9 +140,7 @@ class Connection: @property def is_open(self): - if self.ws: - return self.ws.open - return False + return self.monitor.status == Monitor.CONNECTED def _get_ssl(self, cert=None): return ssl.create_default_context( @@ -171,12 +165,13 @@ class Connection: return self async def close(self): - if not self.is_open: + if not self.ws: return self.monitor.close_called.set() await self.monitor.pinger_stopped.wait() await self.monitor.receiver_stopped.wait() await self.ws.close() + self.ws = None async def recv(self, request_id): if not self.is_open: @@ -197,13 +192,18 @@ class Connection: await self.messages.put(result['request-id'], result) except CancelledError: pass - except Exception as e: + except websockets.ConnectionClosed as e: + log.warning('Receiver: Connection closed, reconnecting') await self.messages.put_all(e) - if isinstance(e, websockets.ConnectionClosed): - # ConnectionClosed is not really exceptional for us, - # but it may be for any pending message listeners - return + # the reconnect has to be done as a task because the receiver will + # be cancelled by the reconnect and we don't want the reconnect + # to be aborted half-way through + self.loop.create_task(self.reconnect()) + return + except Exception as e: log.exception("Error in receiver") + # make pending listeners aware of the error + await self.messages.put_all(e) raise finally: self.monitor.receiver_stopped.set() @@ -225,7 +225,7 @@ class Connection: pinger_facade = client.PingerFacade.from_connection(self) try: - while self.is_open: + while True: await utils.run_with_interrupt( _do_ping(), self.monitor.close_called, @@ -234,6 +234,7 @@ class Connection: break finally: self.monitor.pinger_stopped.set() + return async def rpc(self, msg, encoder=None): self.__request_id__ += 1 @@ -243,7 +244,19 @@ class Connection: if "version" not in msg: msg['version'] = self.facades[msg['type']] outgoing = json.dumps(msg, indent=2, cls=encoder) - await self.ws.send(outgoing) + for attempt in range(3): + try: + await self.ws.send(outgoing) + break + except websockets.ConnectionClosed: + if attempt == 2: + raise + log.warning('RPC: Connection closed, reconnecting') + # the reconnect has to be done in a separate task because, + # if it is triggered by the pinger, then this RPC call will + # be cancelled when the pinger is cancelled by the reconnect, + # and we don't want the reconnect to be aborted halfway through + await asyncio.wait([self.reconnect()], loop=self.loop) result = await self.recv(msg['request-id']) if not result: @@ -379,36 +392,49 @@ class Connection: await self.close() return success, result, new_endpoints - @classmethod - async def connect( - cls, endpoint, uuid, username, password, cacert=None, - macaroons=None, loop=None, max_frame_size=None): - """Connect to the websocket. - - If uuid is None, the connection will be to the controller. Otherwise it - will be to the model. - + async def reconnect(self): + """ Force a reconnection. """ - client = cls(endpoint, uuid, username, password, cacert, macaroons, - loop, max_frame_size) - endpoints = [(endpoint, cacert)] + monitor = self.monitor + if monitor.reconnecting.locked() or monitor.close_called.is_set(): + return + async with monitor.reconnecting: + await self.close() + await self._connect() + + async def _connect(self): + endpoints = [(self._endpoint, self._cacert)] while endpoints: _endpoint, _cacert = endpoints.pop(0) - success, result, new_endpoints = await client._try_endpoint( + success, result, new_endpoints = await self._try_endpoint( _endpoint, _cacert) if success: break endpoints.extend(new_endpoints) else: # ran out of endpoints without a successful login - raise Exception("Couldn't authenticate to {}".format(endpoint)) + raise Exception("Couldn't authenticate to {}".format( + self._endpoint)) response = result['response'] - client.info = response.copy() - client.build_facades(response.get('facades', {})) - client.loop.create_task(client.pinger()) - client.monitor.pinger_stopped.clear() + self.info = response.copy() + self.build_facades(response.get('facades', {})) + self.loop.create_task(self.pinger()) + self.monitor.pinger_stopped.clear() + + @classmethod + async def connect( + cls, endpoint, uuid, username, password, cacert=None, + macaroons=None, loop=None, max_frame_size=None): + """Connect to the websocket. + + If uuid is None, the connection will be to the controller. Otherwise it + will be to the model. + """ + client = cls(endpoint, uuid, username, password, cacert, macaroons, + loop, max_frame_size) + await client._connect() return client @classmethod diff --git a/juju/loop.py b/juju/loop.py index 3720159..4abedfc 100644 --- a/juju/loop.py +++ b/juju/loop.py @@ -20,7 +20,14 @@ def run(*steps): task.cancel() run._sigint = True - loop.add_signal_handler(signal.SIGINT, abort) + added = False + try: + loop.add_signal_handler(signal.SIGINT, abort) + added = True + except ValueError as e: + # add_signal_handler doesn't work in a thread + if 'main thread' not in str(e): + raise try: for step in steps: task = loop.create_task(step) @@ -31,4 +38,5 @@ def run(*steps): raise task.exception() return task.result() finally: - loop.remove_signal_handler(signal.SIGINT) + if added: + loop.remove_signal_handler(signal.SIGINT) diff --git a/juju/model.py b/juju/model.py index 61905c9..7b86ba3 100644 --- a/juju/model.py +++ b/juju/model.py @@ -14,6 +14,7 @@ from concurrent.futures import CancelledError from functools import partial from pathlib import Path +import websockets import yaml import theblues.charmstore import theblues.errors @@ -550,6 +551,8 @@ class Model(object): """ async def _block(): while not all(c() for c in conditions): + if not (self.connection and self.connection.is_open): + raise websockets.ConnectionClosed(1006, 'no reason') await asyncio.sleep(wait_period, loop=self.loop) await asyncio.wait_for(_block(), timeout, loop=self.loop) @@ -643,16 +646,45 @@ class Model(object): See :meth:`add_observer` to register an onchange callback. """ - async def _start_watch(): + async def _all_watcher(): try: allwatcher = client.AllWatcherFacade.from_connection( self.connection) while not self._watch_stopping.is_set(): - results = await utils.run_with_interrupt( - allwatcher.Next(), - self._watch_stopping, - self.loop) + try: + results = await utils.run_with_interrupt( + allwatcher.Next(), + self._watch_stopping, + self.loop) + except JujuAPIError as e: + if 'watcher was stopped' not in str(e): + raise + if self._watch_stopping.is_set(): + # this shouldn't ever actually happen, because + # the event should trigger before the controller + # has a chance to tell us the watcher is stopped + # but handle it gracefully, just in case + break + # controller stopped our watcher for some reason + # but we're not actually stopping, so just restart it + log.warning( + 'Watcher: watcher stopped, restarting') + del allwatcher.Id + continue + except websockets.ConnectionClosed: + monitor = self.connection.monitor + if monitor.status == monitor.ERROR: + # closed unexpectedly, try to reopen + log.warning( + 'Watcher: connection closed, reopening') + await self.connection.reconnect() + del allwatcher.Id + continue + else: + # closed on request, go ahead and shutdown + break if self._watch_stopping.is_set(): + await allwatcher.Stop() break for delta in results.deltas: delta = get_entity_delta(delta) @@ -671,7 +703,7 @@ class Model(object): self._watch_received.clear() self._watch_stopping.clear() self._watch_stopped.clear() - self.loop.create_task(_start_watch()) + self.loop.create_task(_all_watcher()) async def _notify_observers(self, delta, old_obj, new_obj): """Call observing callbacks, notifying them of a change in model state diff --git a/tests/base.py b/tests/base.py index 8ea5109..e1ec452 100644 --- a/tests/base.py +++ b/tests/base.py @@ -44,6 +44,9 @@ class CleanModel(): model_name = 'model-{}'.format(uuid.uuid4()) self.model = await self.controller.add_model(model_name) + # save the model UUID in case test closes model + self.model_uuid = self.model.info.uuid + # Ensure that we connect to the new model by default. This also # prevents failures if test was started with no current model. self._patch_cm = mock.patch.object(JujuData, 'current_model', @@ -55,7 +58,7 @@ class CleanModel(): async def __aexit__(self, exc_type, exc, tb): self._patch_cm.stop() await self.model.disconnect() - await self.controller.destroy_model(self.model.info.uuid) + await self.controller.destroy_model(self.model_uuid) await self.controller.disconnect() diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py index 67dfb2e..290203d 100644 --- a/tests/integration/test_connection.py +++ b/tests/integration/test_connection.py @@ -1,3 +1,4 @@ +import asyncio import pytest from juju.client.connection import Connection @@ -47,7 +48,7 @@ async def test_monitor_catches_error(event_loop): @pytest.mark.asyncio async def test_full_status(event_loop): async with base.CleanModel() as model: - app = await model.deploy( + await model.deploy( 'ubuntu-0', application_name='ubuntu', series='trusty', @@ -56,4 +57,27 @@ async def test_full_status(event_loop): c = client.ClientFacade.from_connection(model.connection) - status = await c.FullStatus(None) + await c.FullStatus(None) + + +@base.bootstrapped +@pytest.mark.asyncio +async def test_reconnect(event_loop): + async with base.CleanModel() as model: + conn = await Connection.connect( + model.connection.endpoint, + model.connection.uuid, + model.connection.username, + model.connection.password, + model.connection.cacert, + model.connection.macaroons, + model.connection.loop, + model.connection.max_frame_size) + try: + await asyncio.sleep(0.1) + assert conn.is_open + await conn.ws.close() + assert not conn.is_open + await model.block_until(lambda: conn.is_open, timeout=3) + finally: + await conn.close() diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index 088dcd5..37f51c0 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -212,6 +212,15 @@ async def test_get_machines(event_loop): assert isinstance(result, list) +@base.bootstrapped +@pytest.mark.asyncio +async def test_watcher_reconnect(event_loop): + async with base.CleanModel() as model: + await model.connection.ws.close() + await asyncio.sleep(0.1) + assert model.connection.is_open + + # @base.bootstrapped # @pytest.mark.asyncio # async def test_grant(event_loop) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 340264e..f69b8d6 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -1,3 +1,4 @@ +import asyncio import json import mock import pytest @@ -20,6 +21,7 @@ class WebsocketMock: async def recv(self): if not self.responses: + await asyncio.sleep(1) # delay to give test time to finish raise ConnectionClosed(0, 'ran out of responses') return json.dumps(self.responses.popleft()) -- 2.25.1