Fixed docstring for addMachines.
[osm/N2VC.git] / juju / model.py
index 5d436fd..8d6c666 100644 (file)
@@ -6,11 +6,13 @@ import weakref
 from concurrent.futures import CancelledError
 from functools import partial
 
 from concurrent.futures import CancelledError
 from functools import partial
 
+import yaml
 from theblues import charmstore
 
 from .client import client
 from .client import watcher
 from .client import connection
 from theblues import charmstore
 
 from .client import client
 from .client import watcher
 from .client import connection
+from .constraints import parse as parse_constraints
 from .delta import get_entity_delta
 from .delta import get_entity_class
 from .exceptions import DeadEntityException
 from .delta import get_entity_delta
 from .delta import get_entity_class
 from .exceptions import DeadEntityException
@@ -26,12 +28,14 @@ class _Observer(object):
     callable so that it's only called for changes that meet the criteria.
 
     """
     callable so that it's only called for changes that meet the criteria.
 
     """
-    def __init__(self, callable_, entity_type, action, entity_id):
+    def __init__(self, callable_, entity_type, action, entity_id, predicate):
         self.callable_ = callable_
         self.entity_type = entity_type
         self.action = action
         self.entity_id = entity_id
         self.callable_ = callable_
         self.entity_type = entity_type
         self.action = action
         self.entity_id = entity_id
+        self.predicate = predicate
         if self.entity_id:
         if self.entity_id:
+            self.entity_id = str(self.entity_id)
             if not self.entity_id.startswith('^'):
                 self.entity_id = '^' + self.entity_id
             if not self.entity_id.endswith('$'):
             if not self.entity_id.startswith('^'):
                 self.entity_id = '^' + self.entity_id
             if not self.entity_id.endswith('$'):
@@ -40,20 +44,22 @@ class _Observer(object):
     async def __call__(self, delta, old, new, model):
         await self.callable_(delta, old, new, model)
 
     async def __call__(self, delta, old, new, model):
         await self.callable_(delta, old, new, model)
 
-    def cares_about(self, entity_type, action, entity_id):
+    def cares_about(self, delta):
         """Return True if this observer "cares about" (i.e. wants to be
         """Return True if this observer "cares about" (i.e. wants to be
-        called) for a change matching the entity_type, action, and
-        entity_id parameters.
+        called) for a this delta.
 
         """
 
         """
-        if (self.entity_id and entity_id and
-                not re.match(self.entity_id, str(entity_id))):
+        if (self.entity_id and delta.get_id() and
+                not re.match(self.entity_id, str(delta.get_id()))):
             return False
 
             return False
 
-        if self.entity_type and self.entity_type != entity_type:
+        if self.entity_type and self.entity_type != delta.entity:
             return False
 
             return False
 
-        if self.action and self.action != action:
+        if self.action and self.action != delta.type:
+            return False
+
+        if self.predicate and not self.predicate(delta):
             return False
 
         return True
             return False
 
         return True
@@ -78,9 +84,6 @@ class ModelState(object):
         self.model = model
         self.state = dict()
 
         self.model = model
         self.state = dict()
 
-    def clear(self):
-        self.state.clear()
-
     def _live_entity_map(self, entity_type):
         """Return an id:Entity map of all the living entities of
         type ``entity_type``.
     def _live_entity_map(self, entity_type):
         """Return an id:Entity map of all the living entities of
         type ``entity_type``.
@@ -157,14 +160,12 @@ class ModelState(object):
 
     def get_entity(
             self, entity_type, entity_id, history_index=-1, connected=True):
 
     def get_entity(
             self, entity_type, entity_id, history_index=-1, connected=True):
-        """Return an object instance representing the entity created or
-        updated by ``delta``
+        """Return an object instance for the given entity_type and id.
+
+        By default the object state matches the most recent state from
+        Juju. To get an instance of the object in an older state, pass
+        history_index, an index into the history deque for the entity.
 
 
-        """
-        """
-        log.debug(
-            'Getting %s:%s at index %s',
-            entity_type, entity_id, history_index)
         """
 
         if history_index < 0 and history_index != -1:
         """
 
         if history_index < 0 and history_index != -1:
@@ -349,6 +350,16 @@ class Model(object):
         self._watch_received = asyncio.Event(loop=loop)
         self._charmstore = CharmStore(self.loop)
 
         self._watch_received = asyncio.Event(loop=loop)
         self._charmstore = CharmStore(self.loop)
 
+    async def connect(self, *args, **kw):
+        """Connect to an arbitrary Juju model.
+
+        args and kw are passed through to Connection.connect()
+
+        """
+        self.connection = await connection.Connection.connect(*args, **kw)
+        self._watch()
+        await self._watch_received.wait()
+
     async def connect_current(self):
         """Connect to the current Juju model.
 
     async def connect_current(self):
         """Connect to the current Juju model.
 
@@ -357,6 +368,15 @@ class Model(object):
         self._watch()
         await self._watch_received.wait()
 
         self._watch()
         await self._watch_received.wait()
 
+    async def connect_model(self, arg):
+        """Connect to a specific Juju model.
+        :param arg:  <controller>:<user/model>
+
+        """
+        self.connection = await connection.Connection.connect_model(arg)
+        self._watch()
+        await self._watch_received.wait()
+
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
 
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
 
@@ -395,7 +415,6 @@ class Model(object):
         await self.block_until(
             lambda: len(self.machines) == 0
         )
         await self.block_until(
             lambda: len(self.machines) == 0
         )
-        self.state.clear()
 
     async def block_until(self, *conditions, timeout=None):
         """Return only after all conditions are true.
 
     async def block_until(self, *conditions, timeout=None):
         """Return only after all conditions are true.
@@ -431,7 +450,8 @@ class Model(object):
         return self.state.units
 
     def add_observer(
         return self.state.units
 
     def add_observer(
-            self, callable_, entity_type=None, action=None, entity_id=None):
+            self, callable_, entity_type=None, action=None, entity_id=None,
+            predicate=None):
         """Register an "on-model-change" callback
 
         Once the model is connected, ``callable_``
         """Register an "on-model-change" callback
 
         Once the model is connected, ``callable_``
@@ -459,8 +479,13 @@ class Model(object):
             add_observer(
                 myfunc, entity_type='application', action='add', id_='ubuntu')
 
             add_observer(
                 myfunc, entity_type='application', action='add', id_='ubuntu')
 
+        For more complex filtering conditions, pass a predicate function. It
+        will be called with a delta as its only argument. If the predicate
+        function returns True, the callable_ will be called.
+
         """
         """
-        observer = _Observer(callable_, entity_type, action, entity_id)
+        observer = _Observer(
+            callable_, entity_type, action, entity_id, predicate)
         self.observers[observer] = callable_
 
     def _watch(self):
         self.observers[observer] = callable_
 
     def _watch(self):
@@ -517,7 +542,7 @@ class Model(object):
             by applying this delta.
 
         """
             by applying this delta.
 
         """
-        if not old_obj:
+        if new_obj and not old_obj:
             delta.type = 'add'
 
         log.debug(
             delta.type = 'add'
 
         log.debug(
@@ -525,10 +550,34 @@ class Model(object):
             delta.entity, delta.type, delta.get_id())
 
         for o in self.observers:
             delta.entity, delta.type, delta.get_id())
 
         for o in self.observers:
-            if o.cares_about(delta.entity, delta.type, delta.get_id()):
+            if o.cares_about(delta):
                 asyncio.ensure_future(o(delta, old_obj, new_obj, self))
 
                 asyncio.ensure_future(o(delta, old_obj, new_obj, self))
 
-    async def _wait_for_new(self, entity_type, entity_id):
+    async def _wait(self, entity_type, entity_id, action, predicate=None):
+        """
+        Block the calling routine until a given action has happened to the
+        given entity
+
+        :param entity_type: The entity's type.
+        :param entity_id: The entity's id.
+        :param action: the type of action (e.g., 'add' or 'change')
+        :param predicate: optional callable that must take as an
+            argument a delta, and must return a boolean, indicating
+            whether the delta contains the specific action we're looking
+            for. For example, you might check to see whether a 'change'
+            has a 'completed' status. See the _Observer class for details.
+
+        """
+        q = asyncio.Queue(loop=self.loop)
+
+        async def callback(delta, old, new, model):
+            await q.put(delta.get_id())
+
+        self.add_observer(callback, entity_type, action, entity_id, predicate)
+        entity_id = await q.get()
+        return self.state._live_entity_map(entity_type)[entity_id]
+
+    async def _wait_for_new(self, entity_type, entity_id, predicate=None):
         """Wait for a new object to appear in the Model and return it.
 
         Waits for an object of type ``entity_type`` with id ``entity_id``.
         """Wait for a new object to appear in the Model and return it.
 
         Waits for an object of type ``entity_type`` with id ``entity_id``.
