Start getting docs updated
[osm/N2VC.git] / juju / model.py
index 5d436fd..0ba79df 100644 (file)
@@ -1,16 +1,20 @@
 import asyncio
 import collections
 import logging
 import asyncio
 import collections
 import logging
+import os
 import re
 import weakref
 from concurrent.futures import CancelledError
 from functools import partial
 import re
 import weakref
 from concurrent.futures import CancelledError
 from functools import partial
+from pathlib import Path
 
 
+import yaml
 from theblues import charmstore
 
 from .client import client
 from .client import watcher
 from .client import connection
 from theblues import charmstore
 
 from .client import client
 from .client import watcher
 from .client import connection
+from .constraints import parse as parse_constraints
 from .delta import get_entity_delta
 from .delta import get_entity_class
 from .exceptions import DeadEntityException
 from .delta import get_entity_delta
 from .delta import get_entity_class
 from .exceptions import DeadEntityException
@@ -26,12 +30,14 @@ class _Observer(object):
     callable so that it's only called for changes that meet the criteria.
 
     """
     callable so that it's only called for changes that meet the criteria.
 
     """
-    def __init__(self, callable_, entity_type, action, entity_id):
+    def __init__(self, callable_, entity_type, action, entity_id, predicate):
         self.callable_ = callable_
         self.entity_type = entity_type
         self.action = action
         self.entity_id = entity_id
         self.callable_ = callable_
         self.entity_type = entity_type
         self.action = action
         self.entity_id = entity_id
+        self.predicate = predicate
         if self.entity_id:
         if self.entity_id:
+            self.entity_id = str(self.entity_id)
             if not self.entity_id.startswith('^'):
                 self.entity_id = '^' + self.entity_id
             if not self.entity_id.endswith('$'):
             if not self.entity_id.startswith('^'):
                 self.entity_id = '^' + self.entity_id
             if not self.entity_id.endswith('$'):
@@ -40,20 +46,22 @@ class _Observer(object):
     async def __call__(self, delta, old, new, model):
         await self.callable_(delta, old, new, model)
 
     async def __call__(self, delta, old, new, model):
         await self.callable_(delta, old, new, model)
 
-    def cares_about(self, entity_type, action, entity_id):
+    def cares_about(self, delta):
         """Return True if this observer "cares about" (i.e. wants to be
         """Return True if this observer "cares about" (i.e. wants to be
-        called) for a change matching the entity_type, action, and
-        entity_id parameters.
+        called) for a this delta.
 
         """
 
         """
-        if (self.entity_id and entity_id and
-                not re.match(self.entity_id, str(entity_id))):
+        if (self.entity_id and delta.get_id() and
+                not re.match(self.entity_id, str(delta.get_id()))):
             return False
 
             return False
 
-        if self.entity_type and self.entity_type != entity_type:
+        if self.entity_type and self.entity_type != delta.entity:
             return False
 
             return False
 
-        if self.action and self.action != action:
+        if self.action and self.action != delta.type:
+            return False
+
+        if self.predicate and not self.predicate(delta):
             return False
 
         return True
             return False
 
         return True
@@ -78,9 +86,6 @@ class ModelState(object):
         self.model = model
         self.state = dict()
 
         self.model = model
         self.state = dict()
 
-    def clear(self):
-        self.state.clear()
-
     def _live_entity_map(self, entity_type):
         """Return an id:Entity map of all the living entities of
         type ``entity_type``.
     def _live_entity_map(self, entity_type):
         """Return an id:Entity map of all the living entities of
         type ``entity_type``.
@@ -157,14 +162,12 @@ class ModelState(object):
 
     def get_entity(
             self, entity_type, entity_id, history_index=-1, connected=True):
 
     def get_entity(
             self, entity_type, entity_id, history_index=-1, connected=True):
-        """Return an object instance representing the entity created or
-        updated by ``delta``
+        """Return an object instance for the given entity_type and id.
+
+        By default the object state matches the most recent state from
+        Juju. To get an instance of the object in an older state, pass
+        history_index, an index into the history deque for the entity.
 
 
-        """
-        """
-        log.debug(
-            'Getting %s:%s at index %s',
-            entity_type, entity_id, history_index)
         """
 
         if history_index < 0 and history_index != -1:
         """
 
         if history_index < 0 and history_index != -1:
@@ -344,18 +347,44 @@ class Model(object):
         self.connection = None
         self.observers = weakref.WeakValueDictionary()
         self.state = ModelState(self)
         self.connection = None
         self.observers = weakref.WeakValueDictionary()
         self.state = ModelState(self)
+        self.info = None
         self._watcher_task = None
         self._watch_shutdown = asyncio.Event(loop=loop)
         self._watch_received = asyncio.Event(loop=loop)
         self._charmstore = CharmStore(self.loop)
 
         self._watcher_task = None
         self._watch_shutdown = asyncio.Event(loop=loop)
         self._watch_received = asyncio.Event(loop=loop)
         self._charmstore = CharmStore(self.loop)
 
+    async def connect(self, *args, **kw):
+        """Connect to an arbitrary Juju model.
+
+        args and kw are passed through to Connection.connect()
+
+        """
+        self.connection = await connection.Connection.connect(*args, **kw)
+        await self._after_connect()
+
     async def connect_current(self):
         """Connect to the current Juju model.
 
         """
         self.connection = await connection.Connection.connect_current()
     async def connect_current(self):
         """Connect to the current Juju model.
 
         """
         self.connection = await connection.Connection.connect_current()
+        await self._after_connect()
+
+    async def connect_model(self, model_name):
+        """Connect to a specific Juju model by name.
+
+        :param model_name:  Format [controller:][user/]model
+
+        """
+        self.connection = await connection.Connection.connect_model(model_name)
+        await self._after_connect()
+
+    async def _after_connect(self):
+        """Run initialization steps after connecting to websocket.
+
+        """
         self._watch()
         await self._watch_received.wait()
         self._watch()
         await self._watch_received.wait()
+        await self.get_info()
 
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
 
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
@@ -395,7 +424,6 @@ class Model(object):
         await self.block_until(
             lambda: len(self.machines) == 0
         )
         await self.block_until(
             lambda: len(self.machines) == 0
         )
-        self.state.clear()
 
     async def block_until(self, *conditions, timeout=None):
         """Return only after all conditions are true.
 
     async def block_until(self, *conditions, timeout=None):
         """Return only after all conditions are true.
@@ -430,8 +458,30 @@ class Model(object):
         """
         return self.state.units
 
         """
         return self.state.units
 
+    async def get_info(self):
+        """Return a client.ModelInfo object for this Model.
+
+        Retrieves latest info for this Model from the api server. The
+        return value is cached on the Model.info attribute so that the
+        valued may be accessed again without another api call, if
+        desired.
+
+        This method is called automatically when the Model is connected,
+        resulting in Model.info being initialized without requiring an
+        explicit call to this method.
+
+        """
+        facade = client.ClientFacade()
+        facade.connect(self.connection)
+
+        self.info = await facade.ModelInfo()
+        log.debug('Got ModelInfo: %s', vars(self.info))
+
+        return self.info
+
     def add_observer(
     def add_observer(
-            self, callable_, entity_type=None, action=None, entity_id=None):
+            self, callable_, entity_type=None, action=None, entity_id=None,
+            predicate=None):
         """Register an "on-model-change" callback
 
         Once the model is connected, ``callable_``
         """Register an "on-model-change" callback
 
         Once the model is connected, ``callable_``
@@ -459,8 +509,13 @@ class Model(object):
             add_observer(
                 myfunc, entity_type='application', action='add', id_='ubuntu')
 
             add_observer(
                 myfunc, entity_type='application', action='add', id_='ubuntu')
 
+        For more complex filtering conditions, pass a predicate function. It
+        will be called with a delta as its only argument. If the predicate
+        function returns True, the callable_ will be called.
+
         """
         """
-        observer = _Observer(callable_, entity_type, action, entity_id)
+        observer = _Observer(
+            callable_, entity_type, action, entity_id, predicate)
         self.observers[observer] = callable_
 
     def _watch(self):
         self.observers[observer] = callable_
 
     def _watch(self):
@@ -517,7 +572,7 @@ class Model(object):
             by applying this delta.
 
         """
             by applying this delta.
 
         """
