Add relate example
[osm/N2VC.git] / juju / model.py
index e56bfb4..04f3437 100644 (file)
@@ -1,7 +1,10 @@
+import asyncio
 import logging
+from concurrent.futures import CancelledError
 
 from .client import client
 from .client import watcher
+from .client import connection
 from .delta import get_entity_delta
 
 log = logging.getLogger(__name__)
@@ -44,20 +47,63 @@ class ModelEntity(object):
 
 
 class Model(object):
-    def __init__(self, connection):
+    def __init__(self, loop=None):
         """Instantiate a new connected Model.
 
-        :param connection: `juju.client.connection.Connection` instance
+        :param loop: an asyncio event loop
 
         """
-        self.connection = connection
+        self.loop = loop or asyncio.get_event_loop()
+        self.connection = None
         self.observers = set()
         self.state = dict()
+        self._watcher_task = None
+        self._watch_shutdown = asyncio.Event(loop=loop)
+        self._watch_received = asyncio.Event(loop=loop)
+
+    async def connect_current(self):
+        self.connection = await connection.Connection.connect_current()
+        self._watch()
+        await self._watch_received.wait()
+
+    async def disconnect(self):
+        self._stop_watching()
+        if self.connection and self.connection.is_open:
+            await self._watch_shutdown.wait()
+            log.debug('Closing model connection')
+            await asyncio.wait_for(self.connection.close(), None)
+            self.connection = None
+
+    def all_units_idle(self):
+        """Return True if all units are idle.
+
+        """
+        for unit in self.units.values():
+            unit_status = unit.data['agent-status']['current']
+            if unit_status != 'idle':
+                return False
+        return True
+
+    async def reset(self, force=False):
+        for app in self.applications.values():
+            await app.destroy()
+        for machine in self.machines.values():
+            await machine.destroy(force=force)
+
+    async def block_until(self, func):
+        async def _block():
+            while not func():
+                await asyncio.sleep(.1)
+        await asyncio.wait_for(_block(), None)
 
     @property
     def applications(self):
         return self.state.get('application', {})
 
+    @property
+    def machines(self):
+        return self.state.get('machine', {})
+
     @property
     def units(self):
         return self.state.get('unit', {})
@@ -87,21 +133,41 @@ class Model(object):
         """
         self.observers.add(callable_)
 
-    async def watch(self):
+    def _watch(self):
         """Start an asynchronous watch against this model.
 
         See :meth:`add_observer` to register an onchange callback.
 
         """
-        self._watching = True
-        allwatcher = watcher.AllWatcher()
-        allwatcher.connect(await self.connection.clone())
-        while True:
-            results = await allwatcher.Next()
-            for delta in results.deltas:
-                delta = get_entity_delta(delta)
-                old_obj, new_obj = self._apply_delta(delta)
-                self._notify_observers(delta, old_obj, new_obj)
+        async def _start_watch():
+            self._watch_shutdown.clear()
+            try:
+                allwatcher = watcher.AllWatcher()
+                self._watch_conn = await self.connection.clone()
+                allwatcher.connect(self._watch_conn)
+                while True:
+                    results = await allwatcher.Next()
+                    for delta in results.deltas:
+                        delta = get_entity_delta(delta)
+                        old_obj, new_obj = self._apply_delta(delta)
+                        self._notify_observers(delta, old_obj, new_obj)
+                    self._watch_received.set()
+            except CancelledError:
+                log.debug('Closing watcher connection')
+                await asyncio.wait_for(self._watch_conn.close(), None)
+                self._watch_shutdown.set()
+                self._watch_conn = None
+
+        log.debug('Starting watcher task')
+        self._watcher_task = self.loop.create_task(_start_watch())
+
+    def _stop_watching(self):
+        """Stop the asynchronous watch against this model.
+
+        """
+        log.debug('Stopping watcher task')
+        if self._watcher_task:
+            self._watcher_task.cancel()
 
     def _apply_delta(self, delta):
         """Apply delta to our model state and return the a copy of the
@@ -182,14 +248,20 @@ class Model(object):
         pass
     add_machines = add_machine
 
-    def add_relation(self, relation1, relation2):
-        """Add a relation between two services.
+    async def add_relation(self, relation1, relation2):
+        """Add a relation between two applications.
 
-        :param str relation1: '<service>[:<relation_name>]'
-        :param str relation2: '<service>[:<relation_name>]'
+        :param str relation1: '<application>[:<relation_name>]'
+        :param str relation2: '<application>[:<relation_name>]'
 
         """
-        pass
+        app_facade = client.ApplicationFacade()
+        app_facade.connect(self.connection)
+
+        log.debug(
+            'Adding relation %s <-> %s', relation1, relation2)
+
+        return await app_facade.AddRelation([relation1, relation2])
 
     def add_space(self, name, *cidrs):
         """Add a new network space.