Merge pull request #28 from petevg/bug/fixes-for-null-cases
[osm/N2VC.git] / juju / model.py
index 0922520..df72eb2 100644 (file)
@@ -345,6 +345,7 @@ class Model(object):
         self.connection = None
         self.observers = weakref.WeakValueDictionary()
         self.state = ModelState(self)
         self.connection = None
         self.observers = weakref.WeakValueDictionary()
         self.state = ModelState(self)
+        self.info = None
         self._watcher_task = None
         self._watch_shutdown = asyncio.Event(loop=loop)
         self._watch_received = asyncio.Event(loop=loop)
         self._watcher_task = None
         self._watch_shutdown = asyncio.Event(loop=loop)
         self._watch_received = asyncio.Event(loop=loop)
@@ -357,25 +358,31 @@ class Model(object):
 
         """
         self.connection = await connection.Connection.connect(*args, **kw)
 
         """
         self.connection = await connection.Connection.connect(*args, **kw)
-        self._watch()
-        await self._watch_received.wait()
+        await self._after_connect()
 
     async def connect_current(self):
         """Connect to the current Juju model.
 
         """
         self.connection = await connection.Connection.connect_current()
 
     async def connect_current(self):
         """Connect to the current Juju model.
 
         """
         self.connection = await connection.Connection.connect_current()
-        self._watch()
-        await self._watch_received.wait()
+        await self._after_connect()
 
 
-    async def connect_model(self, arg):
-        """Connect to a specific Juju model.
-        :param arg:  <controller>:<user/model>
+    async def connect_model(self, model_name):
+        """Connect to a specific Juju model by name.
+
+        :param model_name:  Format [controller:][user/]model
+
+        """
+        self.connection = await connection.Connection.connect_model(model_name)
+        await self._after_connect()
+
+    async def _after_connect(self):
+        """Run initialization steps after connecting to websocket.
 
         """
 
         """
-        self.connection = await connection.Connection.connect_model(arg)
         self._watch()
         await self._watch_received.wait()
         self._watch()
         await self._watch_received.wait()
+        await self.get_info()
 
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
 
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
@@ -449,6 +456,27 @@ class Model(object):
         """
         return self.state.units
 
         """
         return self.state.units
 
+    async def get_info(self):
+        """Return a client.ModelInfo object for this Model.
+
+        Retrieves latest info for this Model from the api server. The
+        return value is cached on the Model.info attribute so that the
+        valued may be accessed again without another api call, if
+        desired.
+
+        This method is called automatically when the Model is connected,
+        resulting in Model.info being initialized without requiring an
+        explicit call to this method.
+
+        """
+        facade = client.ClientFacade()
+        facade.connect(self.connection)
+
+        self.info = await facade.ModelInfo()
+        log.debug('Got ModelInfo: %s', vars(self.info))
+
+        return self.info
+
     def add_observer(
             self, callable_, entity_type=None, action=None, entity_id=None,
             predicate=None):
     def add_observer(
             self, callable_, entity_type=None, action=None, entity_id=None,
             predicate=None):
@@ -577,15 +605,24 @@ class Model(object):
         entity_id = await q.get()
         return self.state._live_entity_map(entity_type)[entity_id]
 
         entity_id = await q.get()
         return self.state._live_entity_map(entity_type)[entity_id]
 
-    async def _wait_for_new(self, entity_type, entity_id, predicate=None):
+    async def _wait_for_new(self, entity_type, entity_id=None, predicate=None):
         """Wait for a new object to appear in the Model and return it.
 
         Waits for an object of type ``entity_type`` with id ``entity_id``.
         """Wait for a new object to appear in the Model and return it.
 
         Waits for an object of type ``entity_type`` with id ``entity_id``.
+        If ``entity_id`` is ``None``, it will wait for the first new entity
+        of the correct type.
 
         This coroutine blocks until the new object appears in the model.
 
         """
 
         This coroutine blocks until the new object appears in the model.
 
         """
