Make observers async; make model.reset() blocking
authorTim Van Steenburgh <tvansteenburgh@gmail.com>
Wed, 14 Sep 2016 00:21:26 +0000 (20:21 -0400)
committerTim Van Steenburgh <tvansteenburgh@gmail.com>
Wed, 14 Sep 2016 00:21:26 +0000 (20:21 -0400)
examples/relate.py
juju/model.py

index 967a785..a8e39a0 100644 (file)
@@ -9,11 +9,11 @@ from juju.model import Model, ModelObserver
 
 
 class MyModelObserver(ModelObserver):
-    def on_change(self, delta, old, new, model):
+    async def on_change(self, delta, old, new, model):
         if model.all_units_idle():
             logging.debug('All units idle, disconnecting')
-            task = model.loop.create_task(model.disconnect())
-            task.add_done_callback(lambda fut: model.loop.stop())
+            await model.disconnect()
+            model.loop.stop()
 
 
 async def run():
@@ -21,9 +21,6 @@ async def run():
     await model.connect_current()
 
     await model.reset(force=True)
-    await model.block_until(
-        lambda: len(model.machines) == 0
-    )
     model.add_observer(MyModelObserver())
 
     await model.deploy(
@@ -44,10 +41,10 @@ async def run():
         'nrpe',
     )
 
-
 logging.basicConfig(level=logging.DEBUG)
 ws_logger = logging.getLogger('websockets.protocol')
 ws_logger.setLevel(logging.INFO)
 loop = asyncio.get_event_loop()
+loop.set_debug(False)
 loop.create_task(run())
 loop.run_forever()
index 04f3437..677657f 100644 (file)
@@ -11,7 +11,7 @@ log = logging.getLogger(__name__)
 
 
 class ModelObserver(object):
-    def __call__(self, delta, old, new, model):
+    async def __call__(self, delta, old, new, model):
         if old is None and new is not None:
             type_ = 'add'
         else:
@@ -21,9 +21,9 @@ class ModelObserver(object):
         log.debug(
             'Model changed: %s %s %s',
             delta.entity, delta.type, delta.get_id())
-        method(delta, old, new, model)
+        await method(delta, old, new, model)
 
-    def on_change(self, delta, old, new, model):
+    async def on_change(self, delta, old, new, model):
         pass
 
 
@@ -62,16 +62,22 @@ class Model(object):
         self._watch_received = asyncio.Event(loop=loop)
 
     async def connect_current(self):
+        """Connect to the current Juju model.
+
+        """
         self.connection = await connection.Connection.connect_current()
         self._watch()
         await self._watch_received.wait()
 
     async def disconnect(self):
+        """Shut down the watcher task and close websockets.
+
+        """
         self._stop_watching()
         if self.connection and self.connection.is_open:
             await self._watch_shutdown.wait()
             log.debug('Closing model connection')
-            await asyncio.wait_for(self.connection.close(), None)
+            await self.connection.close()
             self.connection = None
 
     def all_units_idle(self):
@@ -85,27 +91,53 @@ class Model(object):
         return True
 
     async def reset(self, force=False):
+        """Reset the model to a clean state.
+
+        :param bool force: Force-terminate machines.
+
+        This returns only after the model has reached a clean state. "Clean"
+        means no applications or machines exist in the model.
+
+        """
         for app in self.applications.values():
             await app.destroy()
         for machine in self.machines.values():
             await machine.destroy(force=force)
+        await self.block_until(
+            lambda: len(self.machines) == 0
+        )
 
-    async def block_until(self, func):
+    async def block_until(self, *conditions, timeout=None):
+        """Return only after all conditions are true.
+
+        """
         async def _block():
-            while not func():
+            while not all(c() for c in conditions):
                 await asyncio.sleep(.1)
-        await asyncio.wait_for(_block(), None)
+        await asyncio.wait_for(_block(), timeout)
 
     @property
     def applications(self):
+        """Return a map of application-name:Application for all applications
+        currently in the model.
+
+        """
         return self.state.get('application', {})
 
     @property
     def machines(self):
+        """Return a map of machine-id:Machine for all machines currently in
+        the model.
+
+        """
         return self.state.get('machine', {})
 
     @property
     def units(self):
+        """Return a map of unit-id:Unit for all units currently in
+        the model.
+
+        """
         return self.state.get('unit', {})
 
     def add_observer(self, callable_):
@@ -113,7 +145,7 @@ class Model(object):
 
         Once a watch is started (Model.watch() is called), ``callable_``
         will be called each time the model changes. callable_ should
-        accept the following positional arguments:
+        be Awaitable and accept the following positional arguments:
 
             delta - An instance of :class:`juju.delta.EntityDelta`
                 containing the raw delta data recv'd from the Juju
@@ -150,11 +182,18 @@ class Model(object):
                     for delta in results.deltas:
                         delta = get_entity_delta(delta)
                         old_obj, new_obj = self._apply_delta(delta)
-                        self._notify_observers(delta, old_obj, new_obj)
+                        # XXX: Might not want to shield at this level
+                        # We are shielding because when the watcher is
+                        # canceled (on disconnect()), we don't want all of
+                        # its children (every observer callback) to be
+                        # canceled with it. So we shield them. But this means
+                        # they can *never* be canceled.
+                        await asyncio.shield(
+                            self._notify_observers(delta, old_obj, new_obj))
                     self._watch_received.set()
             except CancelledError:
                 log.debug('Closing watcher connection')
-                await asyncio.wait_for(self._watch_conn.close(), None)
+                await self._watch_conn.close()
                 self._watch_shutdown.set()
                 self._watch_conn = None
 
@@ -204,7 +243,7 @@ class Model(object):
         entity_class = delta.get_entity_class()
         return entity_class(delta.data, self)
 
-    def _notify_observers(self, delta, old_obj, new_obj):
+    async def _notify_observers(self, delta, old_obj, new_obj):
         """Call observing callbacks, notifying them of a change in model state
 
         :param delta: The raw change from the watcher
@@ -216,7 +255,7 @@ class Model(object):
 
         """
         for o in self.observers:
-            o(delta, old_obj, new_obj, self)
+            asyncio.ensure_future(o(delta, old_obj, new_obj, self))
 
     def add_machine(
             self, spec=None, constraints=None, disks=None, series=None,