Fixed docstring for addMachines.
[osm/N2VC.git] / juju / model.py
index 5d436fd..8d6c666 100644 (file)
@@ -6,11 +6,13 @@ import weakref
 from concurrent.futures import CancelledError
 from functools import partial
 
+import yaml
 from theblues import charmstore
 
 from .client import client
 from .client import watcher
 from .client import connection
+from .constraints import parse as parse_constraints
 from .delta import get_entity_delta
 from .delta import get_entity_class
 from .exceptions import DeadEntityException
@@ -26,12 +28,14 @@ class _Observer(object):
     callable so that it's only called for changes that meet the criteria.
 
     """
-    def __init__(self, callable_, entity_type, action, entity_id):
+    def __init__(self, callable_, entity_type, action, entity_id, predicate):
         self.callable_ = callable_
         self.entity_type = entity_type
         self.action = action
         self.entity_id = entity_id
+        self.predicate = predicate
         if self.entity_id:
+            self.entity_id = str(self.entity_id)
             if not self.entity_id.startswith('^'):
                 self.entity_id = '^' + self.entity_id
             if not self.entity_id.endswith('$'):
@@ -40,20 +44,22 @@ class _Observer(object):
     async def __call__(self, delta, old, new, model):
         await self.callable_(delta, old, new, model)
 
-    def cares_about(self, entity_type, action, entity_id):
+    def cares_about(self, delta):
         """Return True if this observer "cares about" (i.e. wants to be
-        called) for a change matching the entity_type, action, and
-        entity_id parameters.
+        called) for a this delta.
 
         """
-        if (self.entity_id and entity_id and
-                not re.match(self.entity_id, str(entity_id))):
+        if (self.entity_id and delta.get_id() and
+                not re.match(self.entity_id, str(delta.get_id()))):
             return False
 
-        if self.entity_type and self.entity_type != entity_type:
+        if self.entity_type and self.entity_type != delta.entity:
             return False
 
-        if self.action and self.action != action:
+        if self.action and self.action != delta.type:
+            return False
+
+        if self.predicate and not self.predicate(delta):
             return False
 
         return True
@@ -78,9 +84,6 @@ class ModelState(object):
         self.model = model
         self.state = dict()
 
-    def clear(self):
-        self.state.clear()
-
     def _live_entity_map(self, entity_type):
         """Return an id:Entity map of all the living entities of
         type ``entity_type``.
@@ -157,14 +160,12 @@ class ModelState(object):
 
     def get_entity(
             self, entity_type, entity_id, history_index=-1, connected=True):
-        """Return an object instance representing the entity created or
-        updated by ``delta``
+        """Return an object instance for the given entity_type and id.
+
+        By default the object state matches the most recent state from
+        Juju. To get an instance of the object in an older state, pass
+        history_index, an index into the history deque for the entity.
 
-        """
-        """
-        log.debug(
-            'Getting %s:%s at index %s',
-            entity_type, entity_id, history_index)
         """
 
         if history_index < 0 and history_index != -1:
@@ -349,6 +350,16 @@ class Model(object):
         self._watch_received = asyncio.Event(loop=loop)
         self._charmstore = CharmStore(self.loop)
 
+    async def connect(self, *args, **kw):
+        """Connect to an arbitrary Juju model.
+
+        args and kw are passed through to Connection.connect()
+
+        """
+        self.connection = await connection.Connection.connect(*args, **kw)
+        self._watch()
+        await self._watch_received.wait()
+
     async def connect_current(self):
         """Connect to the current Juju model.
 
@@ -357,6 +368,15 @@ class Model(object):
         self._watch()
         await self._watch_received.wait()
 
+    async def connect_model(self, arg):
+        """Connect to a specific Juju model.
+        :param arg:  <controller>:<user/model>
+
+        """
+        self.connection = await connection.Connection.connect_model(arg)
+        self._watch()
+        await self._watch_received.wait()
+
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
 
@@ -395,7 +415,6 @@ class Model(object):
         await self.block_until(
             lambda: len(self.machines) == 0
         )
