Bump rev for release
[osm/N2VC.git] / juju / model.py
index c15d07c..4db711b 100644 (file)
@@ -18,9 +18,8 @@ import yaml
 import theblues.charmstore
 import theblues.errors
 
-from . import tag
+from . import tag, utils
 from .client import client
-from .client import watcher
 from .client import connection
 from .constraints import parse as parse_constraints, normalize_key
 from .delta import get_entity_delta
@@ -77,6 +76,9 @@ class _Observer(object):
 
 
 class ModelObserver(object):
+    """
+    Base class for creating observers that react to changes in a model.
+    """
     async def __call__(self, delta, old, new, model):
         handler_name = 'on_{}_{}'.format(delta.entity, delta.type)
         method = getattr(self, handler_name, self.on_change)
@@ -85,6 +87,8 @@ class ModelObserver(object):
     async def on_change(self, delta, old, new, model):
         """Generic model-change handler.
 
+        This should be overridden in a subclass.
+
         :param delta: :class:`juju.client.overrides.Delta`
         :param old: :class:`juju.model.ModelEntity`
         :param new: :class:`juju.model.ModelEntity`
@@ -375,6 +379,9 @@ class ModelEntity(object):
 
 
 class Model(object):
+    """
+    The main API for interacting with a Juju model.
+    """
     def __init__(self, loop=None):
         """Instantiate a new connected Model.
 
@@ -386,11 +393,21 @@ class Model(object):
         self.observers = weakref.WeakValueDictionary()
         self.state = ModelState(self)
         self.info = None
-        self._watcher_task = None
-        self._watch_shutdown = asyncio.Event(loop=self.loop)
+        self._watch_stopping = asyncio.Event(loop=self.loop)
+        self._watch_stopped = asyncio.Event(loop=self.loop)
         self._watch_received = asyncio.Event(loop=self.loop)
         self._charmstore = CharmStore(self.loop)
 
+    async def __aenter__(self):
+        await self.connect_current()
+        return self
+
+    async def __aexit__(self, exc_type, exc, tb):
+        await self.disconnect()
+
+        if exc_type is not None:
+            return False
+
     async def connect(self, *args, **kw):
         """Connect to an arbitrary Juju model.
 
@@ -432,9 +449,10 @@ class Model(object):
         """Shut down the watcher task and close websockets.
 
         """
-        self._stop_watching()
         if self.connection and self.connection.is_open:
-            await self._watch_shutdown.wait()
+            log.debug('Stopping watcher task')
+            self._watch_stopping.set()
+            await self._watch_stopped.wait()
             log.debug('Closing model connection')
             await self.connection.close()
             self.connection = None
@@ -566,8 +584,7 @@ class Model(object):
         explicit call to this method.
 
         """
-        facade = client.ClientFacade()
-        facade.connect(self.connection)
+        facade = client.ClientFacade.from_connection(self.connection)
 
         self.info = await facade.ModelInfo()
         log.debug('Got ModelInfo: %s', vars(self.info))
@@ -621,42 +638,34 @@ class Model(object):
 
         """
         async def _start_watch():
-            self._watch_shutdown.clear()
             try:
-                allwatcher = watcher.AllWatcher()
-                self._watch_conn = await self.connection.clone()
-                allwatcher.connect(self._watch_conn)
-                while True:
-                    results = await allwatcher.Next()
+                allwatcher = client.AllWatcherFacade.from_connection(
+                    self.connection)
+                while not self._watch_stopping.is_set():
+                    results = await utils.run_with_interrupt(
+                        allwatcher.Next(),
+                        self._watch_stopping,
+                        self.loop)
+                    if self._watch_stopping.is_set():
+                        break
                     for delta in results.deltas:
                         delta = get_entity_delta(delta)
                         old_obj, new_obj = self.state.apply_delta(delta)
-                        # XXX: Might not want to shield at this level
-                        # We are shielding because when the watcher is
-                        # canceled (on disconnect()), we don't want all of
-                        # its children (every observer callback) to be
-                        # canceled with it. So we shield them. But this means
-                        # they can *never* be canceled.
-                        await asyncio.shield(
-                            self._notify_observers(delta, old_obj, new_obj),
-                            loop=self.loop)
+                        await self._notify_observers(delta, old_obj, new_obj)
                     self._watch_received.set()
             except CancelledError:
-                log.debug('Closing watcher connection')
-                await self._watch_conn.close()
-                self._watch_shutdown.set()
-                self._watch_conn = None
+                pass
+            except Exception:
+                log.exception('Error in watcher')
+                raise
+            finally:
+                self._watch_stopped.set()
 
         log.debug('Starting watcher task')
-        self._watcher_task = self.loop.create_task(_start_watch())
-
-    def _stop_watching(self):
-        """Stop the asynchronous watch against this model.
-
-        """
-        log.debug('Stopping watcher task')
-        if self._watcher_task:
-            self._watcher_task.cancel()
+        self._watch_received.clear()
+        self._watch_stopping.clear()
+        self._watch_stopped.clear()
+        self.loop.create_task(_start_watch())
 
     async def _notify_observers(self, delta, old_obj, new_obj):
         """Call observing callbacks, notifying them of a change in model state
