Remove incorrect docstring
[osm/N2VC.git] / juju / model.py
index 0922520..7897d42 100644 (file)
@@ -1,10 +1,12 @@
 import asyncio
 import collections
 import logging
 import asyncio
 import collections
 import logging
+import os
 import re
 import weakref
 from concurrent.futures import CancelledError
 from functools import partial
 import re
 import weakref
 from concurrent.futures import CancelledError
 from functools import partial
+from pathlib import Path
 
 import yaml
 from theblues import charmstore
 
 import yaml
 from theblues import charmstore
@@ -205,18 +207,16 @@ class ModelEntity(object):
         self.connected = connected
         self.connection = model.connection
 
         self.connected = connected
         self.connection = model.connection
 
+    def __repr__(self):
+        return '<{} entity_id="{}">'.format(type(self).__name__,
+                                            self.entity_id)
+
     def __getattr__(self, name):
         """Fetch object attributes from the underlying data dict held in the
         model.
 
         """
     def __getattr__(self, name):
         """Fetch object attributes from the underlying data dict held in the
         model.
 
         """
-        if self.data is None:
-            raise DeadEntityException(
-                "Entity {}:{} is dead - its attributes can no longer be "
-                "accessed. Use the .previous() method on this object to get "
-                "a copy of the object at its previous state.".format(
-                    self.entity_type, self.entity_id))
-        return self.data[name]
+        return self.safe_data[name]
 
     def __bool__(self):
         return bool(self.data)
 
     def __bool__(self):
         return bool(self.data)
@@ -283,6 +283,22 @@ class ModelEntity(object):
         return self.model.state.entity_data(
             self.entity_type, self.entity_id, self._history_index)
 
         return self.model.state.entity_data(
             self.entity_type, self.entity_id, self._history_index)
 
+    @property
+    def safe_data(self):
+        """The data dictionary for this entity.
+
+        If this `ModelEntity` points to the dead state, it will
+        raise `DeadEntityException`.
+
+        """
+        if self.data is None:
+            raise DeadEntityException(
+                "Entity {}:{} is dead - its attributes can no longer be "
+                "accessed. Use the .previous() method on this object to get "
+                "a copy of the object at its previous state.".format(
+                    self.entity_type, self.entity_id))
+        return self.data
+
     def previous(self):
         """Return a copy of this object as was at its previous state in
         history.
     def previous(self):
         """Return a copy of this object as was at its previous state in
         history.
@@ -345,6 +361,7 @@ class Model(object):
         self.connection = None
         self.observers = weakref.WeakValueDictionary()
         self.state = ModelState(self)
         self.connection = None
         self.observers = weakref.WeakValueDictionary()
         self.state = ModelState(self)
+        self.info = None
         self._watcher_task = None
         self._watch_shutdown = asyncio.Event(loop=loop)
         self._watch_received = asyncio.Event(loop=loop)
         self._watcher_task = None
         self._watch_shutdown = asyncio.Event(loop=loop)
         self._watch_received = asyncio.Event(loop=loop)
@@ -357,25 +374,31 @@ class Model(object):
 
         """
         self.connection = await connection.Connection.connect(*args, **kw)
 
         """
         self.connection = await connection.Connection.connect(*args, **kw)
-        self._watch()
-        await self._watch_received.wait()
+        await self._after_connect()
 
     async def connect_current(self):
         """Connect to the current Juju model.
 
         """
         self.connection = await connection.Connection.connect_current()
 
     async def connect_current(self):
         """Connect to the current Juju model.
 
         """
         self.connection = await connection.Connection.connect_current()
-        self._watch()
-        await self._watch_received.wait()
+        await self._after_connect()
+
+    async def connect_model(self, model_name):
+        """Connect to a specific Juju model by name.
+
+        :param model_name:  Format [controller:][user/]model
 
 
-    async def connect_model(self, arg):
-        """Connect to a specific Juju model.
-        :param arg:  <controller>:<user/model>
+        """
+        self.connection = await connection.Connection.connect_model(model_name)
+        await self._after_connect()
+
+    async def _after_connect(self):
+        """Run initialization steps after connecting to websocket.
 
         """
 
         """
-        self.connection = await connection.Connection.connect_model(arg)
         self._watch()
         await self._watch_received.wait()
         self._watch()
         await self._watch_received.wait()
