Expand integration tests to use stable/edge versions of juju (#155)
[osm/N2VC.git] / juju / model.py
index 3df2669..bd8709a 100644 (file)
@@ -1,5 +1,7 @@
 import asyncio
+import base64
 import collections
+import hashlib
 import json
 import logging
 import os
@@ -12,13 +14,15 @@ from concurrent.futures import CancelledError
 from functools import partial
 from pathlib import Path
 
+import websockets
 import yaml
 import theblues.charmstore
 import theblues.errors
 
+from . import tag, utils
 from .client import client
-from .client import watcher
 from .client import connection
+from .client.client import ConfigValue
 from .constraints import parse as parse_constraints, normalize_key
 from .delta import get_entity_delta
 from .delta import get_entity_class
@@ -74,6 +78,9 @@ class _Observer(object):
 
 
 class ModelObserver(object):
+    """
+    Base class for creating observers that react to changes in a model.
+    """
     async def __call__(self, delta, old, new, model):
         handler_name = 'on_{}_{}'.format(delta.entity, delta.type)
         method = getattr(self, handler_name, self.on_change)
@@ -82,6 +89,8 @@ class ModelObserver(object):
     async def on_change(self, delta, old, new, model):
         """Generic model-change handler.
 
+        This should be overridden in a subclass.
+
         :param delta: :class:`juju.client.overrides.Delta`
         :param old: :class:`juju.model.ModelEntity`
         :param new: :class:`juju.model.ModelEntity`
@@ -230,7 +239,14 @@ class ModelEntity(object):
         model.
 
         """
-        return self.safe_data[name]
+        try:
+            return self.safe_data[name]
+        except KeyError:
+            name = name.replace('_', '-')
+            if name in self.safe_data:
+                return self.safe_data[name]
+            else:
+                raise
 
     def __bool__(self):
         return bool(self.data)
@@ -365,22 +381,39 @@ class ModelEntity(object):
 
 
 class Model(object):
-    def __init__(self, loop=None):
+    """
+    The main API for interacting with a Juju model.
+    """
+    def __init__(self, loop=None,
+                 max_frame_size=connection.Connection.DEFAULT_FRAME_SIZE):
         """Instantiate a new connected Model.
 
         :param loop: an asyncio event loop
+        :param max_frame_size: See
+            `juju.client.connection.Connection.MAX_FRAME_SIZE`
 
         """
         self.loop = loop or asyncio.get_event_loop()
+        self.max_frame_size = max_frame_size
         self.connection = None
         self.observers = weakref.WeakValueDictionary()
         self.state = ModelState(self)
         self.info = None
-        self._watcher_task = None
-        self._watch_shutdown = asyncio.Event(loop=self.loop)
+        self._watch_stopping = asyncio.Event(loop=self.loop)
+        self._watch_stopped = asyncio.Event(loop=self.loop)
         self._watch_received = asyncio.Event(loop=self.loop)
         self._charmstore = CharmStore(self.loop)
 
+    async def __aenter__(self):
+        await self.connect_current()
+        return self
+
+    async def __aexit__(self, exc_type, exc, tb):
+        await self.disconnect()
+
+        if exc_type is not None:
+            return False
+
     async def connect(self, *args, **kw):
         """Connect to an arbitrary Juju model.
 
@@ -389,6 +422,8 @@ class Model(object):
         """
         if 'loop' not in kw:
             kw['loop'] = self.loop
+        if 'max_frame_size' not in kw:
+            kw['max_frame_size'] = self.max_frame_size
         self.connection = await connection.Connection.connect(*args, **kw)
         await self._after_connect()
 
@@ -397,7 +432,7 @@ class Model(object):
 
         """
         self.connection = await connection.Connection.connect_current(
-            self.loop)
+            self.loop, max_frame_size=self.max_frame_size)
         await self._after_connect()
 
     async def connect_model(self, model_name):
@@ -406,8 +441,8 @@ class Model(object):
         :param model_name:  Format [controller:][user/]model
 
         """
-        self.connection = await connection.Connection.connect_model(model_name,
-                                                                    self.loop)
+        self.connection = await connection.Connection.connect_model(
+            model_name, self.loop, self.max_frame_size)
         await self._after_connect()
 
     async def _after_connect(self):
@@ -422,9 +457,10 @@ class Model(object):
         """Shut down the watcher task and close websockets.
 
         """
-        self._stop_watching()
         if self.connection and self.connection.is_open:
-            await self._watch_shutdown.wait()
+            log.debug('Stopping watcher task')
+            self._watch_stopping.set()
+            await self._watch_stopped.wait()
             log.debug('Closing model connection')
             await self.connection.close()
             self.connection = None
@@ -516,6 +552,8 @@ class Model(object):
         """
         async def _block():
             while not all(c() for c in conditions):
+                if not (self.connection and self.connection.is_open):
+                    raise websockets.ConnectionClosed(1006, 'no reason')
                 await asyncio.sleep(wait_period, loop=self.loop)
         await asyncio.wait_for(_block(), timeout, loop=self.loop)
 
@@ -556,8 +594,7 @@ class Model(object):
         explicit call to this method.
 
         """
-        facade = client.ClientFacade()
-        facade.connect(self.connection)
+        facade = client.ClientFacade.from_connection(self.connection)
 
         self.info = await facade.ModelInfo()
         log.debug('Got ModelInfo: %s', vars(self.info))
@@ -610,43 +647,64 @@ class Model(object):
         See :meth:`add_observer` to register an onchange callback.
 
         """
-        async def _start_watch():
-            self._watch_shutdown.clear()
+        async def _all_watcher():
             try:
-                allwatcher = watcher.AllWatcher()
-                self._watch_conn = await self.connection.clone()
-                allwatcher.connect(self._watch_conn)
-                while True:
-                    results = await allwatcher.Next()
+                allwatcher = client.AllWatcherFacade.from_connection(
+                    self.connection)
+                while not self._watch_stopping.is_set():
+                    try:
+                        results = await utils.run_with_interrupt(
+                            allwatcher.Next(),
+                            self._watch_stopping,
+                            self.loop)
+                    except JujuAPIError as e:
+                        if 'watcher was stopped' not in str(e):
+                            raise
+                        if self._watch_stopping.is_set():
+                            # this shouldn't ever actually happen, because
+                            # the event should trigger before the controller
+                            # has a chance to tell us the watcher is stopped
+                            # but handle it gracefully, just in case
+                            break
+                        # controller stopped our watcher for some reason
+                        # but we're not actually stopping, so just restart it
+                        log.warning(
+                            'Watcher: watcher stopped, restarting')
+                        del allwatcher.Id
+                        continue
+                    except websockets.ConnectionClosed:
+                        monitor = self.connection.monitor
+                        if monitor.status == monitor.ERROR:
+                            # closed unexpectedly, try to reopen
+                            log.warning(
+                                'Watcher: connection closed, reopening')
+                            await self.connection.reconnect()
+                            del allwatcher.Id
+                            continue
+                        else:
+                            # closed on request, go ahead and shutdown
+                            break
+                    if self._watch_stopping.is_set():
+                        await allwatcher.Stop()
+                        break
                     for delta in results.deltas:
                         delta = get_entity_delta(delta)
                         old_obj, new_obj = self.state.apply_delta(delta)
-                        # XXX: Might not want to shield at this level
-                        # We are shielding because when the watcher is
-                        # canceled (on disconnect()), we don't want all of
-                        # its children (every observer callback) to be
-                        # canceled with it. So we shield them. But this means
-                        # they can *never* be canceled.
-                        await asyncio.shield(
-                            self._notify_observers(delta, old_obj, new_obj),
-                            loop=self.loop)
+                        await self._notify_observers(delta, old_obj, new_obj)
                     self._watch_received.set()
             except CancelledError:
-                log.debug('Closing watcher connection')
-                await self._watch_conn.close()
-                self._watch_shutdown.set()
-                self._watch_conn = None
+                pass
+            except Exception:
+                log.exception('Error in watcher')
+                raise
+            finally:
+                self._watch_stopped.set()
 
         log.debug('Starting watcher task')
-        self._watcher_task = self.loop.create_task(_start_watch())
-
-    def _stop_watching(self):
-        """Stop the asynchronous watch against this model.
-
-        """
-        log.debug('Stopping watcher task')
-        if self._watcher_task:
-            self._watcher_task.cancel()
+        self._watch_received.clear()
+        self._watch_stopping.clear()
+        self._watch_stopped.clear()
+        self.loop.create_task(_all_watcher())
 
     async def _notify_observers(self, delta, old_obj, new_obj):
         """Call observing callbacks, notifying them of a change in model state
@@ -744,14 +802,34 @@ class Model(object):
                 'zone=us-east-1a' - starts a machine in zone us-east-1s on AWS
                 'maas2.name' - acquire machine maas2.name on MAAS
 
-        :param dict constraints: Machine constraints
+        :param dict constraints: Machine constraints, which can contain the
+            the following keys::
+
+                arch : str
+                container : str
+                cores : int
+                cpu_power : int
+                instance_type : str
+                mem : int
+                root_disk : int
+                spaces : list(str)
+                tags : list(str)
+                virt_type : str
+
             Example::
 
                 constraints={
                     'mem': 256 * MB,
+                    'tags': ['virtual'],
                 }
 
-        :param list disks: List of disk constraint dictionaries
+        :param list disks: List of disk constraint dictionaries, which can
+            contain the following keys::
+
+                count : int
+                pool : str
+                size : int
+
             Example::
 
                 disks=[{
@@ -787,12 +865,11 @@ class Model(object):
             params.series = series
 
         # Submit the request.
-        client_facade = client.ClientFacade()
-        client_facade.connect(self.connection)
+        client_facade = client.ClientFacade.from_connection(self.connection)
         results = await client_facade.AddMachines([params])
         error = results.machines[0].error
         if error:
-            raise ValueError("Error adding machine: %s", error.message)
+            raise ValueError("Error adding machine: %s" % error.message)
         machine_id = results.machines[0].machine
         log.debug('Added new machine %s', machine_id)
         return await self._wait_for_new('machine', machine_id)
@@ -804,8 +881,7 @@ class Model(object):
         :param str relation2: '<application>[:<relation_name>]'
 
         """
-        app_facade = client.ApplicationFacade()
-        app_facade.connect(self.connection)
+        app_facade = client.ApplicationFacade.from_connection(self.connection)
 
         log.debug(
             'Adding relation %s <-> %s', relation1, relation2)
@@ -841,13 +917,15 @@ class Model(object):
         """
         raise NotImplementedError()
 
-    def add_ssh_key(self, key):
+    async def add_ssh_key(self, user, key):
         """Add a public SSH key to this model.
 
+        :param str user: The username of the user
         :param str key: The public ssh key
 
         """
-        raise NotImplementedError()
+        key_facade = client.KeyManagerFacade.from_connection(self.connection)
+        return await key_facade.AddKeys([key], user)
     add_ssh_keys = add_ssh_key
 
     def add_subnet(self, cidr_or_id, space, *zones):
@@ -935,6 +1013,22 @@ class Model(object):
         """
         raise NotImplementedError()
 
+    def _get_series(self, entity_url, entity):
+        # try to get the series from the provided charm URL
+        if entity_url.startswith('cs:'):
+            parts = entity_url[3:].split('/')
+        else:
+            parts = entity_url.split('/')
+        if parts[0].startswith('~'):
+            parts.pop(0)
+        if len(parts) > 1:
+            # series was specified in the URL
+            return parts[0]
+        # series was not supplied at all, so use the newest
+        # supported series according to the charm store
+        ss = entity['Meta']['supported-series']
+        return ss['SupportedSeries'][0]
+
     async def deploy(
             self, entity_url, application_name=None, bind=None, budget=None,
             channel=None, config=None, constraints=None, force=False,
@@ -947,7 +1041,7 @@ class Model(object):
         :param dict bind: <charm endpoint>:<network space> pairs
         :param dict budget: <budget name>:<limit> pairs
         :param str channel: Charm store channel from which to retrieve
-            the charm or bundle, e.g. 'development'
+            the charm or bundle, e.g. 'edge'
         :param dict config: Charm configuration dictionary
         :param constraints: Service constraints
         :type constraints: :class:`juju.Constraints`
@@ -969,9 +1063,7 @@ class Model(object):
 
         TODO::
 
-            - application_name is required; fill this in automatically if not
-              provided by caller
-            - series is required; how do we pick a default?
+            - support local resources
 
         """
         if storage:
@@ -985,13 +1077,12 @@ class Model(object):
             os.path.isdir(entity_url)
         )
         if is_local:
-            entity_id = entity_url
+            entity_id = entity_url.replace('local:', '')
         else:
-            entity = await self.charmstore.entity(entity_url)
+            entity = await self.charmstore.entity(entity_url, channel=channel)
             entity_id = entity['Id']
 
-        client_facade = client.ClientFacade()
-        client_facade.connect(self.connection)
+        client_facade = client.ClientFacade.from_connection(self.connection)
 
         is_bundle = ((is_local and
                       (Path(entity_id) / 'bundle.yaml').exists()) or
@@ -1018,25 +1109,14 @@ class Model(object):
             if not is_local:
                 if not application_name:
                     application_name = entity['Meta']['charm-metadata']['Name']
-                if not series and '/' in entity_url:
-                    # try to get the series from the provided charm URL
-                    if entity_url.startswith('cs:'):
-                        parts = entity_url[3:].split('/')
-                    else:
-                        parts = entity_url.split('/')
-                    if parts[0].startswith('~'):
-                        parts.pop(0)
-                    if len(parts) > 1:
-                        # series was specified in the URL
-                        series = parts[0]
                 if not series:
-                    # series was not supplied at all, so use the newest
-                    # supported series according to the charm store
-                    ss = entity['Meta']['supported-series']
-                    series = ss['SupportedSeries'][0]
-                if not channel:
-                    channel = 'stable'
+                    series = self._get_series(entity_url, entity)
                 await client_facade.AddCharm(channel, entity_id)
+                # XXX: we're dropping local resources here, but we don't
+                # actually support them yet anyway
+                resources = await self._add_store_resources(application_name,
+                                                            entity_id,
+                                                            entity)
             else:
                 # We have a local charm dir that needs to be uploaded
                 charm_dir = os.path.abspath(
@@ -1059,9 +1139,40 @@ class Model(object):
                 storage=storage,
                 channel=channel,
                 num_units=num_units,
-                placement=parse_placement(to),
+                placement=parse_placement(to)
             )
 
+    async def _add_store_resources(self, application, entity_url, entity=None):
+        if not entity:
+            # avoid extra charm store call if one was already made
+            entity = await self.charmstore.entity(entity_url)
+        resources = [
+            {
+                'description': resource['Description'],
+                'fingerprint': resource['Fingerprint'],
+                'name': resource['Name'],
+                'path': resource['Path'],
+                'revision': resource['Revision'],
+                'size': resource['Size'],
+                'type_': resource['Type'],
+                'origin': 'store',
+            } for resource in entity['Meta']['resources']
+        ]
+
+        if not resources:
+            return None
+
+        resources_facade = client.ResourcesFacade.from_connection(
+            self.connection)
+        response = await resources_facade.AddPendingResources(
+            tag.application(application),
+            entity_url,
+            [client.CharmResource(**resource) for resource in resources])
+        resource_map = {resource['name']: pid
+                        for resource, pid
+                        in zip(resources, response.pending_ids)}
+        return resource_map
+
     async def _deploy(self, charm_url, application, series, config,
                       constraints, endpoint_bindings, resources, storage,
                       channel=None, num_units=None, placement=None):
@@ -1074,8 +1185,8 @@ class Model(object):
         config = yaml.dump({application: config},
                            default_flow_style=False)
 
-        app_facade = client.ApplicationFacade()
-        app_facade.connect(self.connection)
+        app_facade = client.ApplicationFacade.from_connection(
+            self.connection)
 
         app = client.ApplicationDeploy(
             charm_url=charm_url,
@@ -1088,7 +1199,7 @@ class Model(object):
             num_units=num_units,
             resources=resources,
             storage=storage,
-            placement=placement,
+            placement=placement
         )
 
         result = await app_facade.Deploy([app])
@@ -1097,9 +1208,9 @@ class Model(object):
             raise JujuError('\n'.join(errors))
         return await self._wait_for_new('application', application)
 
-    def destroy(self):
+    async def destroy(self):
         """Terminate all machines and resources for this model.
-
+            Is already implemented in controller.py.
         """
         raise NotImplementedError()
 
@@ -1107,8 +1218,7 @@ class Model(object):
         """Destroy units by name.
 
         """
-        app_facade = client.ApplicationFacade()
-        app_facade.connect(self.connection)
+        app_facade = client.ApplicationFacade.from_connection(self.connection)
 
         log.debug(
             'Destroying unit%s %s',
@@ -1146,11 +1256,20 @@ class Model(object):
         """
         raise NotImplementedError()
 
-    def get_config(self):
+    async def get_config(self):
         """Return the configuration settings for this model.
 
+        :returns: A ``dict`` mapping keys to `ConfigValue` instances,
+            which have `source` and `value` attributes.
         """
-        raise NotImplementedError()
+        config_facade = client.ModelConfigFacade.from_connection(
+            self.connection
+        )
+        result = await config_facade.ModelGet()
+        config = result.config
+        for key, value in config.items():
+            config[key] = ConfigValue.from_json(value)
+        return config
 
     def get_constraints(self):
         """Return the machine constraints for this model.
@@ -1158,14 +1277,21 @@ class Model(object):
         """
         raise NotImplementedError()
 
-    def grant(self, username, acl='read'):
+    async def grant(self, username, acl='read'):
         """Grant a user access to this model.
 
         :param str username: Username
         :param str acl: Access control ('read' or 'write')
 
         """
-        raise NotImplementedError()
+        controller_conn = await self.connection.controller()
+        model_facade = client.ModelManagerFacade.from_connection(
+            controller_conn)
+        user = tag.user(username)
+        model = tag.model(self.info.uuid)
+        changes = client.ModifyModelAccess(acl, 'grant', model, user)
+        await self.revoke(username)
+        return await model_facade.ModifyModelAccess([changes])
 
     def import_ssh_key(self, identity):
         """Add a public SSH key from a trusted indentity source to this model.
@@ -1176,14 +1302,11 @@ class Model(object):
         raise NotImplementedError()
     import_ssh_keys = import_ssh_key
 
-    def get_machines(self, machine, utc=False):
+    async def get_machines(self):
         """Return list of machines in this model.
 
-        :param str machine: Machine id, e.g. '0'
-        :param bool utc: Display time as UTC in RFC3339 format
-
         """
-        raise NotImplementedError()
+        return list(self.state.machines.keys())
 
     def get_shares(self):
         """Return list of all users with access to this model.
@@ -1197,11 +1320,16 @@ class Model(object):
         """
         raise NotImplementedError()
 
-    def get_ssh_key(self):
+    async def get_ssh_key(self, raw_ssh=False):
         """Return known SSH keys for this model.
+        :param bool raw_ssh: if True, returns the raw ssh key,
+            else it's fingerprint
 
         """
-        raise NotImplementedError()
+        key_facade = client.KeyManagerFacade.from_connection(self.connection)
+        entity = {'tag': tag.model(self.info.uuid)}
+        entities = client.Entities([entity])
+        return await key_facade.ListKeys(entities, raw_ssh)
     get_ssh_keys = get_ssh_key
 
     def get_storage(self, filesystem=False, volume=False):
@@ -1264,13 +1392,18 @@ class Model(object):
         raise NotImplementedError()
     remove_machines = remove_machine
 
-    def remove_ssh_key(self, *keys):
+    async def remove_ssh_key(self, user, key):
         """Remove a public SSH key(s) from this model.
 
-        :param str \*keys: Keys to remove
+        :param str key: Full ssh key
+        :param str user: Juju user to which the key is registered
 
         """
-        raise NotImplementedError()
+        key_facade = client.KeyManagerFacade.from_connection(self.connection)
+        key = base64.b64decode(bytes(key.strip().split()[1].encode('ascii')))
+        key = hashlib.md5(key).hexdigest()
+        key = ':'.join(a+b for a, b in zip(key[::2], key[1::2]))
+        await key_facade.DeleteKeys([key], user)
     remove_ssh_keys = remove_ssh_key
 
     def restore_backup(
@@ -1294,14 +1427,19 @@ class Model(object):
         """
         raise NotImplementedError()
 
-    def revoke(self, username, acl='read'):
+    async def revoke(self, username):
         """Revoke a user's access to this model.
 
         :param str username: Username to revoke
-        :param str acl: Access control ('read' or 'write')
 
         """
-        raise NotImplementedError()
+        controller_conn = await self.connection.controller()
+        model_facade = client.ModelManagerFacade.from_connection(
+            controller_conn)
+        user = tag.user(username)
+        model = tag.model(self.info.uuid)
+        changes = client.ModifyModelAccess('read', 'revoke', model, user)
+        return await model_facade.ModifyModelAccess([changes])
 
     def run(self, command, timeout=None):
         """Run command on all machines in this model.
@@ -1312,13 +1450,19 @@ class Model(object):
         """
         raise NotImplementedError()
 
-    def set_config(self, **config):
+    async def set_config(self, config):
         """Set configuration keys on this model.
 
-        :param \*\*config: Config key/values
-
+        :param dict config: Mapping of config keys to either string values or
+            `ConfigValue` instances, as returned by `get_config`.
         """
-        raise NotImplementedError()
+        config_facade = client.ModelConfigFacade.from_connection(
+            self.connection
+        )
+        for key, value in config.items():
+            if isinstance(value, ConfigValue):
+                config[key] = value.value
+        await config_facade.ModelSet(config)
 
     def set_constraints(self, constraints):
         """Set machine constraints on this model.
@@ -1362,8 +1506,7 @@ class Model(object):
         :param bool utc: Display time as UTC in RFC3339 format
 
         """
-        client_facade = client.ClientFacade()
-        client_facade.connect(self.connection)
+        client_facade = client.ClientFacade.from_connection(self.connection)
         return await client_facade.FullStatus(filters)
 
     def sync_tools(
@@ -1443,8 +1586,8 @@ class Model(object):
         log.debug("Retrieving metrics for %s",
                   ', '.join(tags) if tags else "all units")
 
-        metrics_facade = client.MetricsDebugFacade()
-        metrics_facade.connect(self.connection)
+        metrics_facade = client.MetricsDebugFacade.from_connection(
+            self.connection)
 
         entities = [client.Entity(tag) for tag in tags]
         metrics_result = await metrics_facade.GetMetrics(entities)
@@ -1494,12 +1637,12 @@ class BundleHandler(object):
         for unit_name, unit in model.units.items():
             app_units = self._units_by_app.setdefault(unit.application, [])
             app_units.append(unit_name)
-        self.client_facade = client.ClientFacade()
-        self.client_facade.connect(model.connection)
-        self.app_facade = client.ApplicationFacade()
-        self.app_facade.connect(model.connection)
-        self.ann_facade = client.AnnotationsFacade()
-        self.ann_facade.connect(model.connection)
+        self.client_facade = client.ClientFacade.from_connection(
+            model.connection)
+        self.app_facade = client.ApplicationFacade.from_connection(
+            model.connection)
+        self.ann_facade = client.AnnotationsFacade.from_connection(
+            model.connection)
 
     async def _handle_local_charms(self, bundle):
         """Search for references to local charms (i.e. filesystem paths)
@@ -1563,9 +1706,6 @@ class BundleHandler(object):
         self.plan = await self.client_facade.GetBundleChanges(
             yaml.dump(self.bundle))
 
-        if self.plan.errors:
-            raise JujuError('\n'.join(self.plan.errors))
-
     async def execute_plan(self):
         for step in self.plan.changes:
             method = getattr(self, step.method)
@@ -1645,7 +1785,7 @@ class BundleHandler(object):
         results = await self.client_facade.AddMachines([params])
         error = results.machines[0].error
         if error:
-            raise ValueError("Error adding machine: %s", error.message)
+            raise ValueError("Error adding machine: %s" % error.message)
         machine = results.machines[0].machine
         log.debug('Added new machine %s', machine)
         return machine
@@ -1701,6 +1841,11 @@ class BundleHandler(object):
         """
         # resolve indirect references
         charm = self.resolve(charm)
+        # the bundle plan doesn't actually do anything with resources, even
+        # though it ostensibly gives us something (None) for that param
+        if not charm.startswith('local:'):
+            resources = await self.model._add_store_resources(application,
+                                                              charm)
         await self.model._deploy(
             charm_url=charm,
             application=application,
@@ -1805,6 +1950,12 @@ class CharmStore(object):
 
 
 class CharmArchiveGenerator(object):
+    """
+    Create a Zip archive of a local charm directory for upload to a controller.
+
+    This is used automatically by
+    `Model.add_local_charm_dir <#juju.model.Model.add_local_charm_dir>`_.
+    """
     def __init__(self, path):
         self.path = os.path.abspath(os.path.expanduser(path))