Improved Primitive support and better testing
[osm/N2VC.git] / modules / libjuju / juju / model.py
index bd8709a..37e8cd6 100644 (file)
@@ -14,26 +14,28 @@ from concurrent.futures import CancelledError
 from functools import partial
 from pathlib import Path
 
 from functools import partial
 from pathlib import Path
 
-import websockets
-import yaml
 import theblues.charmstore
 import theblues.errors
 import theblues.charmstore
 import theblues.errors
+import websockets
+import yaml
 
 from . import tag, utils
 
 from . import tag, utils
-from .client import client
-from .client import connection
+from .client import client, connector
 from .client.client import ConfigValue
 from .client.client import ConfigValue
-from .constraints import parse as parse_constraints, normalize_key
-from .delta import get_entity_delta
-from .delta import get_entity_class
+from .client.client import Value
+from .constraints import parse as parse_constraints
+from .constraints import normalize_key
+from .delta import get_entity_class, get_entity_delta
+from .errors import JujuAPIError, JujuError
 from .exceptions import DeadEntityException
 from .exceptions import DeadEntityException
-from .errors import JujuError, JujuAPIError
 from .placement import parse as parse_placement
 from .placement import parse as parse_placement
+from . import provisioner
+
 
 log = logging.getLogger(__name__)
 
 
 
 log = logging.getLogger(__name__)
 
 
-class _Observer(object):
+class _Observer:
     """Wrapper around an observer callable.
 
     This wrapper allows filter criteria to be associated with the
     """Wrapper around an observer callable.
 
     This wrapper allows filter criteria to be associated with the
@@ -77,7 +79,7 @@ class _Observer(object):
         return True
 
 
         return True
 
 
-class ModelObserver(object):
+class ModelObserver:
     """
     Base class for creating observers that react to changes in a model.
     """
     """
     Base class for creating observers that react to changes in a model.
     """
@@ -100,7 +102,7 @@ class ModelObserver(object):
         pass
 
 
         pass
 
 
-class ModelState(object):
+class ModelState:
     """Holds the state of the model, including the delta history of all
     entities in the model.
 
     """Holds the state of the model, including the delta history of all
     entities in the model.
 
@@ -144,6 +146,14 @@ class ModelState(object):
         """
         return self._live_entity_map('unit')
 
         """
         return self._live_entity_map('unit')
 
+    @property
+    def relations(self):
+        """Return a map of relation-id:Relation for all relations currently in
+        the model.
+
+        """
+        return self._live_entity_map('relation')
+
     def entity_history(self, entity_type, entity_id):
         """Return the history deque for an entity.
 
     def entity_history(self, entity_type, entity_id):
         """Return the history deque for an entity.
 
@@ -209,7 +219,7 @@ class ModelState(object):
             connected=connected)
 
 
             connected=connected)
 
 
-class ModelEntity(object):
+class ModelEntity:
     """An object in the Model tree"""
 
     def __init__(self, entity_id, model, history_index=-1, connected=True):
     """An object in the Model tree"""
 
     def __init__(self, entity_id, model, history_index=-1, connected=True):
@@ -228,7 +238,7 @@ class ModelEntity(object):
         self.model = model
         self._history_index = history_index
         self.connected = connected
         self.model = model
         self._history_index = history_index
         self.connected = connected
