Addressed PR comments.
[osm/N2VC.git] / juju / model.py
index cc9ee95..b74a6c8 100644 (file)
@@ -535,6 +535,30 @@ class Model(object):
             if o.cares_about(delta):
                 asyncio.ensure_future(o(delta, old_obj, new_obj, self))
 
+    async def _wait(self, entity_type, entity_id, action, predicate=None):
+        """
+        Block the calling routine until a given action has happened to the
+        given entity
+
+        :param entity_type: The entity's type.
+        :param entity_id: The entity's id.
+        :param action: the type of action (e.g., 'add' or 'change')
+        :param predicate: optional callable that must take as an
+            argument a delta, and must return a boolean, indicating
+            whether the delta contains the specific action we're looking
+            for. For example, you might check to see whether a 'change'
+            has a 'completed' status. See the _Observer class for details.
+
+        """
+        q = asyncio.Queue(loop=self.loop)
+
+        async def callback(delta, old, new, model):
+            await q.put(delta.get_id())
+
+        self.add_observer(callback, entity_type, action, entity_id, predicate)
+        entity_id = await q.get()
+        return self.state._live_entity_map(entity_type)[entity_id]
+
     async def _wait_for_new(self, entity_type, entity_id, predicate=None):
         """Wait for a new object to appear in the Model and return it.
 
@@ -543,14 +567,20 @@ class Model(object):
         This coroutine blocks until the new object appears in the model.
 
         """
-        entity_added = asyncio.Queue(loop=self.loop)
+        return await self._wait(entity_type, entity_id, predicate)
 
-        async def callback(delta, old, new, model):
-            await entity_added.put(delta.get_id())
+    async def wait_for_action(self, action_id):
+        """Given an action, wait for it to complete."""
 
-        self.add_observer(callback, entity_type, 'add', entity_id, predicate)
-        entity_id = await entity_added.get()
-        return self.state._live_entity_map(entity_type)[entity_id]
+        if action_id.startswith("action-"):
+            # if we've been passed action.tag, transform it into the
+            # id that the api deltas will use.
+            action_id = action_id[7:]
+
+        def predicate(delta):
+            return delta.data['status'] in ('completed', 'error')
+
+        return await self._wait('action', action_id, 'change', predicate)
 
     def add_machine(
             self, spec=None, constraints=None, disks=None, series=None,
@@ -759,9 +789,6 @@ class Model(object):
             - series is required; how do we pick a default?
 
         """
-        if constraints:
-            constraints = client.Value(**constraints)
-
         if to:
             placement = [
                 client.Placement(**p) for p in to
@@ -791,8 +818,11 @@ class Model(object):
             if pending_apps:
                 # new apps will usually be in the model by now, but if some
                 # haven't made it yet we'll need to wait on them to be added
-                await asyncio.wait([self._wait_for_new('application', app_name)
-                                    for app_name in pending_apps])
+                await asyncio.gather(*[
+                    asyncio.ensure_future(
+                        self.model._wait_for_new('application', app_name))
+                    for app_name in pending_apps
+                ])
             return [app for name, app in self.applications.items()
                     if name in handler.applications]
         else:
@@ -815,7 +845,7 @@ class Model(object):
             )
 
             await app_facade.Deploy([app])
-            return [await self._wait_for_new('application', service_name)]
+            return await self._wait_for_new('application', service_name)
 
     def destroy(self):
         """Terminate all machines and resources for this model.