-        if not old_obj:
+        if new_obj and not old_obj:
             delta.type = 'add'
 
         log.debug(
             delta.type = 'add'
 
         log.debug(
@@ -525,24 +580,64 @@ class Model(object):
             delta.entity, delta.type, delta.get_id())
 
         for o in self.observers:
             delta.entity, delta.type, delta.get_id())
 
         for o in self.observers:
-            if o.cares_about(delta.entity, delta.type, delta.get_id()):
+            if o.cares_about(delta):
                 asyncio.ensure_future(o(delta, old_obj, new_obj, self))
 
                 asyncio.ensure_future(o(delta, old_obj, new_obj, self))
 
-    async def _wait_for_new(self, entity_type, entity_id):
+    async def _wait(self, entity_type, entity_id, action, predicate=None):
+        """
+        Block the calling routine until a given action has happened to the
+        given entity
+
+        :param entity_type: The entity's type.
+        :param entity_id: The entity's id.
+        :param action: the type of action (e.g., 'add' or 'change')
+        :param predicate: optional callable that must take as an
+            argument a delta, and must return a boolean, indicating
+            whether the delta contains the specific action we're looking
+            for. For example, you might check to see whether a 'change'
+            has a 'completed' status. See the _Observer class for details.
+
+        """
+        q = asyncio.Queue(loop=self.loop)
+
+        async def callback(delta, old, new, model):
+            await q.put(delta.get_id())
+
+        self.add_observer(callback, entity_type, action, entity_id, predicate)
+        entity_id = await q.get()
+        return self.state._live_entity_map(entity_type)[entity_id]
+
+    async def _wait_for_new(self, entity_type, entity_id=None, predicate=None):
         """Wait for a new object to appear in the Model and return it.
 
         Waits for an object of type ``entity_type`` with id ``entity_id``.
         """Wait for a new object to appear in the Model and return it.
 
         Waits for an object of type ``entity_type`` with id ``entity_id``.
+        If ``entity_id`` is ``None``, it will wait for the first new entity
+        of the correct type.
 
         This coroutine blocks until the new object appears in the model.
 
         """
 
         This coroutine blocks until the new object appears in the model.
 
         """
