Fixed docstring for addMachines.
[osm/N2VC.git] / juju / model.py
index 7781d34..8d6c666 100644 (file)
@@ -6,11 +6,13 @@ import weakref
 from concurrent.futures import CancelledError
 from functools import partial
 
 from concurrent.futures import CancelledError
 from functools import partial
 
+import yaml
 from theblues import charmstore
 
 from .client import client
 from .client import watcher
 from .client import connection
 from theblues import charmstore
 
 from .client import client
 from .client import watcher
 from .client import connection
+from .constraints import parse as parse_constraints
 from .delta import get_entity_delta
 from .delta import get_entity_class
 from .exceptions import DeadEntityException
 from .delta import get_entity_delta
 from .delta import get_entity_class
 from .exceptions import DeadEntityException
@@ -158,14 +160,12 @@ class ModelState(object):
 
     def get_entity(
             self, entity_type, entity_id, history_index=-1, connected=True):
 
     def get_entity(
             self, entity_type, entity_id, history_index=-1, connected=True):
-        """Return an object instance representing the entity created or
-        updated by ``delta``
+        """Return an object instance for the given entity_type and id.
+
+        By default the object state matches the most recent state from
+        Juju. To get an instance of the object in an older state, pass
+        history_index, an index into the history deque for the entity.
 
 
-        """
-        """
-        log.debug(
-            'Getting %s:%s at index %s',
-            entity_type, entity_id, history_index)
         """
 
         if history_index < 0 and history_index != -1:
         """
 
         if history_index < 0 and history_index != -1:
@@ -350,6 +350,16 @@ class Model(object):
         self._watch_received = asyncio.Event(loop=loop)
         self._charmstore = CharmStore(self.loop)
 
         self._watch_received = asyncio.Event(loop=loop)
         self._charmstore = CharmStore(self.loop)
 
+    async def connect(self, *args, **kw):
+        """Connect to an arbitrary Juju model.
+
+        args and kw are passed through to Connection.connect()
+
+        """
+        self.connection = await connection.Connection.connect(*args, **kw)
+        self._watch()
+        await self._watch_received.wait()
+
     async def connect_current(self):
         """Connect to the current Juju model.
 
     async def connect_current(self):
         """Connect to the current Juju model.
 
@@ -358,6 +368,15 @@ class Model(object):
         self._watch()
         await self._watch_received.wait()
 
         self._watch()
         await self._watch_received.wait()
 
+    async def connect_model(self, arg):
+        """Connect to a specific Juju model.
+        :param arg:  <controller>:<user/model>
+
+        """
+        self.connection = await connection.Connection.connect_model(arg)
+        self._watch()
+        await self._watch_received.wait()
+
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
 
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
 
@@ -534,6 +553,30 @@ class Model(object):
             if o.cares_about(delta):
                 asyncio.ensure_future(o(delta, old_obj, new_obj, self))
 
             if o.cares_about(delta):
                 asyncio.ensure_future(o(delta, old_obj, new_obj, self))
 
+    async def _wait(self, entity_type, entity_id, action, predicate=None):
+        """
+        Block the calling routine until a given action has happened to the
+        given entity
+
+        :param entity_type: The entity's type.
+        :param entity_id: The entity's id.
+        :param action: the type of action (e.g., 'add' or 'change')
+        :param predicate: optional callable that must take as an
+            argument a delta, and must return a boolean, indicating
+            whether the delta contains the specific action we're looking
+            for. For example, you might check to see whether a 'change'
+            has a 'completed' status. See the _Observer class for details.
+
+        """
+        q = asyncio.Queue(loop=self.loop)
+
+        async def callback(delta, old, new, model):
+            await q.put(delta.get_id())
+
+        self.add_observer(callback, entity_type, action, entity_id, predicate)
+        entity_id = await q.get()
+        return self.state._live_entity_map(entity_type)[entity_id]
+
     async def _wait_for_new(self, entity_type, entity_id, predicate=None):
         """Wait for a new object to appear in the Model and return it.
 
     async def _wait_for_new(self, entity_type, entity_id, predicate=None):
         """Wait for a new object to appear in the Model and return it.
 
@@ -542,14 +585,20 @@ class Model(object):
         This coroutine blocks until the new object appears in the model.
 
         """
         This coroutine blocks until the new object appears in the model.
 
         """
