Fix livemodel example
[osm/N2VC.git] / juju / model.py
index 2ef903b..677657f 100644 (file)
@@ -1,6 +1,31 @@
+import asyncio
+import logging
+from concurrent.futures import CancelledError
+
+from .client import client
 from .client import watcher
+from .client import connection
 from .delta import get_entity_delta
 
+log = logging.getLogger(__name__)
+
+
+class ModelObserver(object):
+    async def __call__(self, delta, old, new, model):
+        if old is None and new is not None:
+            type_ = 'add'
+        else:
+            type_ = delta.type
+        handler_name = 'on_{}_{}'.format(delta.entity, type_)
+        method = getattr(self, handler_name, self.on_change)
+        log.debug(
+            'Model changed: %s %s %s',
+            delta.entity, delta.type, delta.get_id())
+        await method(delta, old, new, model)
+
+    async def on_change(self, delta, old, new, model):
+        pass
+
 
 class ModelEntity(object):
     """An object in the Model tree"""
@@ -15,25 +40,112 @@ class ModelEntity(object):
         """
         self.data = data
         self.model = model
+        self.connection = model.connection
+
+    def __getattr__(self, name):
+        return self.data[name]
 
 
 class Model(object):
-    def __init__(self, connection):
+    def __init__(self, loop=None):
         """Instantiate a new connected Model.
 
-        :param connection: `juju.client.connection.Connection` instance
+        :param loop: an asyncio event loop
 
         """
-        self.connection = connection
+        self.loop = loop or asyncio.get_event_loop()
+        self.connection = None
         self.observers = set()
         self.state = dict()
+        self._watcher_task = None
+        self._watch_shutdown = asyncio.Event(loop=loop)
+        self._watch_received = asyncio.Event(loop=loop)
+
+    async def connect_current(self):
+        """Connect to the current Juju model.
+
+        """
+        self.connection = await connection.Connection.connect_current()
+        self._watch()
+        await self._watch_received.wait()
+
+    async def disconnect(self):
+        """Shut down the watcher task and close websockets.
+
+        """
+        self._stop_watching()
+        if self.connection and self.connection.is_open:
+            await self._watch_shutdown.wait()
+            log.debug('Closing model connection')
+            await self.connection.close()
+            self.connection = None
+
+    def all_units_idle(self):
+        """Return True if all units are idle.
+
+        """
+        for unit in self.units.values():
+            unit_status = unit.data['agent-status']['current']
+            if unit_status != 'idle':
+                return False
+        return True
+
+    async def reset(self, force=False):
+        """Reset the model to a clean state.
+
+        :param bool force: Force-terminate machines.
+
+        This returns only after the model has reached a clean state. "Clean"
+        means no applications or machines exist in the model.
+
+        """
+        for app in self.applications.values():
+            await app.destroy()
+        for machine in self.machines.values():
+            await machine.destroy(force=force)
+        await self.block_until(
+            lambda: len(self.machines) == 0
+        )
+
+    async def block_until(self, *conditions, timeout=None):
+        """Return only after all conditions are true.
+
+        """
+        async def _block():
+            while not all(c() for c in conditions):
+                await asyncio.sleep(.1)
+        await asyncio.wait_for(_block(), timeout)
+
+    @property
+    def applications(self):
+        """Return a map of application-name:Application for all applications
+        currently in the model.
+
+        """
+        return self.state.get('application', {})
+
+    @property
+    def machines(self):
+        """Return a map of machine-id:Machine for all machines currently in
+        the model.
+
+        """
+        return self.state.get('machine', {})
+
+    @property
+    def units(self):
+        """Return a map of unit-id:Unit for all units currently in
+        the model.
+
+        """
+        return self.state.get('unit', {})
 
     def add_observer(self, callable_):
         """Register an "on-model-change" callback
 
         Once a watch is started (Model.watch() is called), ``callable_``
         will be called each time the model changes. callable_ should
-        accept the following positional arguments:
+        be Awaitable and accept the following positional arguments:
 
             delta - An instance of :class:`juju.delta.EntityDelta`
                 containing the raw delta data recv'd from the Juju
@@ -53,20 +165,48 @@ class Model(object):
         """
         self.observers.add(callable_)
 
-    async def watch(self):
+    def _watch(self):
         """Start an asynchronous watch against this model.
 
         See :meth:`add_observer` to register an onchange callback.
 
         """
-        allwatcher = watcher.AllWatcher()
-        allwatcher.connect(self.connection)
-        while True:
-            results = await allwatcher.Next()
-            for delta in results.deltas:
-                delta = get_entity_delta(delta)
-                old_obj, new_obj = self._apply_delta(delta)
-                self._notify_observers(delta, old_obj, new_obj)
+        async def _start_watch():
+            self._watch_shutdown.clear()
+            try:
+                allwatcher = watcher.AllWatcher()
+                self._watch_conn = await self.connection.clone()
+                allwatcher.connect(self._watch_conn)
+                while True:
+                    results = await allwatcher.Next()
+                    for delta in results.deltas:
+                        delta = get_entity_delta(delta)
+                        old_obj, new_obj = self._apply_delta(delta)
+                        # XXX: Might not want to shield at this level
+                        # We are shielding because when the watcher is
+                        # canceled (on disconnect()), we don't want all of
+                        # its children (every observer callback) to be
+                        # canceled with it. So we shield them. But this means
+                        # they can *never* be canceled.
+                        await asyncio.shield(
+                            self._notify_observers(delta, old_obj, new_obj))
+                    self._watch_received.set()
+            except CancelledError:
+                log.debug('Closing watcher connection')
+                await self._watch_conn.close()
+                self._watch_shutdown.set()
+                self._watch_conn = None
+
+        log.debug('Starting watcher task')
+        self._watcher_task = self.loop.create_task(_start_watch())
+
+    def _stop_watching(self):
+        """Stop the asynchronous watch against this model.
+
+        """
+        log.debug('Stopping watcher task')
+        if self._watcher_task:
+            self._watcher_task.cancel()
 
     def _apply_delta(self, delta):
         """Apply delta to our model state and return the a copy of the