-        entity_added = asyncio.Event(loop=self.loop)
+        # if the entity is already in the model, just return it
+        if entity_id in self.state._live_entity_map(entity_type):
+            return self.state._live_entity_map(entity_type)[entity_id]
+        # if we know the entity_id, we can trigger on any action that puts
+        # the enitty into the model; otherwise, we have to watch for the
+        # next "add" action on that entity_type
+        action = 'add' if entity_id is None else None
+        return await self._wait(entity_type, entity_id, action, predicate)
 
 
-        async def callback(delta, old, new, model):
-            entity_added.set()
-        self.add_observer(callback, entity_type, 'add', entity_id)
-        await entity_added.wait()
-        return self.state._live_entity_map(entity_type)[entity_id]
+    async def wait_for_action(self, action_id):
+        """Given an action, wait for it to complete."""
+
+        if action_id.startswith("action-"):
+            # if we've been passed action.tag, transform it into the
+            # id that the api deltas will use.
+            action_id = action_id[7:]
+
+        def predicate(delta):
+            return delta.data['status'] in ('completed', 'failed')
+
+        return await self._wait('action', action_id, 'change', predicate)
 
     def add_machine(
             self, spec=None, constraints=None, disks=None, series=None,
 
     def add_machine(
             self, spec=None, constraints=None, disks=None, series=None,
@@ -587,7 +682,24 @@ class Model(object):
         log.debug(
             'Adding relation %s <-> %s', relation1, relation2)
 
         log.debug(
             'Adding relation %s <-> %s', relation1, relation2)
 
-        return await app_facade.AddRelation([relation1, relation2])
+        try:
+            result = await app_facade.AddRelation([relation1, relation2])
+        except JujuAPIError as e:
+            if 'relation already exists' not in e.message:
+                raise
+            log.debug(
+                'Relation %s <-> %s already exists', relation1, relation2)
+            # TODO: if relation already exists we should return the
+            # Relation ModelEntity here
+            return None
+
+        def predicate(delta):
+            endpoints = {}
+            for endpoint in delta.data['endpoints']:
+                endpoints[endpoint['application-name']] = endpoint['relation']
+            return endpoints == result.endpoints
+
+        return await self._wait_for_new('relation', None, predicate)
 
     def add_space(self, name, *cidrs):
         """Add a new network space.
 
     def add_space(self, name, *cidrs):
         """Add a new network space.
@@ -729,16 +841,11 @@ class Model(object):
 
         TODO::
 
 
         TODO::
 
-            - entity_url must have a revision; look up latest automatically if
-              not provided by caller
             - service_name is required; fill this in automatically if not
               provided by caller
             - series is required; how do we pick a default?
 
         """
             - service_name is required; fill this in automatically if not
               provided by caller
             - series is required; how do we pick a default?
 
         """
-        if constraints:
-            constraints = client.Value(**constraints)
-
         if to:
             placement = [
                 client.Placement(**p) for p in to
         if to:
             placement = [
                 client.Placement(**p) for p in to
@@ -752,17 +859,36 @@ class Model(object):
                 for k, v in storage.items()
             }
 
                 for k, v in storage.items()
             }
 
-        entity_id = await self.charmstore.entityId(entity_url)
+        is_local = not entity_url.startswith('cs:') and \
+            os.path.isdir(entity_url)
+        entity_id = await self.charmstore.entityId(entity_url) \
+            if not is_local else entity_url
 
         app_facade = client.ApplicationFacade()
         client_facade = client.ClientFacade()
         app_facade.connect(self.connection)
         client_facade.connect(self.connection)
 
 
         app_facade = client.ApplicationFacade()
         client_facade = client.ClientFacade()
         app_facade.connect(self.connection)
         client_facade.connect(self.connection)
 
-        if 'bundle/' in entity_id:
+        is_bundle = ((is_local and
+                      (Path(entity_id) / 'bundle.yaml').exists()) or
+                     (not is_local and 'bundle/' in entity_id))
+
+        if is_bundle:
             handler = BundleHandler(self)
             await handler.fetch_plan(entity_id)
             await handler.execute_plan()
             handler = BundleHandler(self)
             await handler.fetch_plan(entity_id)
             await handler.execute_plan()
+            extant_apps = {app for app in self.applications}
+            pending_apps = set(handler.applications) - extant_apps
+            if pending_apps:
+                # new apps will usually be in the model by now, but if some
+                # haven't made it yet we'll need to wait on them to be added
+                await asyncio.gather(*[
+                    asyncio.ensure_future(
+                        self.model._wait_for_new('application', app_name))
+                    for app_name in pending_apps
+                ])
+            return [app for name, app in self.applications.items()
+                    if name in handler.applications]
         else:
             log.debug(
                 'Deploying %s', entity_id)
         else:
             log.debug(
                 'Deploying %s', entity_id)
@@ -791,6 +917,21 @@ class Model(object):
         """
         pass
 
         """
         pass
 
+    async def destroy_unit(self, *unit_names):
+        """Destroy units by name.
+
+        """
+        app_facade = client.ApplicationFacade()
+        app_facade.connect(self.connection)
+
+        log.debug(
+            'Destroying unit%s %s',
+            's' if len(unit_names) == 1 else '',
+            ' '.join(unit_names))
+
+        return await app_facade.Destroy(self.name)
+    destroy_units = destroy_unit
+
     def get_backup(self, archive_id):
         """Download a backup archive file.
 
     def get_backup(self, archive_id):
         """Download a backup archive file.
 
@@ -1103,6 +1244,36 @@ class Model(object):
     def charmstore(self):
         return self._charmstore
 
     def charmstore(self):
         return self._charmstore
 
+    async def get_metrics(self, *tags):
+        """Retrieve metrics.
+
+        :param str \*tags: Tags of entities from which to retrieve metrics.
+            No tags retrieves the metrics of all units in the model.
+        """
+        log.debug("Retrieving metrics for %s",
+                  ', '.join(tags) if tags else "all units")
+
+        metrics_facade = client.MetricsDebugFacade()
+        metrics_facade.connect(self.connection)
+
+        entities = [client.Entity(tag) for tag in tags]
+        metrics_result = await metrics_facade.GetMetrics(entities)
+
+        metrics = collections.defaultdict(list)
+
+        for entity_metrics in metrics_result.results:
+            error = entity_metrics.error
+            if error:
+                if "is not a valid tag" in error:
+                    raise ValueError(error.message)
+                else:
+                    raise Exception(error.message)
+
+            for metric in entity_metrics.metrics:
+                metrics[metric.unit].append(vars(metric))
+
+        return metrics
+
 
 class BundleHandler(object):
     """
 
 class BundleHandler(object):
     """
@@ -1126,10 +1297,15 @@ class BundleHandler(object):
         self.ann_facade.connect(model.connection)
 
     async def fetch_plan(self, entity_id):
         self.ann_facade.connect(model.connection)
 
     async def fetch_plan(self, entity_id):
-        yaml = await self.charmstore.files(entity_id,
-                                           filename='bundle.yaml',
-                                           read_file=True)
-        self.plan = await self.client_facade.GetBundleChanges(yaml)
+        is_local = not entity_id.startswith('cs:') and os.path.isdir(entity_id)
+        if is_local:
+            bundle_yaml = (Path(entity_id) / "bundle.yaml").read_text()
+        else:
+            bundle_yaml = await self.charmstore.files(entity_id,
+                                                      filename='bundle.yaml',
+                                                      read_file=True)
+        self.bundle = yaml.safe_load(bundle_yaml)
+        self.plan = await self.client_facade.GetBundleChanges(bundle_yaml)
 
     async def execute_plan(self):
         for step in self.plan.changes:
 
     async def execute_plan(self):
         for step in self.plan.changes:
@@ -1137,6 +1313,10 @@ class BundleHandler(object):
             result = await method(*step.args)
             self.references[step.id_] = result
 
             result = await method(*step.args)
             self.references[step.id_] = result
 
+    @property
+    def applications(self):
+        return list(self.bundle['services'].keys())
+
     def resolve(self, reference):
         if reference and reference.startswith('$'):
             reference = self.references[reference[1:]]
     def resolve(self, reference):
         if reference and reference.startswith('$'):
             reference = self.references[reference[1:]]
@@ -1156,35 +1336,41 @@ class BundleHandler(object):
         await self.client_facade.AddCharm(None, entity_id)
         return entity_id
 
         await self.client_facade.AddCharm(None, entity_id)
         return entity_id
 
-    async def addMachines(self, series, constraints, container_type,
-                          parent_id):
-        """
-        :param series string:
-            Series holds the optional machine OS series.
-
-        :param constraints string:
-            Constraints holds the optional machine constraints.
-
-        :param Container_type string:
-            ContainerType optionally holds the type of the container (for
-            instance ""lxc" or kvm"). It is not specified for top level
-            machines.
-
-        :param parent_id string:
-            ParentId optionally holds a placeholder pointing to another machine
-            change or to a unit change. This value is only specified in the
-            case this machine is a container, in which case also ContainerType
-            is set.
-        """
-        params = client.AddMachineParams(
-            series=series,
-            constraints=constraints,
-            container_type=container_type,
-            parent_id=self.resolve(parent_id),
-        )
-        results = await self.client_facade.AddMachines(params)
-        log.debug('Added new machine %s', results[0].machine)
-        return results[0].machine
+    async def addMachines(self, params=None):
+        """
+        :param params dict:
+            Dictionary specifying the machine to add. All keys are optional.
+            Keys include:
+
+            series: string specifying the machine OS series.
+            constraints: string holding machine constraints, if any. We'll
+                parse this into the json friendly dict that the juju api
+                expects.
+            container_type: string holding the type of the container (for
+                instance ""lxc" or kvm"). It is not specified for top level
+                machines.
+            parent_id: string holding a placeholder pointing to another
+                machine change or to a unit change. This value is only
+                specified in the case this machine is a container, in
+                which case also ContainerType is set.
+        """
+        params = params or {}
+
+        if 'parent_id' in params:
+            params['parent_id'] = self.resolve(params['parent_id'])
+
+        params['constraints'] = parse_constraints(
+            params.get('constraints'))
+        params['jobs'] = params.get('jobs', ['JobHostUnits'])
+
+        params = client.AddMachineParams(**params)
+        results = await self.client_facade.AddMachines([params])
+        error = results.machines[0].error
+        if error:
+            raise ValueError("Error adding machine: %s", error.message)
+        machine = results.machines[0].machine
+        log.debug('Added new machine %s', machine)
+        return machine
 
     async def addRelation(self, endpoint1, endpoint2):
         """
 
     async def addRelation(self, endpoint1, endpoint2):
         """
@@ -1201,14 +1387,9 @@ class BundleHandler(object):
             parts = endpoints[i].split(':')
             parts[0] = self.resolve(parts[0])
             endpoints[i] = ':'.join(parts)
             parts = endpoints[i].split(':')
             parts[0] = self.resolve(parts[0])
             endpoints[i] = ':'.join(parts)
-        try:
-            await self.app_facade.AddRelation(endpoints)
-            log.debug('Added relation %s <-> %s', *endpoints)
-        except JujuAPIError as e:
-            if 'relation already exists' not in e.message:
-                raise
-            log.debug('Relation %s <-> %s already exists', *endpoints)
-        return None
+
+        log.info('Relating %s <-> %s', *endpoints)
+        return await self.model.add_relation(*endpoints)
 
     async def deploy(self, charm, series, application, options, constraints,
                      storage, endpoint_bindings, resources):
 
     async def deploy(self, charm, series, application, options, constraints,
                      storage, endpoint_bindings, resources):
@@ -1256,8 +1437,10 @@ class BundleHandler(object):
             resources=resources,
         )
         # do the do
             resources=resources,
         )
         # do the do