-        entity_added = asyncio.Queue(loop=self.loop)
+        return await self._wait(entity_type, entity_id, 'add', predicate)
 
 
-        async def callback(delta, old, new, model):
-            await entity_added.put(delta.get_id())
+    async def wait_for_action(self, action_id):
+        """Given an action, wait for it to complete."""
 
 
-        self.add_observer(callback, entity_type, 'add', entity_id, predicate)
-        entity_id = await entity_added.get()
-        return self.state._live_entity_map(entity_type)[entity_id]
+        if action_id.startswith("action-"):
+            # if we've been passed action.tag, transform it into the
+            # id that the api deltas will use.
+            action_id = action_id[7:]
+
+        def predicate(delta):
+            return delta.data['status'] in ('completed', 'failed')
+
+        return await self._wait('action', action_id, 'change', predicate)
 
     def add_machine(
             self, spec=None, constraints=None, disks=None, series=None,
 
     def add_machine(
             self, spec=None, constraints=None, disks=None, series=None,
@@ -758,9 +807,6 @@ class Model(object):
             - series is required; how do we pick a default?
 
         """
             - series is required; how do we pick a default?
 
         """
-        if constraints:
-            constraints = client.Value(**constraints)
-
         if to:
             placement = [
                 client.Placement(**p) for p in to
         if to:
             placement = [
                 client.Placement(**p) for p in to
@@ -785,6 +831,18 @@ class Model(object):
             handler = BundleHandler(self)
             await handler.fetch_plan(entity_id)
             await handler.execute_plan()
             handler = BundleHandler(self)
             await handler.fetch_plan(entity_id)
             await handler.execute_plan()
+            extant_apps = {app for app in self.applications}
+            pending_apps = set(handler.applications) - extant_apps
+            if pending_apps:
+                # new apps will usually be in the model by now, but if some
+                # haven't made it yet we'll need to wait on them to be added
+                await asyncio.gather(*[
+                    asyncio.ensure_future(
+                        self.model._wait_for_new('application', app_name))
+                    for app_name in pending_apps
+                ])
+            return [app for name, app in self.applications.items()
+                    if name in handler.applications]
         else:
             log.debug(
                 'Deploying %s', entity_id)
         else:
             log.debug(
                 'Deploying %s', entity_id)
@@ -813,6 +871,21 @@ class Model(object):
         """
         pass
 
         """
         pass
 
+    async def destroy_unit(self, *unit_names):
+        """Destroy units by name.
+
+        """
+        app_facade = client.ApplicationFacade()
+        app_facade.connect(self.connection)
+
+        log.debug(
+            'Destroying unit%s %s',
+            's' if len(unit_names) == 1 else '',
+            ' '.join(unit_names))
+
+        return await app_facade.Destroy(self.name)
+    destroy_units = destroy_unit
+
     def get_backup(self, archive_id):
         """Download a backup archive file.
 
     def get_backup(self, archive_id):
         """Download a backup archive file.
 
@@ -1148,10 +1221,11 @@ class BundleHandler(object):
         self.ann_facade.connect(model.connection)
 
     async def fetch_plan(self, entity_id):
         self.ann_facade.connect(model.connection)
 
     async def fetch_plan(self, entity_id):
-        yaml = await self.charmstore.files(entity_id,
-                                           filename='bundle.yaml',
-                                           read_file=True)
-        self.plan = await self.client_facade.GetBundleChanges(yaml)
+        bundle_yaml = await self.charmstore.files(entity_id,
+                                                  filename='bundle.yaml',
+                                                  read_file=True)
+        self.bundle = yaml.safe_load(bundle_yaml)
+        self.plan = await self.client_facade.GetBundleChanges(bundle_yaml)
 
     async def execute_plan(self):
         for step in self.plan.changes:
 
     async def execute_plan(self):
         for step in self.plan.changes:
@@ -1159,6 +1233,10 @@ class BundleHandler(object):
             result = await method(*step.args)
             self.references[step.id_] = result
 
             result = await method(*step.args)
             self.references[step.id_] = result
 
+    @property
+    def applications(self):
+        return list(self.bundle['services'].keys())
+
     def resolve(self, reference):
         if reference and reference.startswith('$'):
             reference = self.references[reference[1:]]
     def resolve(self, reference):
         if reference and reference.startswith('$'):
             reference = self.references[reference[1:]]
@@ -1178,35 +1256,41 @@ class BundleHandler(object):
         await self.client_facade.AddCharm(None, entity_id)
         return entity_id
 
         await self.client_facade.AddCharm(None, entity_id)
         return entity_id
 
-    async def addMachines(self, series, constraints, container_type,
-                          parent_id):
-        """
-        :param series string:
-            Series holds the optional machine OS series.
-
-        :param constraints string:
-            Constraints holds the optional machine constraints.
-
-        :param Container_type string:
-            ContainerType optionally holds the type of the container (for
-            instance ""lxc" or kvm"). It is not specified for top level
-            machines.
-
-        :param parent_id string:
-            ParentId optionally holds a placeholder pointing to another machine
-            change or to a unit change. This value is only specified in the
-            case this machine is a container, in which case also ContainerType
-            is set.
-        """
-        params = client.AddMachineParams(
-            series=series,
-            constraints=constraints,
-            container_type=container_type,
-            parent_id=self.resolve(parent_id),
-        )
-        results = await self.client_facade.AddMachines(params)
-        log.debug('Added new machine %s', results[0].machine)
-        return results[0].machine
+    async def addMachines(self, params=None):
+        """
+        :param params dict:
+            Dictionary specifying the machine to add. All keys are optional.
+            Keys include:
+
+            series: string specifying the machine OS series.
+            constraints: string holding machine constraints, if any. We'll
+                parse this into the json friendly dict that the juju api
+                expects.
+            container_type: string holding the type of the container (for
+                instance ""lxc" or kvm"). It is not specified for top level
+                machines.
+            parent_id: string holding a placeholder pointing to another
+                machine change or to a unit change. This value is only
+                specified in the case this machine is a container, in
+                which case also ContainerType is set.
+        """
+        params = params or {}
+
+        if 'parent_id' in params:
+            params['parent_id'] = self.resolve(params['parent_id'])
+
+        params['constraints'] = parse_constraints(
+            params.get('constraints'))
+        params['jobs'] = params.get('jobs', ['JobHostUnits'])
+
+        params = client.AddMachineParams(**params)
+        results = await self.client_facade.AddMachines([params])
+        error = results.machines[0].error
+        if error:
+            raise ValueError("Error adding machine: %s", error.message)
+        machine = results.machines[0].machine
+        log.debug('Added new machine %s', machine)
+        return machine
 
     async def addRelation(self, endpoint1, endpoint2):
         """
 
     async def addRelation(self, endpoint1, endpoint2):
         """
@@ -1224,6 +1308,7 @@ class BundleHandler(object):
             parts[0] = self.resolve(parts[0])
             endpoints[i] = ':'.join(parts)
 
             parts[0] = self.resolve(parts[0])
             endpoints[i] = ':'.join(parts)
 
+        log.info('Relating %s <-> %s', *endpoints)
         return await self.model.add_relation(*endpoints)
 
     async def deploy(self, charm, series, application, options, constraints,
         return await self.model.add_relation(*endpoints)
 
     async def deploy(self, charm, series, application, options, constraints,
@@ -1272,7 +1357,7 @@ class BundleHandler(object):
             resources=resources,
         )
         # do the do
             resources=resources,
         )
         # do the do
-        log.debug('Deploying %s', charm)
+        log.info('Deploying %s', charm)
         await self.app_facade.Deploy([app])
         return application
 
         await self.app_facade.Deploy([app])
         return application
 
@@ -1297,6 +1382,8 @@ class BundleHandler(object):
             log.debug('Reusing unit %s for %s', unit_name, application)
             return self.model.units[unit_name]
 
             log.debug('Reusing unit %s for %s', unit_name, application)
             return self.model.units[unit_name]
 
+        log.debug('Adding new unit for %s%s', application,
+                  ' to %s' % placement if placement else '')
         return await self.model.applications[application].add_unit(
             count=1,
             to=placement,
         return await self.model.applications[application].add_unit(
             count=1,
             to=placement,
@@ -1309,6 +1396,7 @@ class BundleHandler(object):
             be exposed.
         """
         application = self.resolve(application)
             be exposed.
         """
         application = self.resolve(application)
+        log.info('Exposing %s', application)
         return await self.model.applications[application].expose()
 
     async def setAnnotations(self, id_, entity_type, annotations):
         return await self.model.applications[application].expose()
 
     async def setAnnotations(self, id_, entity_type, annotations):
@@ -1325,7 +1413,10 @@ class BundleHandler(object):
             Annotations holds the annotations as key/value pairs.
         """
         entity_id = self.resolve(id_)
             Annotations holds the annotations as key/value pairs.
         """
         entity_id = self.resolve(id_)
-        entity = self.model.state.get_entity(entity_type, entity_id)
+        try:
+            entity = self.model.state.get_entity(entity_type, entity_id)
+        except KeyError:
+            entity = await self.model._wait_for_new(entity_type, entity_id)
         return await entity.set_annotations(annotations)
 
 
         return await entity.set_annotations(annotations)