+        await self.get_info()
 
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
 
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
@@ -416,13 +439,13 @@ class Model(object):
             lambda: len(self.machines) == 0
         )
 
             lambda: len(self.machines) == 0
         )
 
-    async def block_until(self, *conditions, timeout=None):
+    async def block_until(self, *conditions, timeout=None, wait_period=0.5):
         """Return only after all conditions are true.
 
         """
         async def _block():
             while not all(c() for c in conditions):
         """Return only after all conditions are true.
 
         """
         async def _block():
             while not all(c() for c in conditions):
-                await asyncio.sleep(0)
+                await asyncio.sleep(wait_period)
         await asyncio.wait_for(_block(), timeout)
 
     @property
         await asyncio.wait_for(_block(), timeout)
 
     @property
@@ -449,13 +472,34 @@ class Model(object):
         """
         return self.state.units
 
         """
         return self.state.units
 
+    async def get_info(self):
+        """Return a client.ModelInfo object for this Model.
+
+        Retrieves latest info for this Model from the api server. The
+        return value is cached on the Model.info attribute so that the
+        valued may be accessed again without another api call, if
+        desired.
+
+        This method is called automatically when the Model is connected,
+        resulting in Model.info being initialized without requiring an
+        explicit call to this method.
+
+        """
+        facade = client.ClientFacade()
+        facade.connect(self.connection)
+
+        self.info = await facade.ModelInfo()
+        log.debug('Got ModelInfo: %s', vars(self.info))
+
+        return self.info
+
     def add_observer(
             self, callable_, entity_type=None, action=None, entity_id=None,
             predicate=None):
         """Register an "on-model-change" callback
 
         Once the model is connected, ``callable_``
     def add_observer(
             self, callable_, entity_type=None, action=None, entity_id=None,
             predicate=None):
         """Register an "on-model-change" callback
 
         Once the model is connected, ``callable_``
-        will be called each time the model changes. callable_ should
+        will be called each time the model changes. ``callable_`` should
         be Awaitable and accept the following positional arguments:
 
             delta - An instance of :class:`juju.delta.EntityDelta`
         be Awaitable and accept the following positional arguments:
 
             delta - An instance of :class:`juju.delta.EntityDelta`
@@ -474,14 +518,15 @@ class Model(object):
             model - The :class:`Model` itself.
 
         Events for which ``callable_`` is called can be specified by passing
             model - The :class:`Model` itself.
 
         Events for which ``callable_`` is called can be specified by passing
-        entity_type, action, and/or id_ filter criteria, e.g.:
+        entity_type, action, and/or entitiy_id filter criteria, e.g.::
 
             add_observer(
 
             add_observer(
-                myfunc, entity_type='application', action='add', id_='ubuntu')
+                myfunc,
+                entity_type='application', action='add', entity_id='ubuntu')
 
         For more complex filtering conditions, pass a predicate function. It
         will be called with a delta as its only argument. If the predicate
 
         For more complex filtering conditions, pass a predicate function. It
         will be called with a delta as its only argument. If the predicate