-        self.connection = model.connection
+        self.connection = model.connection()
 
     def __repr__(self):
         return '<{} entity_id="{}">'.format(type(self).__name__,
 
     def __repr__(self):
         return '<{} entity_id="{}">'.format(type(self).__name__,
@@ -380,90 +390,207 @@ class ModelEntity(object):
         return self.model.state.get_entity(self.entity_type, self.entity_id)
 
 
         return self.model.state.get_entity(self.entity_type, self.entity_id)
 
 
-class Model(object):
+class Model:
     """
     The main API for interacting with a Juju model.
     """
     """
     The main API for interacting with a Juju model.
     """
-    def __init__(self, loop=None,
-                 max_frame_size=connection.Connection.DEFAULT_FRAME_SIZE):
-        """Instantiate a new connected Model.
+    def __init__(
+        self,
+        loop=None,
+        max_frame_size=None,
+        bakery_client=None,
+        jujudata=None,
+    ):
+        """Instantiate a new Model.
+
+        The connect method will need to be called before this
+        object can be used for anything interesting.
+
+        If jujudata is None, jujudata.FileJujuData will be used.
 
         :param loop: an asyncio event loop
         :param max_frame_size: See
             `juju.client.connection.Connection.MAX_FRAME_SIZE`
 
         :param loop: an asyncio event loop
         :param max_frame_size: See
             `juju.client.connection.Connection.MAX_FRAME_SIZE`
+        :param bakery_client httpbakery.Client: The bakery client to use
+            for macaroon authorization.
+        :param jujudata JujuData: The source for current controller information
+        """
+        self._connector = connector.Connector(
+            loop=loop,
+            max_frame_size=max_frame_size,
+            bakery_client=bakery_client,
+            jujudata=jujudata,
+        )
+        self._observers = weakref.WeakValueDictionary()
+        self.state = ModelState(self)
+        self._info = None
+        self._watch_stopping = asyncio.Event(loop=self._connector.loop)
+        self._watch_stopped = asyncio.Event(loop=self._connector.loop)
+        self._watch_received = asyncio.Event(loop=self._connector.loop)
+        self._watch_stopped.set()
+        self._charmstore = CharmStore(self._connector.loop)
+
+    def is_connected(self):
+        """Reports whether the Model is currently connected."""
+        return self._connector.is_connected()
+
+    @property
+    def loop(self):
+        return self._connector.loop
+
+    def connection(self):
+        """Return the current Connection object. It raises an exception
+        if the Model is disconnected"""
+        return self._connector.connection()
 
 
+    async def get_controller(self):
+        """Return a Controller instance for the currently connected model.
+        :return Controller:
         """
         """
-        self.loop = loop or asyncio.get_event_loop()
-        self.max_frame_size = max_frame_size
-        self.connection = None
-        self.observers = weakref.WeakValueDictionary()
-        self.state = ModelState(self)
-        self.info = None
-        self._watch_stopping = asyncio.Event(loop=self.loop)
-        self._watch_stopped = asyncio.Event(loop=self.loop)
-        self._watch_received = asyncio.Event(loop=self.loop)
-        self._charmstore = CharmStore(self.loop)
+        from juju.controller import Controller
+        controller = Controller(jujudata=self._connector.jujudata)
+        kwargs = self.connection().connect_params()
+        kwargs.pop('uuid')
+        await controller._connect_direct(**kwargs)
+        return controller
 
     async def __aenter__(self):
 
     async def __aenter__(self):
-        await self.connect_current()
+        await self.connect()
         return self
 
     async def __aexit__(self, exc_type, exc, tb):
         await self.disconnect()
 
         return self
 
     async def __aexit__(self, exc_type, exc, tb):
         await self.disconnect()
 
-        if exc_type is not None:
-            return False
+    async def connect(self, *args, **kwargs):
+        """Connect to a juju model.
 
 
-    async def connect(self, *args, **kw):
-        """Connect to an arbitrary Juju model.
+        This supports two calling conventions:
 
 
-        args and kw are passed through to Connection.connect()
+        The model and (optionally) authentication information can be taken
+        from the data files created by the Juju CLI.  This convention will
+        be used if a ``model_name`` is specified, or if the ``endpoint``
+        and ``uuid`` are not.
 
 
-        """
-        if 'loop' not in kw:
-            kw['loop'] = self.loop
-        if 'max_frame_size' not in kw:
-            kw['max_frame_size'] = self.max_frame_size
-        self.connection = await connection.Connection.connect(*args, **kw)
-        await self._after_connect()
+        Otherwise, all of the ``endpoint``, ``uuid``, and authentication
+        information (``username`` and ``password``, or ``bakery_client`` and/or
+        ``macaroons``) are required.
 
 
-    async def connect_current(self):
-        """Connect to the current Juju model.
+        If a single positional argument is given, it will be assumed to be
+        the ``model_name``.  Otherwise, the first positional argument, if any,
+        must be the ``endpoint``.
 
 
+        Available parameters are:
+
+        :param model_name:  Format [controller:][user/]model
+        :param str endpoint: The hostname:port of the controller to connect to.
+        :param str uuid: The model UUID to connect to.
+        :param str username: The username for controller-local users (or None
+            to use macaroon-based login.)
+        :param str password: The password for controller-local users.
+        :param str cacert: The CA certificate of the controller
+            (PEM formatted).
+        :param httpbakery.Client bakery_client: The macaroon bakery client to
+            to use when performing macaroon-based login. Macaroon tokens
+            acquired when logging will be saved to bakery_client.cookies.
+            If this is None, a default bakery_client will be used.
+        :param list macaroons: List of macaroons to load into the
+            ``bakery_client``.
+        :param asyncio.BaseEventLoop loop: The event loop to use for async
+            operations.
+        :param int max_frame_size: The maximum websocket frame size to allow.
         """
         """
-        self.connection = await connection.Connection.connect_current(
-            self.loop, max_frame_size=self.max_frame_size)
+        await self.disconnect()
+        if 'endpoint' not in kwargs and len(args) < 2:
+            if args and 'model_name' in kwargs:
+                raise TypeError('connect() got multiple values for model_name')
+            elif args:
+                model_name = args[0]
+            else:
+                model_name = kwargs.pop('model_name', None)
+            await self._connector.connect_model(model_name, **kwargs)
+        else:
+            if 'model_name' in kwargs:
+                raise TypeError('connect() got values for both '
+                                'model_name and endpoint')
+            if args and 'endpoint' in kwargs:
+                raise TypeError('connect() got multiple values for endpoint')
+            if len(args) < 2 and 'uuid' not in kwargs:
+                raise TypeError('connect() missing value for uuid')
+            has_userpass = (len(args) >= 4 or
+                            {'username', 'password'}.issubset(kwargs))
+            has_macaroons = (len(args) >= 6 or not
+                             {'bakery_client', 'macaroons'}.isdisjoint(kwargs))
+            if not (has_userpass or has_macaroons):
+                raise TypeError('connect() missing auth params')
+            arg_names = [
+                'endpoint',
+                'uuid',
+                'username',
+                'password',
+                'cacert',
+                'bakery_client',
+                'macaroons',
+                'loop',
+                'max_frame_size',
+            ]
+            for i, arg in enumerate(args):
+                kwargs[arg_names[i]] = arg
+            if not {'endpoint', 'uuid'}.issubset(kwargs):
+                raise ValueError('endpoint and uuid are required '
+                                 'if model_name not given')
+            if not ({'username', 'password'}.issubset(kwargs) or
+                    {'bakery_client', 'macaroons'}.intersection(kwargs)):
+                raise ValueError('Authentication parameters are required '
+                                 'if model_name not given')
+            await self._connector.connect(**kwargs)
         await self._after_connect()
 
     async def connect_model(self, model_name):
         await self._after_connect()
 
     async def connect_model(self, model_name):
-        """Connect to a specific Juju model by name.
-
-        :param model_name:  Format [controller:][user/]model
+        """
+        .. deprecated:: 0.6.2
+           Use ``connect(model_name=model_name)`` instead.
+        """
+        return await self.connect(model_name=model_name)
 
 
+    async def connect_current(self):
+        """
+        .. deprecated:: 0.6.2
+           Use ``connect()`` instead.
         """
         """
-        self.connection = await connection.Connection.connect_model(
-            model_name, self.loop, self.max_frame_size)
+        return await self.connect()
+
+    async def _connect_direct(self, **kwargs):
+        await self.disconnect()
+        await self._connector.connect(**kwargs)
         await self._after_connect()
 
     async def _after_connect(self):
         await self._after_connect()
 
     async def _after_connect(self):
-        """Run initialization steps after connecting to websocket.
-
-        """
         self._watch()
         self._watch()
+
+        # Wait for the first packet of data from the AllWatcher,
+        # which contains all information on the model.
+        # TODO this means that we can't do anything until
+        # we've received all the model data, which might be
+        # a whole load of unneeded data if all the client wants
+        # to do is make one RPC call.
         await self._watch_received.wait()
         await self._watch_received.wait()
+
         await self.get_info()
 
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
 
         """
         await self.get_info()
 
     async def disconnect(self):
         """Shut down the watcher task and close websockets.
 
         """
-        if self.connection and self.connection.is_open:
+        if not self._watch_stopped.is_set():
             log.debug('Stopping watcher task')
             self._watch_stopping.set()
             await self._watch_stopped.wait()
             log.debug('Stopping watcher task')
             self._watch_stopping.set()
             await self._watch_stopped.wait()
+            self._watch_stopping.clear()
+
+        if self.is_connected():
             log.debug('Closing model connection')
             log.debug('Closing model connection')
-            await self.connection.close()
-            self.connection = None
+            await self._connector.disconnect()
+            self._info = None
 
     async def add_local_charm_dir(self, charm_dir, series):
         """Upload a local charm to the model.
 
     async def add_local_charm_dir(self, charm_dir, series):
         """Upload a local charm to the model.
@@ -480,7 +607,7 @@ class Model(object):
         with fh:
             func = partial(
                 self.add_local_charm, fh, series, os.stat(fh.name).st_size)
         with fh:
             func = partial(
                 self.add_local_charm, fh, series, os.stat(fh.name).st_size)
-            charm_url = await self.loop.run_in_executor(None, func)
+            charm_url = await self._connector.loop.run_in_executor(None, func)
 
         log.debug('Uploaded local charm: %s -> %s', charm_dir, charm_url)
         return charm_url
 
         log.debug('Uploaded local charm: %s -> %s', charm_dir, charm_url)
         return charm_url
@@ -505,7 +632,7 @@ class Model(object):
            instead.
 
         """
            instead.
 
         """
-        conn, headers, path_prefix = self.connection.https_connection()
+        conn, headers, path_prefix = self.connection().https_connection()
         path = "%s/charms?series=%s" % (path_prefix, series)
         headers['Content-Type'] = 'application/zip'
         if size:
         path = "%s/charms?series=%s" % (path_prefix, series)
         headers['Content-Type'] = 'application/zip'
         if size:
@@ -549,13 +676,20 @@ class Model(object):
     async def block_until(self, *conditions, timeout=None, wait_period=0.5):
         """Return only after all conditions are true.
 
     async def block_until(self, *conditions, timeout=None, wait_period=0.5):
         """Return only after all conditions are true.
 
+        Raises `websockets.ConnectionClosed` if disconnected.
         """
         """
-        async def _block():
-            while not all(c() for c in conditions):
-                if not (self.connection and self.connection.is_open):
-                    raise websockets.ConnectionClosed(1006, 'no reason')
-                await asyncio.sleep(wait_period, loop=self.loop)
-        await asyncio.wait_for(_block(), timeout, loop=self.loop)
+        def _disconnected():
+            return not (self.is_connected() and self.connection().is_open)
+
+        def done():
+            return _disconnected() or all(c() for c in conditions)
+
+        await utils.block_until(done,
+                                timeout=timeout,
+                                wait_period=wait_period,
+                                loop=self.loop)
+        if _disconnected():
+            raise websockets.ConnectionClosed(1006, 'no reason')
 
     @property
     def applications(self):
 
     @property
     def applications(self):
@@ -581,6 +715,13 @@ class Model(object):
         """
         return self.state.units
 
         """
         return self.state.units
 
+    @property
+    def relations(self):
+        """Return a list of all Relations currently in the model.
+
+        """
+        return list(self.state.relations.values())
+
     async def get_info(self):
         """Return a client.ModelInfo object for this Model.
 
     async def get_info(self):
         """Return a client.ModelInfo object for this Model.
 
@@ -594,13 +735,21 @@ class Model(object):
         explicit call to this method.
 
         """
         explicit call to this method.
 
         """
-        facade = client.ClientFacade.from_connection(self.connection)
+        facade = client.ClientFacade.from_connection(self.connection())
 
 
-        self.info = await facade.ModelInfo()
+        self._info = await facade.ModelInfo()
         log.debug('Got ModelInfo: %s', vars(self.info))
 
         return self.info
 
         log.debug('Got ModelInfo: %s', vars(self.info))
 
         return self.info
 
+    @property
+    def info(self):
+        """Return the cached client.ModelInfo object for this Model.
+
+        If Model.get_info() has not been called, this will return None.
+        """
+        return self._info
+
     def add_observer(
             self, callable_, entity_type=None, action=None, entity_id=None,
             predicate=None):
     def add_observer(
             self, callable_, entity_type=None, action=None, entity_id=None,
             predicate=None):
@@ -639,7 +788,7 @@ class Model(object):
         """
         observer = _Observer(
             callable_, entity_type, action, entity_id, predicate)
         """
         observer = _Observer(
             callable_, entity_type, action, entity_id, predicate)
-        self.observers[observer] = callable_
+        self._observers[observer] = callable_
 
     def _watch(self):
         """Start an asynchronous watch against this model.
 
     def _watch(self):
         """Start an asynchronous watch against this model.
@@ -650,13 +799,13 @@ class Model(object):
         async def _all_watcher():
             try:
                 allwatcher = client.AllWatcherFacade.from_connection(
         async def _all_watcher():
             try:
                 allwatcher = client.AllWatcherFacade.from_connection(
-                    self.connection)
+                    self.connection())
                 while not self._watch_stopping.is_set():
                     try:
                         results = await utils.run_with_interrupt(
                             allwatcher.Next(),
                             self._watch_stopping,
                 while not self._watch_stopping.is_set():
                     try:
                         results = await utils.run_with_interrupt(
                             allwatcher.Next(),
                             self._watch_stopping,
-                            self.loop)
+                            self._connector.loop)
                     except JujuAPIError as e:
                         if 'watcher was stopped' not in str(e):
                             raise
                     except JujuAPIError as e:
                         if 'watcher was stopped' not in str(e):
                             raise
@@ -673,19 +822,27 @@ class Model(object):
                         del allwatcher.Id
                         continue
                     except websockets.ConnectionClosed:
                         del allwatcher.Id
                         continue
                     except websockets.ConnectionClosed:
-                        monitor = self.connection.monitor
+                        monitor = self.connection().monitor
                         if monitor.status == monitor.ERROR:
                             # closed unexpectedly, try to reopen
                             log.warning(
                                 'Watcher: connection closed, reopening')
                         if monitor.status == monitor.ERROR:
                             # closed unexpectedly, try to reopen
                             log.warning(
                                 'Watcher: connection closed, reopening')
-                            await self.connection.reconnect()
+                            await self.connection().reconnect()
+                            if monitor.status != monitor.CONNECTED:
+                                # reconnect failed; abort and shutdown
+                                log.error('Watcher: automatic reconnect '
+                                          'failed; stopping watcher')
+                                break
                             del allwatcher.Id
                             continue
                         else:
                             # closed on request, go ahead and shutdown
                             break
                     if self._watch_stopping.is_set():
                             del allwatcher.Id
                             continue
                         else:
                             # closed on request, go ahead and shutdown
                             break
                     if self._watch_stopping.is_set():
-                        await allwatcher.Stop()
+                        try:
+                            await allwatcher.Stop()
+                        except websockets.ConnectionClosed:
+                            pass  # can't stop on a closed conn
                         break
                     for delta in results.deltas:
                         delta = get_entity_delta(delta)
                         break
                     for delta in results.deltas:
                         delta = get_entity_delta(delta)
@@ -704,7 +861,7 @@ class Model(object):
         self._watch_received.clear()
         self._watch_stopping.clear()
         self._watch_stopped.clear()
         self._watch_received.clear()
         self._watch_stopping.clear()
         self._watch_stopped.clear()
-        self.loop.create_task(_all_watcher())
+        self._connector.loop.create_task(_all_watcher())
 
     async def _notify_observers(self, delta, old_obj, new_obj):
         """Call observing callbacks, notifying them of a change in model state
 
     async def _notify_observers(self, delta, old_obj, new_obj):
         """Call observing callbacks, notifying them of a change in model state
@@ -724,10 +881,10 @@ class Model(object):
             'Model changed: %s %s %s',
             delta.entity, delta.type, delta.get_id())
 
             'Model changed: %s %s %s',
             delta.entity, delta.type, delta.get_id())
 
-        for o in self.observers:
+        for o in self._observers:
             if o.cares_about(delta):
                 asyncio.ensure_future(o(delta, old_obj, new_obj, self),
             if o.cares_about(delta):
                 asyncio.ensure_future(o(delta, old_obj, new_obj, self),
-                                      loop=self.loop)
+                                      loop=self._connector.loop)
 
     async def _wait(self, entity_type, entity_id, action, predicate=None):
         """
 
     async def _wait(self, entity_type, entity_id, action, predicate=None):
         """
@@ -744,7 +901,7 @@ class Model(object):
             has a 'completed' status. See the _Observer class for details.
 
         """
             has a 'completed' status. See the _Observer class for details.
 
         """
-        q = asyncio.Queue(loop=self.loop)
+        q = asyncio.Queue(loop=self._connector.loop)
 
         async def callback(delta, old, new, model):
             await q.put(delta.get_id())
 
         async def callback(delta, old, new, model):
             await q.put(delta.get_id())
@@ -755,24 +912,19 @@ class Model(object):
         # 'remove' action
         return self.state._live_entity_map(entity_type).get(entity_id)
 
         # 'remove' action
         return self.state._live_entity_map(entity_type).get(entity_id)
 
-    async def _wait_for_new(self, entity_type, entity_id=None, predicate=None):
+    async def _wait_for_new(self, entity_type, entity_id):
         """Wait for a new object to appear in the Model and return it.
 
         """Wait for a new object to appear in the Model and return it.
 
-        Waits for an object of type ``entity_type`` with id ``entity_id``.
-        If ``entity_id`` is ``None``, it will wait for the first new entity
-        of the correct type.
-
-        This coroutine blocks until the new object appears in the model.
+        Waits for an object of type ``entity_type`` with id ``entity_id``
+        to appear in the model.  This is similar to watching for the
+        object using ``block_until``, but uses the watcher rather than
+        polling.
 
         """
         # if the entity is already in the model, just return it
         if entity_id in self.state._live_entity_map(entity_type):
             return self.state._live_entity_map(entity_type)[entity_id]
 
         """
         # if the entity is already in the model, just return it
         if entity_id in self.state._live_entity_map(entity_type):
             return self.state._live_entity_map(entity_type)[entity_id]
-        # if we know the entity_id, we can trigger on any action that puts
-        # the enitty into the model; otherwise, we have to watch for the
-        # next "add" action on that entity_type
-        action = 'add' if entity_id is None else None
-        return await self._wait(entity_type, entity_id, action, predicate)
+        return await self._wait(entity_type, entity_id, None)
 
     async def wait_for_action(self, action_id):
         """Given an action, wait for it to complete."""
 
     async def wait_for_action(self, action_id):
         """Given an action, wait for it to complete."""
@@ -785,7 +937,7 @@ class Model(object):
         def predicate(delta):
             return delta.data['status'] in ('completed', 'failed')
 
         def predicate(delta):
             return delta.data['status'] in ('completed', 'failed')
 
-        return await self._wait('action', action_id, 'change', predicate)
+        return await self._wait('action', action_id, None, predicate)
 
     async def add_machine(
             self, spec=None, constraints=None, disks=None, series=None):
 
     async def add_machine(
             self, spec=None, constraints=None, disks=None, series=None):
@@ -798,7 +950,8 @@ class Model(object):
                 (None) - starts a new machine
                 'lxd' - starts a new machine with one lxd container
                 'lxd:4' - starts a new lxd container on machine 4
                 (None) - starts a new machine
                 'lxd' - starts a new machine with one lxd container
                 'lxd:4' - starts a new lxd container on machine 4
-                'ssh:user@10.10.0.3' - manually provisions a machine with ssh
+                'ssh:user@10.10.0.3:/path/to/private/key' - manually provision
+                a machine with ssh and the private key used for authentication
                 'zone=us-east-1a' - starts a machine in zone us-east-1s on AWS
                 'maas2.name' - acquire machine maas2.name on MAAS
 
                 'zone=us-east-1a' - starts a machine in zone us-east-1s on AWS
                 'maas2.name' - acquire machine maas2.name on MAAS
 
@@ -847,12 +1000,25 @@ class Model(object):
 
         """
         params = client.AddMachineParams()
 
         """
         params = client.AddMachineParams()
-        params.jobs = ['JobHostUnits']
 
         if spec:
 
         if spec:
-            placement = parse_placement(spec)
-            if placement:
-                params.placement = placement[0]
+            if spec.startswith("ssh:"):
+                placement, target, private_key_path = spec.split(":")
+                user, host = target.split("@")
+
+                sshProvisioner = provisioner.SSHProvisioner(
+                    host=host,
+                    user=user,
+                    private_key_path=private_key_path,
+                )
+
+                params = sshProvisioner.provision_machine()
+            else:
+                placement = parse_placement(spec)
+                if placement:
+                    params.placement = placement[0]
+
+        params.jobs = ['JobHostUnits']
 
         if constraints:
             params.constraints = client.Value.from_json(constraints)
 
         if constraints:
             params.constraints = client.Value.from_json(constraints)
@@ -865,12 +1031,23 @@ class Model(object):
             params.series = series
 
         # Submit the request.
             params.series = series
 
         # Submit the request.
-        client_facade = client.ClientFacade.from_connection(self.connection)
+        client_facade = client.ClientFacade.from_connection(self.connection())
         results = await client_facade.AddMachines([params])
         error = results.machines[0].error
         if error:
             raise ValueError("Error adding machine: %s" % error.message)
         machine_id = results.machines[0].machine
         results = await client_facade.AddMachines([params])
         error = results.machines[0].error
         if error:
             raise ValueError("Error adding machine: %s" % error.message)
         machine_id = results.machines[0].machine
+
+        if spec:
+            if spec.startswith("ssh:"):
+                # Need to run this after AddMachines has been called,
+                # as we need the machine_id
+                await sshProvisioner.install_agent(
+                    self.connection(),
+                    params.nonce,
+                    machine_id,
+                )
+
         log.debug('Added new machine %s', machine_id)
         return await self._wait_for_new('machine', machine_id)
 
         log.debug('Added new machine %s', machine_id)
         return await self._wait_for_new('machine', machine_id)
 
@@ -881,29 +1058,34 @@ class Model(object):
         :param str relation2: '<application>[:<relation_name>]'
 
         """
         :param str relation2: '<application>[:<relation_name>]'
 
         """
-        app_facade = client.ApplicationFacade.from_connection(self.connection)
+        connection = self.connection()
+        app_facade = client.ApplicationFacade.from_connection(connection)
 
         log.debug(
             'Adding relation %s <-> %s', relation1, relation2)
 
 
         log.debug(
             'Adding relation %s <-> %s', relation1, relation2)
 
+        def _find_relation(*specs):
+            for rel in self.relations:
+                if rel.matches(*specs):
+                    return rel
+            return None
+
         try:
             result = await app_facade.AddRelation([relation1, relation2])
         except JujuAPIError as e:
             if 'relation already exists' not in e.message:
                 raise
         try:
             result = await app_facade.AddRelation([relation1, relation2])
         except JujuAPIError as e:
             if 'relation already exists' not in e.message:
                 raise
-            log.debug(
-                'Relation %s <-> %s already exists', relation1, relation2)
-            # TODO: if relation already exists we should return the
-            # Relation ModelEntity here
-            return None
+            rel = _find_relation(relation1, relation2)
+            if rel:
+                return rel
+            raise JujuError('Relation {} {} exists but not in model'.format(
+                relation1, relation2))
 
 
-        def predicate(delta):
-            endpoints = {}
-            for endpoint in delta.data['endpoints']:
-                endpoints[endpoint['application-name']] = endpoint['relation']
-            return endpoints == result.endpoints
+        specs = ['{}:{}'.format(app, data['name'])
+                 for app, data in result.endpoints.items()]
 
 
-        return await self._wait_for_new('relation', None, predicate)
+        await self.block_until(lambda: _find_relation(*specs) is not None)
+        return _find_relation(*specs)
 
     def add_space(self, name, *cidrs):
         """Add a new network space.
 
     def add_space(self, name, *cidrs):
         """Add a new network space.
@@ -924,7 +1106,7 @@ class Model(object):
         :param str key: The public ssh key
 
         """
         :param str key: The public ssh key
 
         """
-        key_facade = client.KeyManagerFacade.from_connection(self.connection)
+        key_facade = client.KeyManagerFacade.from_connection(self.connection())
         return await key_facade.AddKeys([key], user)
     add_ssh_keys = add_ssh_key
 
         return await key_facade.AddKeys([key], user)
     add_ssh_keys = add_ssh_key
 
@@ -1072,9 +1254,14 @@ class Model(object):
                 for k, v in storage.items()
             }
 
                 for k, v in storage.items()
             }
 
+        entity_path = Path(entity_url.replace('local:', ''))
+        bundle_path = entity_path / 'bundle.yaml'
+        metadata_path = entity_path / 'metadata.yaml'
+
         is_local = (
             entity_url.startswith('local:') or
         is_local = (
             entity_url.startswith('local:') or
-            os.path.isdir(entity_url)
+            entity_path.is_dir() or
+            entity_path.is_file()
         )
         if is_local:
             entity_id = entity_url.replace('local:', '')
         )
         if is_local:
             entity_id = entity_url.replace('local:', '')
@@ -1082,10 +1269,11 @@ class Model(object):
             entity = await self.charmstore.entity(entity_url, channel=channel)
             entity_id = entity['Id']
 
             entity = await self.charmstore.entity(entity_url, channel=channel)
             entity_id = entity['Id']
 
-        client_facade = client.ClientFacade.from_connection(self.connection)
+        client_facade = client.ClientFacade.from_connection(self.connection())
 
         is_bundle = ((is_local and
 
         is_bundle = ((is_local and
-                      (Path(entity_id) / 'bundle.yaml').exists()) or
+                      (entity_id.endswith('.yaml') and entity_path.exists()) or
+                      bundle_path.exists()) or
                      (not is_local and 'bundle/' in entity_id))
 
         if is_bundle:
                      (not is_local and 'bundle/' in entity_id))
 
         if is_bundle:
@@ -1100,9 +1288,9 @@ class Model(object):
                 await asyncio.gather(*[
                     asyncio.ensure_future(
                         self._wait_for_new('application', app_name),
                 await asyncio.gather(*[
                     asyncio.ensure_future(
                         self._wait_for_new('application', app_name),
-                        loop=self.loop)
+                        loop=self._connector.loop)
                     for app_name in pending_apps
                     for app_name in pending_apps
-                ], loop=self.loop)
+                ], loop=self._connector.loop)
             return [app for name, app in self.applications.items()
                     if name in handler.applications]
         else:
             return [app for name, app in self.applications.items()
                     if name in handler.applications]
         else:
@@ -1118,6 +1306,9 @@ class Model(object):
                                                             entity_id,
                                                             entity)
             else:
                                                             entity_id,
                                                             entity)
             else:
+                if not application_name:
+                    metadata = yaml.load(metadata_path.read_text())
+                    application_name = metadata['name']
                 # We have a local charm dir that needs to be uploaded
                 charm_dir = os.path.abspath(
                     os.path.expanduser(entity_id))
                 # We have a local charm dir that needs to be uploaded
                 charm_dir = os.path.abspath(
                     os.path.expanduser(entity_id))
@@ -1163,7 +1354,7 @@ class Model(object):
             return None
 
         resources_facade = client.ResourcesFacade.from_connection(
             return None
 
         resources_facade = client.ResourcesFacade.from_connection(
-            self.connection)
+            self.connection())
         response = await resources_facade.AddPendingResources(
             tag.application(application),
             entity_url,
         response = await resources_facade.AddPendingResources(
             tag.application(application),
             entity_url,
@@ -1186,7 +1377,7 @@ class Model(object):
                            default_flow_style=False)
 
         app_facade = client.ApplicationFacade.from_connection(
                            default_flow_style=False)
 
         app_facade = client.ApplicationFacade.from_connection(
-            self.connection)
+            self.connection())
 
         app = client.ApplicationDeploy(
             charm_url=charm_url,
 
         app = client.ApplicationDeploy(
             charm_url=charm_url,
@@ -1201,7 +1392,6 @@ class Model(object):
             storage=storage,
             placement=placement
         )
             storage=storage,
             placement=placement
         )
-
         result = await app_facade.Deploy([app])
         errors = [r.error.message for r in result.results if r.error]
         if errors:
         result = await app_facade.Deploy([app])
         errors = [r.error.message for r in result.results if r.error]
         if errors:
@@ -1218,7 +1408,8 @@ class Model(object):
         """Destroy units by name.
 
         """
         """Destroy units by name.
 
         """
-        app_facade = client.ApplicationFacade.from_connection(self.connection)
+        connection = self.connection()
+        app_facade = client.ApplicationFacade.from_connection(connection)
 
         log.debug(
             'Destroying unit%s %s',
 
         log.debug(
             'Destroying unit%s %s',
@@ -1263,7 +1454,7 @@ class Model(object):
             which have `source` and `value` attributes.
         """
         config_facade = client.ModelConfigFacade.from_connection(
             which have `source` and `value` attributes.
         """
         config_facade = client.ModelConfigFacade.from_connection(
-            self.connection
+            self.connection()
         )
         result = await config_facade.ModelGet()
         config = result.config
         )
         result = await config_facade.ModelGet()
         config = result.config
@@ -1271,27 +1462,28 @@ class Model(object):
             config[key] = ConfigValue.from_json(value)
         return config
 
             config[key] = ConfigValue.from_json(value)
         return config
 
-    def get_constraints(self):
+    async def get_constraints(self):
         """Return the machine constraints for this model.
 
         """Return the machine constraints for this model.
 
-        """
-        raise NotImplementedError()
-
-    async def grant(self, username, acl='read'):
-        """Grant a user access to this model.
-
-        :param str username: Username
-        :param str acl: Access control ('read' or 'write')
-
-        """
-        controller_conn = await self.connection.controller()
-        model_facade = client.ModelManagerFacade.from_connection(
-            controller_conn)
-        user = tag.user(username)
-        model = tag.model(self.info.uuid)
-        changes = client.ModifyModelAccess(acl, 'grant', model, user)
-        await self.revoke(username)
-        return await model_facade.ModifyModelAccess([changes])
+        :returns: A ``dict`` of constraints.
+        """
+        constraints = {}
+        client_facade = client.ClientFacade.from_connection(self.connection())
+        result = await client_facade.GetModelConstraints()
+
+        # GetModelConstraints returns GetConstraintsResults which has a 'constraints'
+        # attribute. If no constraints have been set GetConstraintsResults.constraints
+        # is None. Otherwise GetConstraintsResults.constraints has an attribute for each
+        # possible constraint, each of these in turn will be None if they have not been
+        # set.
+        if result.constraints:
+           constraint_types = [a for a in dir(result.constraints)
+                               if a in Value._toSchema.keys()]
+           for constraint in constraint_types:
+               value = getattr(result.constraints, constraint)
+               if value is not None:
+                   constraints[constraint] = getattr(result.constraints, constraint)
+        return constraints
 
     def import_ssh_key(self, identity):
         """Add a public SSH key from a trusted indentity source to this model.
 
     def import_ssh_key(self, identity):
         """Add a public SSH key from a trusted indentity source to this model.
@@ -1326,7 +1518,7 @@ class Model(object):
             else it's fingerprint
 
         """
             else it's fingerprint
 
         """
-        key_facade = client.KeyManagerFacade.from_connection(self.connection)
+        key_facade = client.KeyManagerFacade.from_connection(self.connection())
         entity = {'tag': tag.model(self.info.uuid)}
         entities = client.Entities([entity])
         return await key_facade.ListKeys(entities, raw_ssh)
         entity = {'tag': tag.model(self.info.uuid)}
         entities = client.Entities([entity])
         return await key_facade.ListKeys(entities, raw_ssh)
@@ -1399,10 +1591,10 @@ class Model(object):
         :param str user: Juju user to which the key is registered
 
         """
         :param str user: Juju user to which the key is registered
 
         """
-        key_facade = client.KeyManagerFacade.from_connection(self.connection)
+        key_facade = client.KeyManagerFacade.from_connection(self.connection())
         key = base64.b64decode(bytes(key.strip().split()[1].encode('ascii')))
         key = hashlib.md5(key).hexdigest()
         key = base64.b64decode(bytes(key.strip().split()[1].encode('ascii')))
         key = hashlib.md5(key).hexdigest()
-        key = ':'.join(a+b for a, b in zip(key[::2], key[1::2]))
+        key = ':'.join(a + b for a, b in zip(key[::2], key[1::2]))
         await key_facade.DeleteKeys([key], user)
     remove_ssh_keys = remove_ssh_key
 
         await key_facade.DeleteKeys([key], user)
     remove_ssh_keys = remove_ssh_key
 
@@ -1427,20 +1619,6 @@ class Model(object):
         """
         raise NotImplementedError()
 
         """
         raise NotImplementedError()
 
-    async def revoke(self, username):
-        """Revoke a user's access to this model.
-
-        :param str username: Username to revoke
-
-        """
-        controller_conn = await self.connection.controller()
-        model_facade = client.ModelManagerFacade.from_connection(
-            controller_conn)
-        user = tag.user(username)
-        model = tag.model(self.info.uuid)
-        changes = client.ModifyModelAccess('read', 'revoke', model, user)
-        return await model_facade.ModifyModelAccess([changes])
-
     def run(self, command, timeout=None):
         """Run command on all machines in this model.
 
     def run(self, command, timeout=None):
         """Run command on all machines in this model.
 
@@ -1457,38 +1635,86 @@ class Model(object):
             `ConfigValue` instances, as returned by `get_config`.
         """
         config_facade = client.ModelConfigFacade.from_connection(
             `ConfigValue` instances, as returned by `get_config`.
         """
         config_facade = client.ModelConfigFacade.from_connection(
-            self.connection
+            self.connection()
         )
         for key, value in config.items():
             if isinstance(value, ConfigValue):
                 config[key] = value.value
         await config_facade.ModelSet(config)
 
         )
         for key, value in config.items():
             if isinstance(value, ConfigValue):
                 config[key] = value.value
         await config_facade.ModelSet(config)
 
-    def set_constraints(self, constraints):
+    async def set_constraints(self, constraints):
         """Set machine constraints on this model.
 
         """Set machine constraints on this model.
 
-        :param :class:`juju.Constraints` constraints: Machine constraints
-
+        :param dict config: Mapping of model constraints
         """
         """
-        raise NotImplementedError()
+        client_facade = client.ClientFacade.from_connection(self.connection())
+        await client_facade.SetModelConstraints(
+            application='',
+            constraints=constraints)
 
 
-    def get_action_output(self, action_uuid, wait=-1):
+    async def get_action_output(self, action_uuid, wait=None):
         """Get the results of an action by ID.
 
         :param str action_uuid: Id of the action
         """Get the results of an action by ID.
 
         :param str action_uuid: Id of the action
-        :param int wait: Time in seconds to wait for action to complete
-
+        :param int wait: Time in seconds to wait for action to complete.
+        :return dict: Output from action
+        :raises: :class:`JujuError` if invalid action_uuid
         """
         """
-        raise NotImplementedError()
+        action_facade = client.ActionFacade.from_connection(
+            self.connection()
+        )
+        entity = [{'tag': tag.action(action_uuid)}]
+        # Cannot use self.wait_for_action as the action event has probably
+        # already happened and self.wait_for_action works by processing
+        # model deltas and checking if they match our type. If the action
+        # has already occured then the delta has gone.
+
+        async def _wait_for_action_status():
+            while True:
+                action_output = await action_facade.Actions(entity)
+                if action_output.results[0].status in ('completed', 'failed'):
+                    return
+                else:
+                    await asyncio.sleep(1)
+        await asyncio.wait_for(
+            _wait_for_action_status(),
+            timeout=wait)
+        action_output = await action_facade.Actions(entity)
+        # ActionResult.output is None if the action produced no output
+        if action_output.results[0].output is None:
+            output = {}
+        else:
+            output = action_output.results[0].output
+        return output
 
 
-    def get_action_status(self, uuid_or_prefix=None, name=None):
-        """Get the status of all actions, filtered by ID, ID prefix, or action name.
+    async def get_action_status(self, uuid_or_prefix=None, name=None):
+        """Get the status of all actions, filtered by ID, ID prefix, or name.
 
         :param str uuid_or_prefix: Filter by action uuid or prefix
         :param str name: Filter by action name
 
         """
 
         :param str uuid_or_prefix: Filter by action uuid or prefix
         :param str name: Filter by action name
 
         """
-        raise NotImplementedError()
+        results = {}
+        action_results = []
+        action_facade = client.ActionFacade.from_connection(
+            self.connection()
+        )
+        if name:
+            name_results = await action_facade.FindActionsByNames([name])
+            action_results.extend(name_results.actions[0].actions)
+        if uuid_or_prefix:
+            # Collect list of actions matching uuid or prefix
+            matching_actions = await action_facade.FindActionTagsByPrefix(
+                [uuid_or_prefix])
+            entities = []
+            for actions in matching_actions.matches.values():
+                entities = [{'tag': a.tag} for a in actions]
+            # Get action results matching action tags
+            uuid_results = await action_facade.Actions(entities)
+            action_results.extend(uuid_results.results)
+        for a in action_results:
+            results[tag.untag('action-', a.action.tag)] = a.status
+        return results
 
     def get_budget(self, budget_name):
         """Get budget usage info.
 
     def get_budget(self, budget_name):
         """Get budget usage info.
@@ -1506,7 +1732,7 @@ class Model(object):
         :param bool utc: Display time as UTC in RFC3339 format
 
         """
         :param bool utc: Display time as UTC in RFC3339 format
 
         """
-        client_facade = client.ClientFacade.from_connection(self.connection)
+        client_facade = client.ClientFacade.from_connection(self.connection())
         return await client_facade.FullStatus(filters)
 
     def sync_tools(
         return await client_facade.FullStatus(filters)
 
     def sync_tools(
@@ -1587,7 +1813,7 @@ class Model(object):
                   ', '.join(tags) if tags else "all units")
 
         metrics_facade = client.MetricsDebugFacade.from_connection(
                   ', '.join(tags) if tags else "all units")
 
         metrics_facade = client.MetricsDebugFacade.from_connection(
-            self.connection)
+            self.connection())
 
         entities = [client.Entity(tag) for tag in tags]
         metrics_result = await metrics_facade.GetMetrics(entities)
 
         entities = [client.Entity(tag) for tag in tags]
         metrics_result = await metrics_facade.GetMetrics(entities)
@@ -1623,7 +1849,7 @@ def get_charm_series(path):
     return series[0] if series else None
 
 
     return series[0] if series else None
 
 
-class BundleHandler(object):
+class BundleHandler:
     """
     Handle bundles by using the API to translate bundle YAML into a plan of
     steps and then dispatching each of those using the API.
     """
     Handle bundles by using the API to translate bundle YAML into a plan of
     steps and then dispatching each of those using the API.
@@ -1638,11 +1864,11 @@ class BundleHandler(object):
             app_units = self._units_by_app.setdefault(unit.application, [])
             app_units.append(unit_name)
         self.client_facade = client.ClientFacade.from_connection(
             app_units = self._units_by_app.setdefault(unit.application, [])
             app_units.append(unit_name)
         self.client_facade = client.ClientFacade.from_connection(
-            model.connection)
+            model.connection())
         self.app_facade = client.ApplicationFacade.from_connection(
         self.app_facade = client.ApplicationFacade.from_connection(
-            model.connection)
+            model.connection())
         self.ann_facade = client.AnnotationsFacade.from_connection(
         self.ann_facade = client.AnnotationsFacade.from_connection(
-            model.connection)
+            model.connection())
 
     async def _handle_local_charms(self, bundle):
         """Search for references to local charms (i.e. filesystem paths)
 
     async def _handle_local_charms(self, bundle):
         """Search for references to local charms (i.e. filesystem paths)
@@ -1658,8 +1884,9 @@ class BundleHandler(object):
         apps, args = [], []
 
         default_series = bundle.get('series')
         apps, args = [], []
 
         default_series = bundle.get('series')
+        apps_dict = bundle.get('applications', bundle.get('services', {}))
         for app_name in self.applications:
         for app_name in self.applications:
-            app_dict = bundle['services'][app_name]
+            app_dict = apps_dict[app_name]
             charm_dir = os.path.abspath(os.path.expanduser(app_dict['charm']))
             if not os.path.isdir(charm_dir):
                 continue
             charm_dir = os.path.abspath(os.path.expanduser(app_dict['charm']))
             if not os.path.isdir(charm_dir):
                 continue
@@ -1688,13 +1915,16 @@ class BundleHandler(object):
             ], loop=self.model.loop)
             # Update the 'charm:' entry for each app with the new 'local:' url.
             for app_name, charm_url in zip(apps, charm_urls):
             ], loop=self.model.loop)
             # Update the 'charm:' entry for each app with the new 'local:' url.
             for app_name, charm_url in zip(apps, charm_urls):
-                bundle['services'][app_name]['charm'] = charm_url
+                apps_dict[app_name]['charm'] = charm_url
 
         return bundle
 
     async def fetch_plan(self, entity_id):
 
         return bundle
 
     async def fetch_plan(self, entity_id):
-        is_local = not entity_id.startswith('cs:') and os.path.isdir(entity_id)
-        if is_local:
+        is_local = not entity_id.startswith('cs:')
+
+        if is_local and os.path.isfile(entity_id):
+            bundle_yaml = Path(entity_id).read_text()
+        elif is_local and os.path.isdir(entity_id):
             bundle_yaml = (Path(entity_id) / "bundle.yaml").read_text()
         else:
             bundle_yaml = await self.charmstore.files(entity_id,
             bundle_yaml = (Path(entity_id) / "bundle.yaml").read_text()
         else:
             bundle_yaml = await self.charmstore.files(entity_id,
@@ -1706,6 +1936,9 @@ class BundleHandler(object):
         self.plan = await self.client_facade.GetBundleChanges(
             yaml.dump(self.bundle))
 
         self.plan = await self.client_facade.GetBundleChanges(
             yaml.dump(self.bundle))
 
+        if self.plan.errors:
+            raise JujuError(self.plan.errors)
+
     async def execute_plan(self):
         for step in self.plan.changes:
             method = getattr(self, step.method)
     async def execute_plan(self):
         for step in self.plan.changes:
             method = getattr(self, step.method)
@@ -1714,7 +1947,9 @@ class BundleHandler(object):
 
     @property
     def applications(self):
 
     @property
     def applications(self):
-        return list(self.bundle['services'].keys())
+        apps_dict = self.bundle.get('applications',
+                                    self.bundle.get('services', {}))
+        return list(apps_dict.keys())
 
     def resolve(self, reference):
         if reference and reference.startswith('$'):
 
     def resolve(self, reference):
         if reference and reference.startswith('$'):
@@ -1769,7 +2004,11 @@ class BundleHandler(object):
 
         # Fix up values, as necessary.
         if 'parent_id' in params:
 
         # Fix up values, as necessary.
         if 'parent_id' in params:
-            params['parent_id'] = self.resolve(params['parent_id'])
+            if params['parent_id'].startswith('$addUnit'):
+                unit = self.resolve(params['parent_id'])[0]
+                params['parent_id'] = unit.machine.entity_id
+            else:
+                params['parent_id'] = self.resolve(params['parent_id'])
 
         params['constraints'] = parse_constraints(
             params.get('constraints'))
 
         params['constraints'] = parse_constraints(
             params.get('constraints'))
@@ -1917,7 +2156,7 @@ class BundleHandler(object):
         return await entity.set_annotations(annotations)
 
 
         return await entity.set_annotations(annotations)
 
 
-class CharmStore(object):
+class CharmStore:
     """
     Async wrapper around theblues.charmstore.CharmStore
     """
     """
     Async wrapper around theblues.charmstore.CharmStore
     """
@@ -1949,7 +2188,7 @@ class CharmStore(object):
         return wrapper
 
 
         return wrapper
 
 
-class CharmArchiveGenerator(object):
+class CharmArchiveGenerator:
     """
     Create a Zip archive of a local charm directory for upload to a controller.
 
     """
     Create a Zip archive of a local charm directory for upload to a controller.