-        log.debug('Deploying %s', charm)
+        log.info('Deploying %s', charm)
         await self.app_facade.Deploy([app])
         await self.app_facade.Deploy([app])
+        # ensure the app is in the model for future operations
+        await self.model._wait_for_new('application', application)
         return application
 
     async def addUnit(self, application, to):
         return application
 
     async def addUnit(self, application, to):
@@ -1279,16 +1462,14 @@ class BundleHandler(object):
             # doesn't, so we're not bothering, either
             unit_name = self._units_by_app[application].pop()
             log.debug('Reusing unit %s for %s', unit_name, application)
             # doesn't, so we're not bothering, either
             unit_name = self._units_by_app[application].pop()
             log.debug('Reusing unit %s for %s', unit_name, application)
-            return unit_name
-        log.debug('Adding unit of %s%s',
-                  application,
-                  (' to %s' % placement) if placement else '')
-        result = await self.app_facade.AddUnits(
-            application=application,
-            placement=placement,
-            num_units=1,
+            return self.model.units[unit_name]
+
+        log.debug('Adding new unit for %s%s', application,
+                  ' to %s' % placement if placement else '')
+        return await self.model.applications[application].add_unit(
+            count=1,
+            to=placement,
         )
         )
-        return result.units[0]
 
     async def expose(self, application):
         """
 
     async def expose(self, application):
         """