-        function returns True, the callable_ will be called.
+        function returns True, the ``callable_`` will be called.
 
         """
         observer = _Observer(
 
         """
         observer = _Observer(
@@ -577,15 +622,24 @@ class Model(object):
         entity_id = await q.get()
         return self.state._live_entity_map(entity_type)[entity_id]
 
         entity_id = await q.get()
         return self.state._live_entity_map(entity_type)[entity_id]
 
-    async def _wait_for_new(self, entity_type, entity_id, predicate=None):
+    async def _wait_for_new(self, entity_type, entity_id=None, predicate=None):
         """Wait for a new object to appear in the Model and return it.
 
         Waits for an object of type ``entity_type`` with id ``entity_id``.
         """Wait for a new object to appear in the Model and return it.
 
         Waits for an object of type ``entity_type`` with id ``entity_id``.
+        If ``entity_id`` is ``None``, it will wait for the first new entity
+        of the correct type.
 
         This coroutine blocks until the new object appears in the model.
 
         """
 
         This coroutine blocks until the new object appears in the model.
 
         """
-        return await self._wait(entity_type, entity_id, 'add', predicate)
+        # if the entity is already in the model, just return it
+        if entity_id in self.state._live_entity_map(entity_type):
+            return self.state._live_entity_map(entity_type)[entity_id]
+        # if we know the entity_id, we can trigger on any action that puts
+        # the enitty into the model; otherwise, we have to watch for the
+        # next "add" action on that entity_type
+        action = 'add' if entity_id is None else None
+        return await self._wait(entity_type, entity_id, action, predicate)
 
     async def wait_for_action(self, action_id):
         """Given an action, wait for it to complete."""
 
     async def wait_for_action(self, action_id):
         """Given an action, wait for it to complete."""
@@ -769,14 +823,14 @@ class Model(object):
         pass
 
     async def deploy(
         pass
 
     async def deploy(
-            self, entity_url, service_name=None, bind=None, budget=None,
+            self, entity_url, application_name=None, bind=None, budget=None,
             channel=None, config=None, constraints=None, force=False,
             num_units=1, plan=None, resources=None, series=None, storage=None,
             to=None):
         """Deploy a new service or bundle.
 
         :param str entity_url: Charm or bundle url
             channel=None, config=None, constraints=None, force=False,
             num_units=1, plan=None, resources=None, series=None, storage=None,
             to=None):
         """Deploy a new service or bundle.
 
         :param str entity_url: Charm or bundle url
-        :param str service_name: Name to give the service
+        :param str application_name: Name to give the service
         :param dict bind: <charm endpoint>:<network space> pairs
         :param dict budget: <budget name>:<limit> pairs
         :param str channel: Charm store channel from which to retrieve
         :param dict bind: <charm endpoint>:<network space> pairs
         :param dict budget: <budget name>:<limit> pairs
         :param str channel: Charm store channel from which to retrieve
@@ -802,7 +856,7 @@ class Model(object):
 
         TODO::
 
 
         TODO::
 
-            - service_name is required; fill this in automatically if not
+            - application_name is required; fill this in automatically if not
               provided by caller
             - series is required; how do we pick a default?
 
               provided by caller
             - series is required; how do we pick a default?
 
@@ -820,14 +874,21 @@ class Model(object):
                 for k, v in storage.items()
             }
 
                 for k, v in storage.items()
             }
 
-        entity_id = await self.charmstore.entityId(entity_url)
+        is_local = not entity_url.startswith('cs:') and \
+            os.path.isdir(entity_url)
+        entity_id = await self.charmstore.entityId(entity_url) \
+            if not is_local else entity_url
 
         app_facade = client.ApplicationFacade()
         client_facade = client.ClientFacade()
         app_facade.connect(self.connection)
         client_facade.connect(self.connection)
 
 
         app_facade = client.ApplicationFacade()
         client_facade = client.ClientFacade()
         app_facade.connect(self.connection)
         client_facade.connect(self.connection)
 
-        if 'bundle/' in entity_id:
+        is_bundle = ((is_local and
+                      (Path(entity_id) / 'bundle.yaml').exists()) or
+                     (not is_local and 'bundle/' in entity_id))
+
+        if is_bundle:
             handler = BundleHandler(self)
             await handler.fetch_plan(entity_id)
             await handler.execute_plan()
             handler = BundleHandler(self)
             await handler.fetch_plan(entity_id)
             await handler.execute_plan()