-        self.state.clear()
 
     async def block_until(self, *conditions, timeout=None):
         """Return only after all conditions are true.
@@ -431,7 +450,8 @@ class Model(object):
         return self.state.units
 
     def add_observer(
-            self, callable_, entity_type=None, action=None, entity_id=None):
+            self, callable_, entity_type=None, action=None, entity_id=None,
+            predicate=None):
         """Register an "on-model-change" callback
 
         Once the model is connected, ``callable_``
@@ -459,8 +479,13 @@ class Model(object):
             add_observer(
                 myfunc, entity_type='application', action='add', id_='ubuntu')
 
+        For more complex filtering conditions, pass a predicate function. It
+        will be called with a delta as its only argument. If the predicate
+        function returns True, the callable_ will be called.
+
         """
-        observer = _Observer(callable_, entity_type, action, entity_id)
+        observer = _Observer(
+            callable_, entity_type, action, entity_id, predicate)
         self.observers[observer] = callable_
 
     def _watch(self):
@@ -517,7 +542,7 @@ class Model(object):
             by applying this delta.
 
         """
-        if not old_obj:
+        if new_obj and not old_obj:
             delta.type = 'add'
 
         log.debug(
@@ -525,10 +550,34 @@ class Model(object):
             delta.entity, delta.type, delta.get_id())
 
         for o in self.observers:
-            if o.cares_about(delta.entity, delta.type, delta.get_id()):
+            if o.cares_about(delta):
                 asyncio.ensure_future(o(delta, old_obj, new_obj, self))
 
-    async def _wait_for_new(self, entity_type, entity_id):
+    async def _wait(self, entity_type, entity_id, action, predicate=None):
+        """
+        Block the calling routine until a given action has happened to the
+        given entity
+
+        :param entity_type: The entity's type.
+        :param entity_id: The entity's id.
+        :param action: the type of action (e.g., 'add' or 'change')
+        :param predicate: optional callable that must take as an
+            argument a delta, and must return a boolean, indicating
+            whether the delta contains the specific action we're looking
+            for. For example, you might check to see whether a 'change'
+            has a 'completed' status. See the _Observer class for details.
+
+        """
+        q = asyncio.Queue(loop=self.loop)
+
+        async def callback(delta, old, new, model):
+            await q.put(delta.get_id())
+
+        self.add_observer(callback, entity_type, action, entity_id, predicate)
+        entity_id = await q.get()
+        return self.state._live_entity_map(entity_type)[entity_id]
+
+    async def _wait_for_new(self, entity_type, entity_id, predicate=None):
         """Wait for a new object to appear in the Model and return it.
 
         Waits for an object of type ``entity_type`` with id ``entity_id``.
@@ -536,13 +585,20 @@ class Model(object):
         This coroutine blocks until the new object appears in the model.
 
         """
-        entity_added = asyncio.Event(loop=self.loop)
+        return await self._wait(entity_type, entity_id, 'add', predicate)
 
