From 7fdf96510e9aa6960d6c956a7178585cb65caefe Mon Sep 17 00:00:00 2001 From: Tim Van Steenburgh Date: Mon, 17 Oct 2016 18:39:45 -0400 Subject: [PATCH] Add predicate support for observers and Model._wait_for_new() --- examples/relate.py | 12 ++++++--- juju/model.py | 61 ++++++++++++++++++++++++++++------------------ 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/examples/relate.py b/examples/relate.py index fa32900..241d1a6 100644 --- a/examples/relate.py +++ b/examples/relate.py @@ -30,7 +30,7 @@ class MyModelObserver(ModelObserver): _shutting_down = False async def on_change(self, delta, old, new, model): - if model.all_units_idle() and not self._shutting_down: + if model.units and model.all_units_idle() and not self._shutting_down: self._shutting_down = True logging.debug('All units idle, disconnecting') await model.reset(force=True) @@ -47,7 +47,7 @@ async def run(): model.add_observer(MyModelObserver()) ubuntu_app = await model.deploy( - 'ubuntu-0', + 'ubuntu', service_name='ubuntu', series='trusty', channel='stable', @@ -69,16 +69,20 @@ async def run(): print('Unit removed: {}'.format(old_unit.entity_id)) )) await model.deploy( - 'nrpe-11', + 'nrpe', service_name='nrpe', series='trusty', channel='stable', num_units=0, ) - await model.add_relation( + my_relation = await model.add_relation( 'ubuntu', 'nrpe', ) + my_relation.on_remove(asyncio.coroutine( + lambda delta, old_rel, new_rel, model: + print('Relation removed: {}'.format(old_rel.endpoints)) + )) logging.basicConfig(level=logging.DEBUG) ws_logger = logging.getLogger('websockets.protocol') diff --git a/juju/model.py b/juju/model.py index 5d436fd..49f07a9 100644 --- a/juju/model.py +++ b/juju/model.py @@ -26,12 +26,14 @@ class _Observer(object): callable so that it's only called for changes that meet the criteria. """ - def __init__(self, callable_, entity_type, action, entity_id): + def __init__(self, callable_, entity_type, action, entity_id, predicate): self.callable_ = callable_ self.entity_type = entity_type self.action = action self.entity_id = entity_id + self.predicate = predicate if self.entity_id: + self.entity_id = str(self.entity_id) if not self.entity_id.startswith('^'): self.entity_id = '^' + self.entity_id if not self.entity_id.endswith('$'): @@ -40,20 +42,22 @@ class _Observer(object): async def __call__(self, delta, old, new, model): await self.callable_(delta, old, new, model) - def cares_about(self, entity_type, action, entity_id): + def cares_about(self, delta): """Return True if this observer "cares about" (i.e. wants to be - called) for a change matching the entity_type, action, and - entity_id parameters. + called) for a this delta. """ - if (self.entity_id and entity_id and - not re.match(self.entity_id, str(entity_id))): + if (self.entity_id and delta.get_id() and + not re.match(self.entity_id, str(delta.get_id()))): return False - if self.entity_type and self.entity_type != entity_type: + if self.entity_type and self.entity_type != delta.entity: return False - if self.action and self.action != action: + if self.action and self.action != delta.type: + return False + + if self.predicate and not self.predicate(delta): return False return True @@ -78,9 +82,6 @@ class ModelState(object): self.model = model self.state = dict() - def clear(self): - self.state.clear() - def _live_entity_map(self, entity_type): """Return an id:Entity map of all the living entities of type ``entity_type``. @@ -395,7 +396,6 @@ class Model(object): await self.block_until( lambda: len(self.machines) == 0 ) - self.state.clear() async def block_until(self, *conditions, timeout=None): """Return only after all conditions are true. @@ -431,7 +431,8 @@ class Model(object): return self.state.units def add_observer( - self, callable_, entity_type=None, action=None, entity_id=None): + self, callable_, entity_type=None, action=None, entity_id=None, + predicate=None): """Register an "on-model-change" callback Once the model is connected, ``callable_`` @@ -459,8 +460,13 @@ class Model(object): add_observer( myfunc, entity_type='application', action='add', id_='ubuntu') + For more complex filtering conditions, pass a predicate function. It + will called with a delta as it's only argument. If the predicate + function returns True, the callable_ will be called. + """ - observer = _Observer(callable_, entity_type, action, entity_id) + observer = _Observer( + callable_, entity_type, action, entity_id, predicate) self.observers[observer] = callable_ def _watch(self): @@ -517,7 +523,7 @@ class Model(object): by applying this delta. """ - if not old_obj: + if new_obj and not old_obj: delta.type = 'add' log.debug( @@ -525,10 +531,10 @@ class Model(object): delta.entity, delta.type, delta.get_id()) for o in self.observers: - if o.cares_about(delta.entity, delta.type, delta.get_id()): + if o.cares_about(delta): asyncio.ensure_future(o(delta, old_obj, new_obj, self)) - async def _wait_for_new(self, entity_type, entity_id): + async def _wait_for_new(self, entity_type, entity_id, predicate=None): """Wait for a new object to appear in the Model and return it. Waits for an object of type ``entity_type`` with id ``entity_id``. @@ -536,12 +542,13 @@ class Model(object): This coroutine blocks until the new object appears in the model. """ - entity_added = asyncio.Event(loop=self.loop) + entity_added = asyncio.Queue(loop=self.loop) async def callback(delta, old, new, model): - entity_added.set() - self.add_observer(callback, entity_type, 'add', entity_id) - await entity_added.wait() + await entity_added.put(delta.get_id()) + + self.add_observer(callback, entity_type, 'add', entity_id, predicate) + entity_id = await entity_added.get() return self.state._live_entity_map(entity_type)[entity_id] def add_machine( @@ -587,7 +594,15 @@ class Model(object): log.debug( 'Adding relation %s <-> %s', relation1, relation2) - return await app_facade.AddRelation([relation1, relation2]) + result = await app_facade.AddRelation([relation1, relation2]) + + def predicate(delta): + endpoints = {} + for endpoint in delta.data['endpoints']: + endpoints[endpoint['application-name']] = endpoint['relation'] + return endpoints == result.endpoints + + return await self._wait_for_new('relation', None, predicate) def add_space(self, name, *cidrs): """Add a new network space. @@ -729,8 +744,6 @@ class Model(object): TODO:: - - entity_url must have a revision; look up latest automatically if - not provided by caller - service_name is required; fill this in automatically if not provided by caller - series is required; how do we pick a default? -- 2.25.1