@@ -77,8 +217,8 @@ class Model(object):
         old_obj may be None if the delta is for the creation of a new object,
         e.g. a new application or unit is deployed.
 
-        new_obj may be if no object was created or updated, or if an object
-        was deleted as a result of the delta being applied.
+        new_obj may be None if no object was created or updated, or if an
+        object was deleted as a result of the delta being applied.
 
         """
         old_obj, new_obj = None, None
@@ -103,7 +243,7 @@ class Model(object):
         entity_class = delta.get_entity_class()
         return entity_class(delta.data, self)
 
-    def _notify_observers(self, delta, old_obj, new_obj):
+    async def _notify_observers(self, delta, old_obj, new_obj):
         """Call observing callbacks, notifying them of a change in model state
 
         :param delta: The raw change from the watcher
@@ -115,7 +255,7 @@ class Model(object):
 
         """
         for o in self.observers:
-            o(delta, old_obj, new_obj, self)
+            asyncio.ensure_future(o(delta, old_obj, new_obj, self))
 
     def add_machine(
             self, spec=None, constraints=None, disks=None, series=None,
@@ -147,14 +287,20 @@ class Model(object):
         pass
     add_machines = add_machine
 
-    def add_relation(self, relation1, relation2):
-        """Add a relation between two services.
+    async def add_relation(self, relation1, relation2):
+        """Add a relation between two applications.
 
-        :param str relation1: '<service>[:<relation_name>]'
-        :param str relation2: '<service>[:<relation_name>]'
+        :param str relation1: '<application>[:<relation_name>]'
+        :param str relation2: '<application>[:<relation_name>]'
 
         """
-        pass
+        app_facade = client.ApplicationFacade()
+        app_facade.connect(self.connection)
+
+        log.debug(
+            'Adding relation %s <-> %s', relation1, relation2)
+
+        return await app_facade.AddRelation([relation1, relation2])
 
     def add_space(self, name, *cidrs):
         """Add a new network space.
@@ -262,10 +408,10 @@ class Model(object):
         """
         pass
 
-    def deploy(
+    async def deploy(
             self, entity_url, service_name=None, bind=None, budget=None,
             channel=None, config=None, constraints=None, force=False,
-            num_units=1, plan=None, resource=None, series=None, storage=None,
+            num_units=1, plan=None, resources=None, series=None, storage=None,
             to=None):
         """Deploy a new service or bundle.
 
@@ -282,7 +428,7 @@ class Model(object):
             an unsupported series
         :param int num_units: Number of units to deploy
         :param str plan: Plan under which to deploy charm
-        :param dict resource: <resource name>:<file path> pairs
+        :param dict resources: <resource name>:<file path> pairs
         :param str series: Series on which to deploy
         :param dict storage: Storage constraints TODO how do these look?
         :param str to: Placement directive, e.g.::
@@ -293,8 +439,56 @@ class Model(object):
 
             If None, a new machine is provisioned.
 
-        """
-        pass
+
+        TODO::
+
+            - entity_url must have a revision; look up latest automatically if
+              not provided by caller
+            - service_name is required; fill this in automatically if not
+              provided by caller
+            - series is required; how do we pick a default?
+
+        """
+        if constraints:
+            constraints = client.Value(**constraints)
+
+        if to:
+            placement = [
+                client.Placement(**p) for p in to
+            ]
+        else:
+            placement = []
+
+        if storage:
+            storage = {
+                k: client.Constraints(**v)
+                for k, v in storage.items()
+            }
+
+        app_facade = client.ApplicationFacade()
+        client_facade = client.ClientFacade()
+        app_facade.connect(self.connection)
+        client_facade.connect(self.connection)
+
+        log.debug(
+            'Deploying %s', entity_url)
+
+        await client_facade.AddCharm(channel, entity_url)
+        app = client.ApplicationDeploy(
+            application=service_name,
+            channel=channel,
+            charm_url=entity_url,
+            config=config,
+            constraints=constraints,
+            endpoint_bindings=bind,
+            num_units=num_units,
+            placement=placement,
+            resources=resources,
+            series=series,
+            storage=storage,
+        )
+
+        return await app_facade.Deploy([app])
 
     def destroy(self):
         """Terminate all machines and resources for this model.