-        async def callback(delta, old, new, model):
-            entity_added.set()
-        self.add_observer(callback, entity_type, 'add', entity_id)
-        await entity_added.wait()
-        return self.state._live_entity_map(entity_type)[entity_id]
+    async def wait_for_action(self, action_id):
+        """Given an action, wait for it to complete."""
+
+        if action_id.startswith("action-"):
+            # if we've been passed action.tag, transform it into the
+            # id that the api deltas will use.
+            action_id = action_id[7:]
+
+        def predicate(delta):
+            return delta.data['status'] in ('completed', 'failed')
+
+        return await self._wait('action', action_id, 'change', predicate)
 
     def add_machine(
             self, spec=None, constraints=None, disks=None, series=None,
@@ -587,7 +643,24 @@ class Model(object):
         log.debug(
             'Adding relation %s <-> %s', relation1, relation2)
 
-        return await app_facade.AddRelation([relation1, relation2])
+        try:
+            result = await app_facade.AddRelation([relation1, relation2])
+        except JujuAPIError as e:
+            if 'relation already exists' not in e.message:
+                raise
+            log.debug(
+                'Relation %s <-> %s already exists', relation1, relation2)
+            # TODO: if relation already exists we should return the
+            # Relation ModelEntity here
+            return None
+
+        def predicate(delta):
+            endpoints = {}
+            for endpoint in delta.data['endpoints']:
+                endpoints[endpoint['application-name']] = endpoint['relation']
+            return endpoints == result.endpoints
+
+        return await self._wait_for_new('relation', None, predicate)
 
     def add_space(self, name, *cidrs):
         """Add a new network space.
@@ -729,16 +802,11 @@ class Model(object):
 
         TODO::
 
-            - entity_url must have a revision; look up latest automatically if
-              not provided by caller
             - service_name is required; fill this in automatically if not
               provided by caller
             - series is required; how do we pick a default?
 
         """
-        if constraints:
-            constraints = client.Value(**constraints)
-
         if to:
             placement = [
                 client.Placement(**p) for p in to
@@ -763,6 +831,18 @@ class Model(object):
             handler = BundleHandler(self)
             await handler.fetch_plan(entity_id)
             await handler.execute_plan()
+            extant_apps = {app for app in self.applications}
+            pending_apps = set(handler.applications) - extant_apps
+            if pending_apps:
+                # new apps will usually be in the model by now, but if some
+                # haven't made it yet we'll need to wait on them to be added
+                await asyncio.gather(*[
+                    asyncio.ensure_future(
+                        self.model._wait_for_new('application', app_name))
+                    for app_name in pending_apps
+                ])
+            return [app for name, app in self.applications.items()
+                    if name in handler.applications]
         else:
             log.debug(
                 'Deploying %s', entity_id)
@@ -791,6 +871,21 @@ class Model(object):
         """
         pass
 
+    async def destroy_unit(self, *unit_names):
+        """Destroy units by name.
+
+        """
+        app_facade = client.ApplicationFacade()
+        app_facade.connect(self.connection)
+
+        log.debug(
+            'Destroying unit%s %s',
+            's' if len(unit_names) == 1 else '',
+            ' '.join(unit_names))
+
+        return await app_facade.Destroy(self.name)
+    destroy_units = destroy_unit
+
     def get_backup(self, archive_id):
         """Download a backup archive file.
 
@@ -1126,10 +1221,11 @@ class BundleHandler(object):
         self.ann_facade.connect(model.connection)
 
     async def fetch_plan(self, entity_id):
-        yaml = await self.charmstore.files(entity_id,
-                                           filename='bundle.yaml',
-                                           read_file=True)
-        self.plan = await self.client_facade.GetBundleChanges(yaml)
+        bundle_yaml = await self.charmstore.files(entity_id,
+                                                  filename='bundle.yaml',
+                                                  read_file=True)
+        self.bundle = yaml.safe_load(bundle_yaml)
+        self.plan = await self.client_facade.GetBundleChanges(bundle_yaml)
 
     async def execute_plan(self):
         for step in self.plan.changes:
@@ -1137,6 +1233,10 @@ class BundleHandler(object):
             result = await method(*step.args)
             self.references[step.id_] = result
 
+    @property
+    def applications(self):
+        return list(self.bundle['services'].keys())
+
     def resolve(self, reference):
         if reference and reference.startswith('$'):
             reference = self.references[reference[1:]]
@@ -1156,35 +1256,41 @@ class BundleHandler(object):
         await self.client_facade.AddCharm(None, entity_id)
         return entity_id
 
-    async def addMachines(self, series, constraints, container_type,
-                          parent_id):
-        """
-        :param series string:
-            Series holds the optional machine OS series.
-
-        :param constraints string:
-            Constraints holds the optional machine constraints.
-
-        :param Container_type string:
-            ContainerType optionally holds the type of the container (for
-            instance ""lxc" or kvm"). It is not specified for top level
-            machines.
-
-        :param parent_id string:
-            ParentId optionally holds a placeholder pointing to another machine
-            change or to a unit change. This value is only specified in the
-            case this machine is a container, in which case also ContainerType
-            is set.
-        """
-        params = client.AddMachineParams(
-            series=series,
-            constraints=constraints,
-            container_type=container_type,
-            parent_id=self.resolve(parent_id),
-        )
-        results = await self.client_facade.AddMachines(params)
-        log.debug('Added new machine %s', results[0].machine)
-        return results[0].machine
+    async def addMachines(self, params=None):
+        """
+        :param params dict:
+            Dictionary specifying the machine to add. All keys are optional.
+            Keys include:
+
+            series: string specifying the machine OS series.
+            constraints: string holding machine constraints, if any. We'll
+                parse this into the json friendly dict that the juju api
+                expects.
+            container_type: string holding the type of the container (for
+                instance ""lxc" or kvm"). It is not specified for top level
+                machines.
+            parent_id: string holding a placeholder pointing to another
+                machine change or to a unit change. This value is only
+                specified in the case this machine is a container, in
+                which case also ContainerType is set.
+        """
+        params = params or {}
+
+        if 'parent_id' in params:
+            params['parent_id'] = self.resolve(params['parent_id'])
+
+        params['constraints'] = parse_constraints(
+            params.get('constraints'))
+        params['jobs'] = params.get('jobs', ['JobHostUnits'])
+
+        params = client.AddMachineParams(**params)
+        results = await self.client_facade.AddMachines([params])
+        error = results.machines[0].error
+        if error:
+            raise ValueError("Error adding machine: %s", error.message)
+        machine = results.machines[0].machine
+        log.debug('Added new machine %s', machine)
+        return machine
 
     async def addRelation(self, endpoint1, endpoint2):
         """
@@ -1201,14 +1307,9 @@ class BundleHandler(object):
             parts = endpoints[i].split(':')
             parts[0] = self.resolve(parts[0])
             endpoints[i] = ':'.join(parts)
-        try:
-            await self.app_facade.AddRelation(endpoints)
-            log.debug('Added relation %s <-> %s', *endpoints)
-        except JujuAPIError as e:
-            if 'relation already exists' not in e.message:
-                raise
-            log.debug('Relation %s <-> %s already exists', *endpoints)
-        return None
+
+        log.info('Relating %s <-> %s', *endpoints)
+        return await self.model.add_relation(*endpoints)
 
     async def deploy(self, charm, series, application, options, constraints,
                      storage, endpoint_bindings, resources):
@@ -1256,7 +1357,7 @@ class BundleHandler(object):
             resources=resources,
         )
         # do the do
-        log.debug('Deploying %s', charm)
+        log.info('Deploying %s', charm)
         await self.app_facade.Deploy([app])
         return application
 
@@ -1279,16 +1380,14 @@ class BundleHandler(object):
             # doesn't, so we're not bothering, either
             unit_name = self._units_by_app[application].pop()
             log.debug('Reusing unit %s for %s', unit_name, application)
-            return unit_name
-        log.debug('Adding unit of %s%s',
-                  application,
-                  (' to %s' % placement) if placement else '')
-        result = await self.app_facade.AddUnits(
-            application=application,
-            placement=placement,
-            num_units=1,
+            return self.model.units[unit_name]
+
+        log.debug('Adding new unit for %s%s', application,
+                  ' to %s' % placement if placement else '')
+        return await self.model.applications[application].add_unit(
+            count=1,
+            to=placement,
         )
-        return result.units[0]
 
     async def expose(self, application):
         """
@@ -1297,9 +1396,8 @@ class BundleHandler(object):
             be exposed.
         """
         application = self.resolve(application)
-        log.debug('Exposing %s', application)
-        await self.app_facade.Expose(application)
-        return None
+        log.info('Exposing %s', application)
+        return await self.model.applications[application].expose()
 
     async def setAnnotations(self, id_, entity_type, annotations):
         """
@@ -1315,13 +1413,11 @@ class BundleHandler(object):
             Annotations holds the annotations as key/value pairs.
         """
         entity_id = self.resolve(id_)
-        log.debug('Updating annotations of %s', entity_id)
-        ann = client.EntityAnnotations(
-            entity=entity_id,
-            annotations=annotations,
-        )
-        await self.ann_facade.Set([ann])
-        return None
+        try:
+            entity = self.model.state.get_entity(entity_type, entity_id)
+        except KeyError:
+            entity = await self.model._wait_for_new(entity_type, entity_id)
+        return await entity.set_annotations(annotations)
 
 
 class CharmStore(object):