@@ -536,13 +585,20 @@ class Model(object):
         This coroutine blocks until the new object appears in the model.
 
         """
         This coroutine blocks until the new object appears in the model.
 
         """
-        entity_added = asyncio.Event(loop=self.loop)
+        return await self._wait(entity_type, entity_id, 'add', predicate)
 
 
-        async def callback(delta, old, new, model):
-            entity_added.set()
-        self.add_observer(callback, entity_type, 'add', entity_id)
-        await entity_added.wait()
-        return self.state._live_entity_map(entity_type)[entity_id]
+    async def wait_for_action(self, action_id):
+        """Given an action, wait for it to complete."""
+
+        if action_id.startswith("action-"):
+            # if we've been passed action.tag, transform it into the
+            # id that the api deltas will use.
+            action_id = action_id[7:]
+
+        def predicate(delta):
+            return delta.data['status'] in ('completed', 'failed')
+
+        return await self._wait('action', action_id, 'change', predicate)
 
     def add_machine(
             self, spec=None, constraints=None, disks=None, series=None,
 
     def add_machine(
             self, spec=None, constraints=None, disks=None, series=None,
@@ -587,7 +643,24 @@ class Model(object):
         log.debug(
             'Adding relation %s <-> %s', relation1, relation2)
 
         log.debug(
             'Adding relation %s <-> %s', relation1, relation2)
 
-        return await app_facade.AddRelation([relation1, relation2])
+        try:
+            result = await app_facade.AddRelation([relation1, relation2])
+        except JujuAPIError as e:
+            if 'relation already exists' not in e.message:
+                raise
+            log.debug(
+                'Relation %s <-> %s already exists', relation1, relation2)
+            # TODO: if relation already exists we should return the
+            # Relation ModelEntity here
+            return None
+
+        def predicate(delta):
+            endpoints = {}
+            for endpoint in delta.data['endpoints']:
+                endpoints[endpoint['application-name']] = endpoint['relation']
+            return endpoints == result.endpoints
+
+        return await self._wait_for_new('relation', None, predicate)
 
     def add_space(self, name, *cidrs):
         """Add a new network space.
 
     def add_space(self, name, *cidrs):
         """Add a new network space.
@@ -729,16 +802,11 @@ class Model(object):
 
         TODO::
 
 
         TODO::
 
-            - entity_url must have a revision; look up latest automatically if
-              not provided by caller
             - service_name is required; fill this in automatically if not
               provided by caller
             - series is required; how do we pick a default?
 
         """
             - service_name is required; fill this in automatically if not
               provided by caller
             - series is required; how do we pick a default?
 
         """
-        if constraints:
-            constraints = client.Value(**constraints)
-
         if to:
             placement = [
                 client.Placement(**p) for p in to
         if to:
             placement = [
                 client.Placement(**p) for p in to
@@ -763,6 +831,18 @@ class Model(object):
             handler = BundleHandler(self)
             await handler.fetch_plan(entity_id)
             await handler.execute_plan()
             handler = BundleHandler(self)
             await handler.fetch_plan(entity_id)
             await handler.execute_plan()
+            extant_apps = {app for app in self.applications}
+            pending_apps = set(handler.applications) - extant_apps
+            if pending_apps:
+                # new apps will usually be in the model by now, but if some
+                # haven't made it yet we'll need to wait on them to be added
+                await asyncio.gather(*[
+                    asyncio.ensure_future(
+                        self.model._wait_for_new('application', app_name))
+                    for app_name in pending_apps
+                ])
+            return [app for name, app in self.applications.items()
+                    if name in handler.applications]
         else:
             log.debug(
                 'Deploying %s', entity_id)
         else:
             log.debug(
                 'Deploying %s', entity_id)
@@ -791,6 +871,21 @@ class Model(object):
         """
         pass
 
         """
         pass
 
+    async def destroy_unit(self, *unit_names):
+        """Destroy units by name.
+
+        """
+        app_facade = client.ApplicationFacade()
+        app_facade.connect(self.connection)
+
+        log.debug(
+            'Destroying unit%s %s',
+            's' if len(unit_names) == 1 else '',
+            ' '.join(unit_names))
+
+        return await app_facade.Destroy(self.name)
+    destroy_units = destroy_unit
+
     def get_backup(self, archive_id):
         """Download a backup archive file.
 
     def get_backup(self, archive_id):
         """Download a backup archive file.
 
@@ -1126,10 +1221,11 @@ class BundleHandler(object):
         self.ann_facade.connect(model.connection)
 
     async def fetch_plan(self, entity_id):
         self.ann_facade.connect(model.connection)
 
     async def fetch_plan(self, entity_id):
-        yaml = await self.charmstore.files(entity_id,
-                                           filename='bundle.yaml',
-                                           read_file=True)
-        self.plan = await self.client_facade.GetBundleChanges(yaml)
+        bundle_yaml = await self.charmstore.files(entity_id,
+                                                  filename='bundle.yaml',
+                                                  read_file=True)
+        self.bundle = yaml.safe_load(bundle_yaml)
+        self.plan = await self.client_facade.GetBundleChanges(bundle_yaml)
 
     async def execute_plan(self):
         for step in self.plan.changes:
 
     async def execute_plan(self):
         for step in self.plan.changes:
@@ -1137,6 +1233,10 @@ class BundleHandler(object):
             result = await method(*step.args)
             self.references[step.id_] = result
 
             result = await method(*step.args)
             self.references[step.id_] = result
 
+    @property
+    def applications(self):
+        return list(self.bundle['services'].keys())
+
     def resolve(self, reference):
         if reference and reference.startswith('$'):
             reference = self.references[reference[1:]]
     def resolve(self, reference):
         if reference and reference.startswith('$'):
             reference = self.references[reference[1:]]
@@ -1156,35 +1256,41 @@ class BundleHandler(object):
         await self.client_facade.AddCharm(None, entity_id)
         return entity_id
 
         await self.client_facade.AddCharm(None, entity_id)
         return entity_id
 
-    async def addMachines(self, series, constraints, container_type,
-                          parent_id):
-        """
-        :param series string:
-            Series holds the optional machine OS series.
-
-        :param constraints string:
-            Constraints holds the optional machine constraints.
-
-        :param Container_type string:
-            ContainerType optionally holds the type of the container (for
-            instance ""lxc" or kvm"). It is not specified for top level
-            machines.
-
-        :param parent_id string:
-            ParentId optionally holds a placeholder pointing to another machine
-            change or to a unit change. This value is only specified in the
-            case this machine is a container, in which case also ContainerType
-            is set.
-        """
-        params = client.AddMachineParams(
-            series=series,
-            constraints=constraints,
-            container_type=container_type,
-            parent_id=self.resolve(parent_id),
-        )
-        results = await self.client_facade.AddMachines(params)
-        log.debug('Added new machine %s', results[0].machine)
-        return results[0].machine
+    async def addMachines(self, params=None):
+        """
+        :param params dict:
+            Dictionary specifying the machine to add. All keys are optional.
+            Keys include:
+
+            series: string specifying the machine OS series.
+            constraints: string holding machine constraints, if any. We'll
+                parse this into the json friendly dict that the juju api
+                expects.
+            container_type: string holding the type of the container (for
+                instance ""lxc" or kvm"). It is not specified for top level
+                machines.
+            parent_id: string holding a placeholder pointing to another
+                machine change or to a unit change. This value is only
+                specified in the case this machine is a container, in
+                which case also ContainerType is set.
+        """
+        params = params or {}
+
+        if 'parent_id' in params:
+            params['parent_id'] = self.resolve(params['parent_id'])
+
+        params['constraints'] = parse_constraints(
+            params.get('constraints'))
+        params['jobs'] = params.get('jobs', ['JobHostUnits'])
+
+        params = client.AddMachineParams(**params)
+        results = await self.client_facade.AddMachines([params])
+        error = results.machines[0].error
+        if error:
+            raise ValueError("Error adding machine: %s", error.message)
+        machine = results.machines[0].machine
+        log.debug('Added new machine %s', machine)
+        return machine
 
     async def addRelation(self, endpoint1, endpoint2):
         """
 
     async def addRelation(self, endpoint1, endpoint2):
         """
@@ -1201,14 +1307,9 @@ class BundleHandler(object):
             parts = endpoints[i].split(':')
             parts[0] = self.resolve(parts[0])
             endpoints[i] = ':'.join(parts)
             parts = endpoints[i].split(':')
             parts[0] = self.resolve(parts[0])
             endpoints[i] = ':'.join(parts)
-        try:
-            await self.app_facade.AddRelation(endpoints)
-            log.debug('Added relation %s <-> %s', *endpoints)
-        except JujuAPIError as e:
-            if 'relation already exists' not in e.message:
-                raise
-            log.debug('Relation %s <-> %s already exists', *endpoints)
-        return None
+
+        log.info('Relating %s <-> %s', *endpoints)
+        return await self.model.add_relation(*endpoints)
 
     async def deploy(self, charm, series, application, options, constraints,
                      storage, endpoint_bindings, resources):
 
     async def deploy(self, charm, series, application, options, constraints,
                      storage, endpoint_bindings, resources):
@@ -1256,7 +1357,7 @@ class BundleHandler(object):
             resources=resources,
         )
         # do the do
             resources=resources,
         )
         # do the do
-        log.debug('Deploying %s', charm)
+        log.info('Deploying %s', charm)
         await self.app_facade.Deploy([app])
         return application
 
         await self.app_facade.Deploy([app])
         return application
 
@@ -1279,16 +1380,14 @@ class BundleHandler(object):
             # doesn't, so we're not bothering, either
             unit_name = self._units_by_app[application].pop()
             log.debug('Reusing unit %s for %s', unit_name, application)
             # doesn't, so we're not bothering, either
             unit_name = self._units_by_app[application].pop()
             log.debug('Reusing unit %s for %s', unit_name, application)
-            return unit_name
-        log.debug('Adding unit of %s%s',
-                  application,
-                  (' to %s' % placement) if placement else '')
-        result = await self.app_facade.AddUnits(
-            application=application,
-            placement=placement,
-            num_units=1,
+            return self.model.units[unit_name]
+
+        log.debug('Adding new unit for %s%s', application,
+                  ' to %s' % placement if placement else '')
+        return await self.model.applications[application].add_unit(
+            count=1,
+            to=placement,
         )
         )
-        return result.units[0]
 
     async def expose(self, application):
         """
 
     async def expose(self, application):
         """
@@ -1297,9 +1396,8 @@ class BundleHandler(object):
             be exposed.
         """
         application = self.resolve(application)
             be exposed.
         """
         application = self.resolve(application)
-        log.debug('Exposing %s', application)
-        await self.app_facade.Expose(application)
-        return None
+        log.info('Exposing %s', application)
+        return await self.model.applications[application].expose()
 
     async def setAnnotations(self, id_, entity_type, annotations):
         """
 
     async def setAnnotations(self, id_, entity_type, annotations):
         """
@@ -1315,13 +1413,11 @@ class BundleHandler(object):
             Annotations holds the annotations as key/value pairs.
         """
         entity_id = self.resolve(id_)
             Annotations holds the annotations as key/value pairs.
         """
         entity_id = self.resolve(id_)
-        log.debug('Updating annotations of %s', entity_id)
-        ann = client.EntityAnnotations(
-            entity=entity_id,
-            annotations=annotations,
-        )
-        await self.ann_facade.Set([ann])
-        return None
+        try:
+            entity = self.model.state.get_entity(entity_type, entity_id)
+        except KeyError:
+            entity = await self.model._wait_for_new(entity_type, entity_id)
+        return await entity.set_annotations(annotations)
 
 
 class CharmStore(object):
 
 
 class CharmStore(object):