@@ -754,14 +763,34 @@ class Model(object):
                 'zone=us-east-1a' - starts a machine in zone us-east-1s on AWS
                 'maas2.name' - acquire machine maas2.name on MAAS
 
-        :param dict constraints: Machine constraints
+        :param dict constraints: Machine constraints, which can contain the
+            the following keys::
+
+                arch : str
+                container : str
+                cores : int
+                cpu_power : int
+                instance_type : str
+                mem : int
+                root_disk : int
+                spaces : list(str)
+                tags : list(str)
+                virt_type : str
+
             Example::
 
                 constraints={
                     'mem': 256 * MB,
+                    'tags': ['virtual'],
                 }
 
-        :param list disks: List of disk constraint dictionaries
+        :param list disks: List of disk constraint dictionaries, which can
+            contain the following keys::
+
+                count : int
+                pool : str
+                size : int
+
             Example::
 
                 disks=[{
@@ -797,8 +826,7 @@ class Model(object):
             params.series = series
 
         # Submit the request.
-        client_facade = client.ClientFacade()
-        client_facade.connect(self.connection)
+        client_facade = client.ClientFacade.from_connection(self.connection)
         results = await client_facade.AddMachines([params])
         error = results.machines[0].error
         if error:
@@ -814,8 +842,7 @@ class Model(object):
         :param str relation2: '<application>[:<relation_name>]'
 
         """
-        app_facade = client.ApplicationFacade()
-        app_facade.connect(self.connection)
+        app_facade = client.ApplicationFacade.from_connection(self.connection)
 
         log.debug(
             'Adding relation %s <-> %s', relation1, relation2)
@@ -858,8 +885,7 @@ class Model(object):
         :param str key: The public ssh key
 
         """
-        key_facade = client.KeyManagerFacade()
-        key_facade.connect(self.connection)
+        key_facade = client.KeyManagerFacade.from_connection(self.connection)
         return await key_facade.AddKeys([key], user)
     add_ssh_keys = add_ssh_key
 
@@ -976,7 +1002,7 @@ class Model(object):
         :param dict bind: <charm endpoint>:<network space> pairs
         :param dict budget: <budget name>:<limit> pairs
         :param str channel: Charm store channel from which to retrieve
-            the charm or bundle, e.g. 'development'
+            the charm or bundle, e.g. 'edge'
         :param dict config: Charm configuration dictionary
         :param constraints: Service constraints
         :type constraints: :class:`juju.Constraints`
@@ -1012,13 +1038,12 @@ class Model(object):
             os.path.isdir(entity_url)
         )
         if is_local:
-            entity_id = entity_url
+            entity_id = entity_url.replace('local:', '')
         else:
-            entity = await self.charmstore.entity(entity_url)
+            entity = await self.charmstore.entity(entity_url, channel=channel)
             entity_id = entity['Id']
 
-        client_facade = client.ClientFacade()
-        client_facade.connect(self.connection)
+        client_facade = client.ClientFacade.from_connection(self.connection)
 
         is_bundle = ((is_local and
                       (Path(entity_id) / 'bundle.yaml').exists()) or
@@ -1047,8 +1072,6 @@ class Model(object):
                     application_name = entity['Meta']['charm-metadata']['Name']
                 if not series:
                     series = self._get_series(entity_url, entity)
-                if not channel:
-                    channel = 'stable'
                 await client_facade.AddCharm(channel, entity_id)
                 # XXX: we're dropping local resources here, but we don't
                 # actually support them yet anyway
@@ -1100,8 +1123,8 @@ class Model(object):
         if not resources:
             return None
 
-        resources_facade = client.ResourcesFacade()
-        resources_facade.connect(self.connection)
+        resources_facade = client.ResourcesFacade.from_connection(
+            self.connection)
         response = await resources_facade.AddPendingResources(
             tag.application(application),
             entity_url,
@@ -1123,8 +1146,8 @@ class Model(object):
         config = yaml.dump({application: config},
                            default_flow_style=False)
 
-        app_facade = client.ApplicationFacade()
-        app_facade.connect(self.connection)
+        app_facade = client.ApplicationFacade.from_connection(
+            self.connection)
 
         app = client.ApplicationDeploy(
             charm_url=charm_url,
@@ -1156,8 +1179,7 @@ class Model(object):
         """Destroy units by name.
 
         """
-        app_facade = client.ApplicationFacade()
-        app_facade.connect(self.connection)
+        app_facade = client.ApplicationFacade.from_connection(self.connection)
 
         log.debug(
             'Destroying unit%s %s',
@@ -1214,9 +1236,9 @@ class Model(object):
         :param str acl: Access control ('read' or 'write')
 
         """
-        model_facade = client.ModelManagerFacade()
         controller_conn = await self.connection.controller()
-        model_facade.connect(controller_conn)
+        model_facade = client.ModelManagerFacade.from_connection(
+            controller_conn)
         user = tag.user(username)
         model = tag.model(self.info.uuid)
         changes = client.ModifyModelAccess(acl, 'grant', model, user)
@@ -1256,8 +1278,7 @@ class Model(object):
             else it's fingerprint
 
         """
-        key_facade = client.KeyManagerFacade()
-        key_facade.connect(self.connection)
+        key_facade = client.KeyManagerFacade.from_connection(self.connection)
         entity = {'tag': tag.model(self.info.uuid)}
         entities = client.Entities([entity])
         return await key_facade.ListKeys(entities, raw_ssh)
@@ -1330,8 +1351,7 @@ class Model(object):
         :param str user: Juju user to which the key is registered
 
         """
-        key_facade = client.KeyManagerFacade()
-        key_facade.connect(self.connection)
+        key_facade = client.KeyManagerFacade.from_connection(self.connection)
         key = base64.b64decode(bytes(key.strip().split()[1].encode('ascii')))
         key = hashlib.md5(key).hexdigest()
         key = ':'.join(a+b for a, b in zip(key[::2], key[1::2]))
@@ -1365,9 +1385,9 @@ class Model(object):
         :param str username: Username to revoke
 
         """
-        model_facade = client.ModelManagerFacade()
         controller_conn = await self.connection.controller()
-        model_facade.connect(controller_conn)
+        model_facade = client.ModelManagerFacade.from_connection(
+            controller_conn)
         user = tag.user(username)
         model = tag.model(self.info.uuid)
         changes = client.ModifyModelAccess('read', 'revoke', model, user)
@@ -1432,8 +1452,7 @@ class Model(object):
         :param bool utc: Display time as UTC in RFC3339 format
 
         """
-        client_facade = client.ClientFacade()
-        client_facade.connect(self.connection)
+        client_facade = client.ClientFacade.from_connection(self.connection)
         return await client_facade.FullStatus(filters)
 
     def sync_tools(
@@ -1513,8 +1532,8 @@ class Model(object):
         log.debug("Retrieving metrics for %s",
                   ', '.join(tags) if tags else "all units")
 
-        metrics_facade = client.MetricsDebugFacade()
-        metrics_facade.connect(self.connection)
+        metrics_facade = client.MetricsDebugFacade.from_connection(
+            self.connection)
 
         entities = [client.Entity(tag) for tag in tags]
         metrics_result = await metrics_facade.GetMetrics(entities)
@@ -1564,12 +1583,12 @@ class BundleHandler(object):
         for unit_name, unit in model.units.items():
             app_units = self._units_by_app.setdefault(unit.application, [])
             app_units.append(unit_name)
-        self.client_facade = client.ClientFacade()
-        self.client_facade.connect(model.connection)
-        self.app_facade = client.ApplicationFacade()
-        self.app_facade.connect(model.connection)
-        self.ann_facade = client.AnnotationsFacade()
-        self.ann_facade.connect(model.connection)
+        self.client_facade = client.ClientFacade.from_connection(
+            model.connection)
+        self.app_facade = client.ApplicationFacade.from_connection(
+            model.connection)
+        self.ann_facade = client.AnnotationsFacade.from_connection(
+            model.connection)
 
     async def _handle_local_charms(self, bundle):
         """Search for references to local charms (i.e. filesystem paths)
@@ -1877,6 +1896,12 @@ class CharmStore(object):
 
 
 class CharmArchiveGenerator(object):
+    """
+    Create a Zip archive of a local charm directory for upload to a controller.
+
+    This is used automatically by
+    `Model.add_local_charm_dir <#juju.model.Model.add_local_charm_dir>`_.
+    """
     def __init__(self, path):
         self.path = os.path.abspath(os.path.expanduser(path))