Add predicate support for observers and Model._wait_for_new()
authorTim Van Steenburgh <tvansteenburgh@gmail.com>
Mon, 17 Oct 2016 22:39:45 +0000 (18:39 -0400)
committerTim Van Steenburgh <tvansteenburgh@gmail.com>
Mon, 17 Oct 2016 22:39:45 +0000 (18:39 -0400)
examples/relate.py
juju/model.py

index fa32900..241d1a6 100644 (file)
@@ -30,7 +30,7 @@ class MyModelObserver(ModelObserver):
     _shutting_down = False
 
     async def on_change(self, delta, old, new, model):
-        if model.all_units_idle() and not self._shutting_down:
+        if model.units and model.all_units_idle() and not self._shutting_down:
             self._shutting_down = True
             logging.debug('All units idle, disconnecting')
             await model.reset(force=True)
@@ -47,7 +47,7 @@ async def run():
     model.add_observer(MyModelObserver())
 
     ubuntu_app = await model.deploy(
-        'ubuntu-0',
+        'ubuntu',
         service_name='ubuntu',
         series='trusty',
         channel='stable',
@@ -69,16 +69,20 @@ async def run():
             print('Unit removed: {}'.format(old_unit.entity_id))
     ))
     await model.deploy(
-        'nrpe-11',
+        'nrpe',
         service_name='nrpe',
         series='trusty',
         channel='stable',
         num_units=0,
     )
-    await model.add_relation(
+    my_relation = await model.add_relation(
         'ubuntu',
         'nrpe',
     )
+    my_relation.on_remove(asyncio.coroutine(
+        lambda delta, old_rel, new_rel, model:
+            print('Relation removed: {}'.format(old_rel.endpoints))
+    ))
 
 logging.basicConfig(level=logging.DEBUG)
 ws_logger = logging.getLogger('websockets.protocol')
index 5d436fd..49f07a9 100644 (file)
@@ -26,12 +26,14 @@ class _Observer(object):
     callable so that it's only called for changes that meet the criteria.
 
     """
-    def __init__(self, callable_, entity_type, action, entity_id):
+    def __init__(self, callable_, entity_type, action, entity_id, predicate):
         self.callable_ = callable_
         self.entity_type = entity_type
         self.action = action
         self.entity_id = entity_id
+        self.predicate = predicate
         if self.entity_id:
+            self.entity_id = str(self.entity_id)
             if not self.entity_id.startswith('^'):
                 self.entity_id = '^' + self.entity_id
             if not self.entity_id.endswith('$'):
@@ -40,20 +42,22 @@ class _Observer(object):
     async def __call__(self, delta, old, new, model):
         await self.callable_(delta, old, new, model)
 
-    def cares_about(self, entity_type, action, entity_id):
+    def cares_about(self, delta):
         """Return True if this observer "cares about" (i.e. wants to be
-        called) for a change matching the entity_type, action, and
-        entity_id parameters.
+        called) for a this delta.
 
         """
-        if (self.entity_id and entity_id and
-                not re.match(self.entity_id, str(entity_id))):
+        if (self.entity_id and delta.get_id() and
+                not re.match(self.entity_id, str(delta.get_id()))):
             return False
 
-        if self.entity_type and self.entity_type != entity_type:
+        if self.entity_type and self.entity_type != delta.entity:
             return False
 
-        if self.action and self.action != action:
+        if self.action and self.action != delta.type:
+            return False
+
+        if self.predicate and not self.predicate(delta):
             return False
 
         return True
@@ -78,9 +82,6 @@ class ModelState(object):
         self.model = model
         self.state = dict()
 
-    def clear(self):
-        self.state.clear()
-
     def _live_entity_map(self, entity_type):
         """Return an id:Entity map of all the living entities of
         type ``entity_type``.
@@ -395,7 +396,6 @@ class Model(object):
         await self.block_until(
             lambda: len(self.machines) == 0
         )
-        self.state.clear()
 
     async def block_until(self, *conditions, timeout=None):
         """Return only after all conditions are true.
@@ -431,7 +431,8 @@ class Model(object):
         return self.state.units
 
     def add_observer(
-            self, callable_, entity_type=None, action=None, entity_id=None):
+            self, callable_, entity_type=None, action=None, entity_id=None,
+            predicate=None):
         """Register an "on-model-change" callback
 
         Once the model is connected, ``callable_``
@@ -459,8 +460,13 @@ class Model(object):
             add_observer(
                 myfunc, entity_type='application', action='add', id_='ubuntu')
 
+        For more complex filtering conditions, pass a predicate function. It
+        will called with a delta as it's only argument. If the predicate
+        function returns True, the callable_ will be called.
+
         """
-        observer = _Observer(callable_, entity_type, action, entity_id)
+        observer = _Observer(
+            callable_, entity_type, action, entity_id, predicate)
         self.observers[observer] = callable_
 
     def _watch(self):
@@ -517,7 +523,7 @@ class Model(object):
             by applying this delta.
 
         """
-        if not old_obj:
+        if new_obj and not old_obj:
             delta.type = 'add'
 
         log.debug(
@@ -525,10 +531,10 @@ class Model(object):
             delta.entity, delta.type, delta.get_id())
 
         for o in self.observers:
-            if o.cares_about(delta.entity, delta.type, delta.get_id()):
+            if o.cares_about(delta):
                 asyncio.ensure_future(o(delta, old_obj, new_obj, self))
 
-    async def _wait_for_new(self, entity_type, entity_id):
+    async def _wait_for_new(self, entity_type, entity_id, predicate=None):
         """Wait for a new object to appear in the Model and return it.
 
         Waits for an object of type ``entity_type`` with id ``entity_id``.
@@ -536,12 +542,13 @@ class Model(object):
         This coroutine blocks until the new object appears in the model.
 
         """
-        entity_added = asyncio.Event(loop=self.loop)
+        entity_added = asyncio.Queue(loop=self.loop)
 
         async def callback(delta, old, new, model):
-            entity_added.set()
-        self.add_observer(callback, entity_type, 'add', entity_id)
-        await entity_added.wait()
+            await entity_added.put(delta.get_id())
+
+        self.add_observer(callback, entity_type, 'add', entity_id, predicate)
+        entity_id = await entity_added.get()
         return self.state._live_entity_map(entity_type)[entity_id]
 
     def add_machine(
@@ -587,7 +594,15 @@ class Model(object):
         log.debug(
             'Adding relation %s <-> %s', relation1, relation2)
 
-        return await app_facade.AddRelation([relation1, relation2])
+        result = await app_facade.AddRelation([relation1, relation2])
+
+        def predicate(delta):
+            endpoints = {}
+            for endpoint in delta.data['endpoints']:
+                endpoints[endpoint['application-name']] = endpoint['relation']
+            return endpoints == result.endpoints
+
+        return await self._wait_for_new('relation', None, predicate)
 
     def add_space(self, name, *cidrs):
         """Add a new network space.
@@ -729,8 +744,6 @@ class Model(object):
 
         TODO::
 
-            - entity_url must have a revision; look up latest automatically if
-              not provided by caller
             - service_name is required; fill this in automatically if not
               provided by caller
             - series is required; how do we pick a default?