@@ -1297,9 +1478,8 @@ class BundleHandler(object):
             be exposed.
         """
         application = self.resolve(application)
             be exposed.
         """
         application = self.resolve(application)
-        log.debug('Exposing %s', application)
-        await self.app_facade.Expose(application)
-        return None
+        log.info('Exposing %s', application)
+        return await self.model.applications[application].expose()
 
     async def setAnnotations(self, id_, entity_type, annotations):
         """
 
     async def setAnnotations(self, id_, entity_type, annotations):
         """
@@ -1315,13 +1495,11 @@ class BundleHandler(object):
             Annotations holds the annotations as key/value pairs.
         """
         entity_id = self.resolve(id_)
             Annotations holds the annotations as key/value pairs.
         """
         entity_id = self.resolve(id_)
-        log.debug('Updating annotations of %s', entity_id)
-        ann = client.EntityAnnotations(
-            entity=entity_id,
-            annotations=annotations,
-        )
-        await self.ann_facade.Set([ann])
-        return None
+        try:
+            entity = self.model.state.get_entity(entity_type, entity_id)
+        except KeyError:
+            entity = await self.model._wait_for_new(entity_type, entity_id)
+        return await entity.set_annotations(annotations)
 
 
 class CharmStore(object):
 
 
 class CharmStore(object):