import string
import subprocess
import websockets
+from concurrent.futures import CancelledError
from http.client import HTTPSConnection
import asyncio
import yaml
-from juju import tag
+from juju import tag, utils
from juju.client import client
from juju.client.version_map import VERSION_MAP
from juju.errors import JujuError, JujuAPIError, JujuConnectionError
def __init__(self, connection):
self.connection = connection
- self.receiver = None
- self.pinger = None
+ self.close_called = asyncio.Event(loop=self.connection.loop)
+ self.receiver_stopped = asyncio.Event(loop=self.connection.loop)
+ self.pinger_stopped = asyncio.Event(loop=self.connection.loop)
+ self.receiver_stopped.set()
+ self.pinger_stopped.set()
@property
def status(self):
# DISCONNECTED: connection not yet open
if not self.connection.ws:
return self.DISCONNECTED
- if not self.receiver:
+ if self.receiver_stopped.is_set():
return self.DISCONNECTED
# ERROR: Connection closed (or errored), but we didn't call
# connection.close
- if not self.connection.close_called and self.receiver_exceptions():
+ if not self.close_called.is_set() and self.receiver_stopped.is_set():
return self.ERROR
- if not self.connection.close_called and not self.connection.ws.open:
- # The check for self.receiver existing above guards against the
- # case where we're not open because we simply haven't
- # setup the connection yet.
+ if not self.close_called.is_set() and not self.connection.ws.open:
+ # The check for self.receiver_stopped existing above guards
+ # against the case where we're not open because we simply
+ # haven't setup the connection yet.
return self.ERROR
# DISCONNECTED: cleanly disconnected.
- if self.connection.close_called and not self.connection.ws.open:
+ if self.close_called.is_set() and not self.connection.ws.open:
return self.DISCONNECTED
# CONNECTED: everything is fine!
# know what state the connection is in.
return self.UNKNOWN
- def receiver_exceptions(self):
- """
- Return exceptions in the receiver, if any.
-
- """
- if not self.receiver:
- return None
- if not self.receiver.done():
- return None
- return self.receiver.exception()
-
class Connection:
"""
self.ws = None
self.facades = {}
self.messages = IdQueue(loop=self.loop)
- self.close_called = False
self.monitor = Monitor(connection=self)
@property
kw['loop'] = self.loop
self.addr = url
self.ws = await websockets.connect(url, **kw)
- self.monitor.receiver = self.loop.create_task(self.receiver())
+ self.loop.create_task(self.receiver())
+ self.monitor.receiver_stopped.clear()
log.info("Driver connected to juju %s", url)
+ self.monitor.close_called.clear()
return self
async def close(self):
if not self.is_open:
return
- self.close_called = True
- if self.monitor.pinger:
- # might be closing due to login failure,
- # in which case we won't have a pinger yet
- self.monitor.pinger.cancel()
- self.monitor.receiver.cancel()
+ self.monitor.close_called.set()
+ await self.monitor.pinger_stopped.wait()
+ await self.monitor.receiver_stopped.wait()
await self.ws.close()
async def recv(self, request_id):
return await self.messages.get(request_id)
async def receiver(self):
- while self.is_open:
- try:
- result = await self.ws.recv()
+ try:
+ while self.is_open:
+ result = await utils.run_with_interrupt(
+ self.ws.recv(),
+ self.monitor.close_called,
+ loop=self.loop)
+ if self.monitor.close_called.is_set():
+ break
if result is not None:
result = json.loads(result)
await self.messages.put(result['request-id'], result)
- except Exception as e:
- await self.messages.put_all(e)
- if isinstance(e, websockets.ConnectionClosed):
- # ConnectionClosed is not really exceptional for us,
- # but it may be for any pending message listeners
- return
- raise
+ except CancelledError:
+ pass
+ except Exception as e:
+ await self.messages.put_all(e)
+ if isinstance(e, websockets.ConnectionClosed):
+ # ConnectionClosed is not really exceptional for us,
+ # but it may be for any pending message listeners
+ return
+ log.exception("Error in receiver")
+ raise
+ finally:
+ self.monitor.receiver_stopped.set()
async def pinger(self):
'''
To prevent timing out, we send a ping every ten seconds.
'''
+ async def _do_ping():
+ try:
+ await pinger_facade.Ping()
+ await asyncio.sleep(10, loop=self.loop)
+ except CancelledError:
+ pass
+
pinger_facade = client.PingerFacade.from_connection(self)
- while self.is_open:
- await pinger_facade.Ping()
- await asyncio.sleep(10, loop=self.loop)
+ try:
+ while self.is_open:
+ await utils.run_with_interrupt(
+ _do_ping(),
+ self.monitor.close_called,
+ loop=self.loop)
+ if self.monitor.close_called.is_set():
+ break
+ finally:
+ self.monitor.pinger_stopped.set()
async def rpc(self, msg, encoder=None):
self.__request_id__ += 1
response = result['response']
client.info = response.copy()
client.build_facades(response.get('facades', {}))
- client.monitor.pinger = client.loop.create_task(client.pinger())
+ client.loop.create_task(client.pinger())
+ client.monitor.pinger_stopped.clear()
return client
import theblues.charmstore
import theblues.errors
-from . import tag
+from . import tag, utils
from .client import client
from .client import connection
from .constraints import parse as parse_constraints, normalize_key
self.observers = weakref.WeakValueDictionary()
self.state = ModelState(self)
self.info = None
- self._watcher_task = None
- self._watch_shutdown = asyncio.Event(loop=self.loop)
+ self._watch_stopping = asyncio.Event(loop=self.loop)
+ self._watch_stopped = asyncio.Event(loop=self.loop)
self._watch_received = asyncio.Event(loop=self.loop)
self._charmstore = CharmStore(self.loop)
"""Shut down the watcher task and close websockets.
"""
- self._stop_watching()
if self.connection and self.connection.is_open:
- await self._watch_shutdown.wait()
+ log.debug('Stopping watcher task')
+ self._watch_stopping.set()
+ await self._watch_stopped.wait()
log.debug('Closing model connection')
await self.connection.close()
self.connection = None
"""
async def _start_watch():
- self._watch_shutdown.clear()
try:
allwatcher = client.AllWatcherFacade.from_connection(
self.connection)
- while True:
- results = await allwatcher.Next()
+ while not self._watch_stopping.is_set():
+ results = await utils.run_with_interrupt(
+ allwatcher.Next(),
+ self._watch_stopping,
+ self.loop)
+ if self._watch_stopping.is_set():
+ break
for delta in results.deltas:
delta = get_entity_delta(delta)
old_obj, new_obj = self.state.apply_delta(delta)
- # XXX: Might not want to shield at this level
- # We are shielding because when the watcher is
- # canceled (on disconnect()), we don't want all of
- # its children (every observer callback) to be
- # canceled with it. So we shield them. But this means
- # they can *never* be canceled.
- await asyncio.shield(
- self._notify_observers(delta, old_obj, new_obj),
- loop=self.loop)
+ await self._notify_observers(delta, old_obj, new_obj)
self._watch_received.set()
except CancelledError:
- self._watch_shutdown.set()
+ pass
except Exception:
log.exception('Error in watcher')
raise
+ finally:
+ self._watch_stopped.set()
log.debug('Starting watcher task')
- self._watcher_task = self.loop.create_task(_start_watch())
-
- def _stop_watching(self):
- """Stop the asynchronous watch against this model.
-
- """
- log.debug('Stopping watcher task')
- if self._watcher_task:
- self._watcher_task.cancel()
+ self._watch_received.clear()
+ self._watch_stopping.clear()
+ self._watch_stopped.clear()
+ self.loop.create_task(_start_watch())
async def _notify_observers(self, delta, old_obj, new_obj):
"""Call observing callbacks, notifying them of a change in model state
async def put_all(self, value):
for queue in self._queues.values():
await queue.put(value)
+
+
+async def run_with_interrupt(task, event, loop=None):
+ """
+ Awaits a task while allowing it to be interrupted by an `asyncio.Event`.
+
+ If the task finishes without the event becoming set, the results of the
+ task will be returned. If the event becomes set, the task will be
+ cancelled ``None`` will be returned.
+
+ :param task: Task to run
+ :param event: An `asyncio.Event` which, if set, will interrupt `task`
+ and cause it to be cancelled.
+ :param loop: Optional event loop to use other than the default.
+ """
+ loop = loop or asyncio.get_event_loop()
+ event_task = loop.create_task(event.wait())
+ done, pending = await asyncio.wait([task, event_task],
+ loop=loop,
+ return_when=asyncio.FIRST_COMPLETED)
+ for f in pending:
+ f.cancel()
+ result = [f.result() for f in done if f is not event_task]
+ if result:
+ return result[0]
+ else:
+ return None