@@ -838,7 +899,7 @@ class Model(object):
                 # haven't made it yet we'll need to wait on them to be added
                 await asyncio.gather(*[
                     asyncio.ensure_future(
                 # haven't made it yet we'll need to wait on them to be added
                 await asyncio.gather(*[
                     asyncio.ensure_future(
-                        self.model._wait_for_new('application', app_name))
+                        self._wait_for_new('application', app_name))
                     for app_name in pending_apps
                 ])
             return [app for name, app in self.applications.items()
                     for app_name in pending_apps
                 ])
             return [app for name, app in self.applications.items()
@@ -849,7 +910,7 @@ class Model(object):
 
             await client_facade.AddCharm(channel, entity_id)
             app = client.ApplicationDeploy(
 
             await client_facade.AddCharm(channel, entity_id)
             app = client.ApplicationDeploy(
-                application=service_name,
+                application=application_name,
                 channel=channel,
                 charm_url=entity_id,
                 config=config,
                 channel=channel,
                 charm_url=entity_id,
                 config=config,
@@ -863,7 +924,7 @@ class Model(object):
             )
 
             await app_facade.Deploy([app])
             )
 
             await app_facade.Deploy([app])
-            return await self._wait_for_new('application', service_name)
+            return await self._wait_for_new('application', application_name)
 
     def destroy(self):
         """Terminate all machines and resources for this model.
 
     def destroy(self):
         """Terminate all machines and resources for this model.
@@ -1198,6 +1259,36 @@ class Model(object):
     def charmstore(self):
         return self._charmstore
 
     def charmstore(self):
         return self._charmstore
 
+    async def get_metrics(self, *tags):
+        """Retrieve metrics.
+
+        :param str \*tags: Tags of entities from which to retrieve metrics.
+            No tags retrieves the metrics of all units in the model.
+        """
+        log.debug("Retrieving metrics for %s",
+                  ', '.join(tags) if tags else "all units")
+
+        metrics_facade = client.MetricsDebugFacade()
+        metrics_facade.connect(self.connection)
+
+        entities = [client.Entity(tag) for tag in tags]
+        metrics_result = await metrics_facade.GetMetrics(entities)
+
+        metrics = collections.defaultdict(list)
+
+        for entity_metrics in metrics_result.results:
+            error = entity_metrics.error
+            if error:
+                if "is not a valid tag" in error:
+                    raise ValueError(error.message)
+                else:
+                    raise Exception(error.message)
+
+            for metric in entity_metrics.metrics:
+                metrics[metric.unit].append(vars(metric))
+
+        return metrics
+
 
 class BundleHandler(object):
     """
 
 class BundleHandler(object):
     """
@@ -1221,9 +1312,13 @@ class BundleHandler(object):
         self.ann_facade.connect(model.connection)
 
     async def fetch_plan(self, entity_id):
         self.ann_facade.connect(model.connection)
 
     async def fetch_plan(self, entity_id):
-        bundle_yaml = await self.charmstore.files(entity_id,
-                                                  filename='bundle.yaml',
-                                                  read_file=True)
+        is_local = not entity_id.startswith('cs:') and os.path.isdir(entity_id)
+        if is_local:
+            bundle_yaml = (Path(entity_id) / "bundle.yaml").read_text()
+        else:
+            bundle_yaml = await self.charmstore.files(entity_id,
+                                                      filename='bundle.yaml',
+                                                      read_file=True)
         self.bundle = yaml.safe_load(bundle_yaml)
         self.plan = await self.client_facade.GetBundleChanges(bundle_yaml)
 
         self.bundle = yaml.safe_load(bundle_yaml)
         self.plan = await self.client_facade.GetBundleChanges(bundle_yaml)
 
@@ -1256,23 +1351,30 @@ class BundleHandler(object):
         await self.client_facade.AddCharm(None, entity_id)
         return entity_id
 
         await self.client_facade.AddCharm(None, entity_id)
         return entity_id
 
-    async def addMachines(self, params):
+    async def addMachines(self, params=None):
         """
         :param params dict:
         """
         :param params dict:
-            Dictionary specifying the machine to add. Keys include:
+            Dictionary specifying the machine to add. All keys are optional.
+            Keys include:
 
             series: string specifying the machine OS series.
 
             series: string specifying the machine OS series.
-            constraints: string holding optional machine constraints. We'll
+
+            constraints: string holding machine constraints, if any. We'll
                 parse this into the json friendly dict that the juju api
                 expects.
                 parse this into the json friendly dict that the juju api
                 expects.
-            Container_type: string holding the type of the container (for
+
+            container_type: string holding the type of the container (for
                 instance ""lxc" or kvm"). It is not specified for top level
                 machines.
                 instance ""lxc" or kvm"). It is not specified for top level
                 machines.
+
             parent_id: string holding a placeholder pointing to another
                 machine change or to a unit change. This value is only
                 specified in the case this machine is a container, in
                 which case also ContainerType is set.
             parent_id: string holding a placeholder pointing to another
                 machine change or to a unit change. This value is only
                 specified in the case this machine is a container, in
                 which case also ContainerType is set.
+
         """
         """
+        params = params or {}
+
         if 'parent_id' in params:
             params['parent_id'] = self.resolve(params['parent_id'])
 
         if 'parent_id' in params:
             params['parent_id'] = self.resolve(params['parent_id'])
 
@@ -1356,6 +1458,8 @@ class BundleHandler(object):
         # do the do
         log.info('Deploying %s', charm)
         await self.app_facade.Deploy([app])
         # do the do
         log.info('Deploying %s', charm)
         await self.app_facade.Deploy([app])
+        # ensure the app is in the model for future operations
+        await self.model._wait_for_new('application', application)
         return application
 
     async def addUnit(self, application, to):
         return application
 
     async def addUnit(self, application, to):