-        return await self._wait(entity_type, entity_id, 'add', predicate)
+        # if the entity is already in the model, just return it
+        if entity_id in self.state._live_entity_map(entity_type):
+            return self.state._live_entity_map(entity_type)[entity_id]
+        # if we know the entity_id, we can trigger on any action that puts
+        # the enitty into the model; otherwise, we have to watch for the
+        # next "add" action on that entity_type
+        action = 'add' if entity_id is None else None
+        return await self._wait(entity_type, entity_id, action, predicate)
 
     async def wait_for_action(self, action_id):
         """Given an action, wait for it to complete."""
 
     async def wait_for_action(self, action_id):
         """Given an action, wait for it to complete."""
@@ -1198,6 +1235,36 @@ class Model(object):
     def charmstore(self):
         return self._charmstore
 
     def charmstore(self):
         return self._charmstore
 
+    async def get_metrics(self, *tags):
+        """Retrieve metrics.
+
+        :param str \*tags: Tags of entities from which to retrieve metrics.
+            No tags retrieves the metrics of all units in the model.
+        """
+        log.debug("Retrieving metrics for %s",
+                  ', '.join(tags) if tags else "all units")
+
+        metrics_facade = client.MetricsDebugFacade()
+        metrics_facade.connect(self.connection)
+
+        entities = [client.Entity(tag) for tag in tags]
+        metrics_result = await metrics_facade.GetMetrics(entities)
+
+        metrics = collections.defaultdict(list)
+
+        for entity_metrics in metrics_result.results:
+            error = entity_metrics.error
+            if error:
+                if "is not a valid tag" in error:
+                    raise ValueError(error.message)
+                else:
+                    raise Exception(error.message)
+
+            for metric in entity_metrics.metrics:
+                metrics[metric.unit].append(vars(metric))
+
+        return metrics
+
 
 class BundleHandler(object):
     """
 
 class BundleHandler(object):
     """
@@ -1256,16 +1323,17 @@ class BundleHandler(object):
         await self.client_facade.AddCharm(None, entity_id)
         return entity_id
 
         await self.client_facade.AddCharm(None, entity_id)
         return entity_id
 
-    async def addMachines(self, params):
+    async def addMachines(self, params=None):
         """
         :param params dict:
         """
         :param params dict:
-            Dictionary specifying the machine to add. Keys include:
+            Dictionary specifying the machine to add. All keys are optional.
+            Keys include:
 
             series: string specifying the machine OS series.
 
             series: string specifying the machine OS series.
-            constraints: string holding optional machine constraints. We'll
+            constraints: string holding machine constraints, if any. We'll
                 parse this into the json friendly dict that the juju api
                 expects.
                 parse this into the json friendly dict that the juju api
                 expects.
-            Container_type: string holding the type of the container (for
+            container_type: string holding the type of the container (for
                 instance ""lxc" or kvm"). It is not specified for top level
                 machines.
             parent_id: string holding a placeholder pointing to another
                 instance ""lxc" or kvm"). It is not specified for top level
                 machines.
             parent_id: string holding a placeholder pointing to another
@@ -1273,6 +1341,8 @@ class BundleHandler(object):
                 specified in the case this machine is a container, in
                 which case also ContainerType is set.
         """
                 specified in the case this machine is a container, in
                 which case also ContainerType is set.
         """
+        params = params or {}
+
         if 'parent_id' in params:
             params['parent_id'] = self.resolve(params['parent_id'])
 
         if 'parent_id' in params:
             params['parent_id'] = self.resolve(params['parent_id'])
 
@@ -1356,6 +1426,8 @@ class BundleHandler(object):
         # do the do
         log.info('Deploying %s', charm)
         await self.app_facade.Deploy([app])
         # do the do
         log.info('Deploying %s', charm)
         await self.app_facade.Deploy([app])
+        # ensure the app is in the model for future operations
+        await self.model._wait_for_new('application', application)
         return application
 
     async def addUnit(self, application, to):
         return application
 
     async def addUnit(self, application, to):