Merge branch 'master' into bug/fix-invalid-annotations
authorPete Vander Giessen <petevg@gmail.com>
Tue, 7 Mar 2017 21:20:39 +0000 (15:20 -0600)
committerGitHub <noreply@github.com>
Tue, 7 Mar 2017 21:20:39 +0000 (15:20 -0600)
juju/client/connection.py
juju/client/facade.py
juju/constraints.py
juju/controller.py
juju/model.py
juju/utils.py
tests/base.py
tests/integration/test_model.py
tests/unit/test_model.py
tox.ini

index f408135..b508a1a 100644 (file)
@@ -11,10 +11,11 @@ import subprocess
 import websockets
 from http.client import HTTPSConnection
 
+import asyncio
 import yaml
 
 from juju import tag
-from juju.errors import JujuAPIError, JujuConnectionError, JujuError
+from juju.errors import JujuError, JujuAPIError, JujuConnectionError
 
 log = logging.getLogger("websocket")
 
@@ -33,16 +34,19 @@ class Connection:
         # Connect to the currently active model
         client = await Connection.connect_current()
 
+    Note: Any connection method or constructor can accept an optional `loop`
+    argument to override the default event loop from `asyncio.get_event_loop`.
     """
     def __init__(
             self, endpoint, uuid, username, password, cacert=None,
-            macaroons=None):
+            macaroons=None, loop=None):
         self.endpoint = endpoint
         self.uuid = uuid
         self.username = username
         self.password = password
         self.macaroons = macaroons
         self.cacert = cacert
+        self.loop = loop or asyncio.get_event_loop()
 
         self.__request_id__ = 0
         self.addr = None
@@ -67,6 +71,7 @@ class Connection:
 
         kw = dict()
         kw['ssl'] = self._get_ssl(self.cacert)
+        kw['loop'] = self.loop
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
         log.info("Driver connected to juju %s", url)
@@ -175,6 +180,7 @@ class Connection:
             self.password,
             self.cacert,
             self.macaroons,
+            self.loop,
         )
 
     async def controller(self):
@@ -188,19 +194,21 @@ class Connection:
             self.password,
             self.cacert,
             self.macaroons,
+            self.loop,
         )
 
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
-            macaroons=None):
+            macaroons=None, loop=None):
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
         will be to the model.
 
         """
-        client = cls(endpoint, uuid, username, password, cacert, macaroons)
+        client = cls(endpoint, uuid, username, password, cacert, macaroons,
+                     loop)
         await client.open()
 
         redirect_info = await client.redirect_info()
@@ -231,20 +239,19 @@ class Connection:
             "Couldn't authenticate to %s", endpoint)
 
     @classmethod
-    async def connect_current(cls):
+    async def connect_current(cls, loop=None):
         """Connect to the currently active model.
 
         """
         jujudata = JujuData()
         controller_name = jujudata.current_controller()
-        models = jujudata.models()[controller_name]
-        model_name = models['current-model']
+        model_name = jujudata.current_model()
 
         return await cls.connect_model(
-            '{}:{}'.format(controller_name, model_name))
+            '{}:{}'.format(controller_name, model_name), loop)
 
     @classmethod
-    async def connect_current_controller(cls):
+    async def connect_current_controller(cls, loop=None):
         """Connect to the currently active controller.
 
         """
@@ -253,10 +260,10 @@ class Connection:
         if not controller_name:
             raise JujuConnectionError('No current controller')
 
-        return await cls.connect_controller(controller_name)
+        return await cls.connect_controller(controller_name, loop)
 
     @classmethod
-    async def connect_controller(cls, controller_name):
+    async def connect_controller(cls, controller_name, loop=None):
         """Connect to a controller by name.
 
         """
@@ -270,30 +277,41 @@ class Connection:
         macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
-            endpoint, None, username, password, cacert, macaroons)
+            endpoint, None, username, password, cacert, macaroons, loop)
 
     @classmethod
-    async def connect_model(cls, model):
+    async def connect_model(cls, model, loop=None):
         """Connect to a model by name.
 
-        :param str model: <controller>:<model>
+        :param str model: [<controller>:]<model>
 
         """
-        controller_name, model_name = model.split(':')
-
         jujudata = JujuData()
+
+        if ':' in model:
+            # explicit controller given
+            controller_name, model_name = model.split(':')
+        else:
+            # use the current controller if one isn't explicitly given
+            controller_name = jujudata.current_controller()
+            model_name = model
+
+        accounts = jujudata.accounts()[controller_name]
+        username = accounts['user']
+        # model name must include a user prefix, so add it if it doesn't
+        if '/' not in model_name:
+            model_name = '{}/{}'.format(username, model_name)
+
         controller = jujudata.controllers()[controller_name]
         endpoint = controller['api-endpoints'][0]
         cacert = controller.get('ca-cert')
-        accounts = jujudata.accounts()[controller_name]
-        username = accounts['user']
         password = accounts.get('password')
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
         macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert, macaroons)
+            endpoint, model_uuid, username, password, cacert, macaroons, loop)
 
     def build_facades(self, info):
         self.facades.clear()
@@ -348,6 +366,14 @@ class JujuData:
         output = yaml.safe_load(output)
         return output.get('current-controller', '')
 
+    def current_model(self, controller_name=None):
+        if not controller_name:
+            controller_name = self.current_controller()
+        models = self.models()[controller_name]
+        if 'current-model' not in models:
+            raise JujuError('No current model')
+        return models['current-model']
+
     def controllers(self):
         return self._load_yaml('controllers.yaml', 'controllers')
 
@@ -371,7 +397,7 @@ def get_macaroons():
         cookie_file = os.path.expanduser('~/.go-cookies')
         with open(cookie_file, 'r') as f:
             cookies = json.load(f)
-    except (OSError, ValueError) as e:
+    except (OSError, ValueError):
         log.warn("Couldn't load macaroons from %s", cookie_file)
         return []
 
index 5ed5320..817b37b 100644 (file)
@@ -438,6 +438,8 @@ class Type:
 
     @classmethod
     def from_json(cls, data):
+        if isinstance(data, cls):
+            return data
         if isinstance(data, str):
             data = json.loads(data)
         d = {}
index c551883..998862d 100644 (file)
@@ -39,10 +39,7 @@ def parse(constraints):
     and key value pairs joined by an '='.
 
     """
-    if constraints is None:
-        return None
-
-    if constraints == "":
+    if not constraints:
         return None
 
     if type(constraints) is dict:
index 2bcb2e7..b3d2ac1 100644 (file)
@@ -103,8 +103,9 @@ class Controller(object):
         try:
             ssh_key = await utils.read_ssh_key(loop=self.loop)
             await utils.execute_process(
-                'juju', 'add-ssh-key', '-m', model_name, ssh_key, log=log)
-        except Exception as e:
+                'juju', 'add-ssh-key', '-m', model_name, ssh_key, log=log,
+                loop=self.loop)
+        except Exception:
             log.exception(
                 "Could not add ssh key to model. You will not be able "
                 "to ssh into machines in this model. "
@@ -119,6 +120,7 @@ class Controller(object):
             self.connection.password,
             self.connection.cacert,
             self.connection.macaroons,
+            loop=self.loop,
         )
 
         return model
index 548fc03..55ad086 100644 (file)
@@ -13,7 +13,8 @@ from functools import partial
 from pathlib import Path
 
 import yaml
-from theblues import charmstore
+import theblues.charmstore
+import theblues.errors
 
 from .client import client
 from .client import watcher
@@ -376,8 +377,8 @@ class Model(object):
         self.state = ModelState(self)
         self.info = None
         self._watcher_task = None
-        self._watch_shutdown = asyncio.Event(loop=loop)
-        self._watch_received = asyncio.Event(loop=loop)
+        self._watch_shutdown = asyncio.Event(loop=self.loop)
+        self._watch_received = asyncio.Event(loop=self.loop)
         self._charmstore = CharmStore(self.loop)
 
     async def connect(self, *args, **kw):
@@ -386,6 +387,8 @@ class Model(object):
         args and kw are passed through to Connection.connect()
 
         """
+        if 'loop' not in kw:
+            kw['loop'] = self.loop
         self.connection = await connection.Connection.connect(*args, **kw)
         await self._after_connect()
 
@@ -393,7 +396,8 @@ class Model(object):
         """Connect to the current Juju model.
 
         """
-        self.connection = await connection.Connection.connect_current()
+        self.connection = await connection.Connection.connect_current(
+            self.loop)
         await self._after_connect()
 
     async def connect_model(self, model_name):
@@ -402,7 +406,8 @@ class Model(object):
         :param model_name:  Format [controller:][user/]model
 
         """
-        self.connection = await connection.Connection.connect_model(model_name)
+        self.connection = await connection.Connection.connect_model(model_name,
+                                                                    self.loop)
         await self._after_connect()
 
     async def _after_connect(self):
@@ -511,8 +516,8 @@ class Model(object):
         """
         async def _block():
             while not all(c() for c in conditions):
-                await asyncio.sleep(wait_period)
-        await asyncio.wait_for(_block(), timeout)
+                await asyncio.sleep(wait_period, loop=self.loop)
+        await asyncio.wait_for(_block(), timeout, loop=self.loop)
 
     @property
     def applications(self):
@@ -623,7 +628,8 @@ class Model(object):
                         # canceled with it. So we shield them. But this means
                         # they can *never* be canceled.
                         await asyncio.shield(
-                            self._notify_observers(delta, old_obj, new_obj))
+                            self._notify_observers(delta, old_obj, new_obj),
+                            loop=self.loop)
                     self._watch_received.set()
             except CancelledError:
                 log.debug('Closing watcher connection')
@@ -662,7 +668,8 @@ class Model(object):
 
         for o in self.observers:
             if o.cares_about(delta):
-                asyncio.ensure_future(o(delta, old_obj, new_obj, self))
+                asyncio.ensure_future(o(delta, old_obj, new_obj, self),
+                                      loop=self.loop)
 
     async def _wait(self, entity_type, entity_id, action, predicate=None):
         """
@@ -928,6 +935,22 @@ class Model(object):
         """
         raise NotImplementedError()
 
+    def _get_series(self, entity_url, entity):
+        # try to get the series from the provided charm URL
+        if entity_url.startswith('cs:'):
+            parts = entity_url[3:].split('/')
+        else:
+            parts = entity_url.split('/')
+        if parts[0].startswith('~'):
+            parts.pop(0)
+        if len(parts) > 1:
+            # series was specified in the URL
+            return parts[0]
+        # series was not supplied at all, so use the newest
+        # supported series according to the charm store
+        ss = entity['Meta']['supported-series']
+        return ss['SupportedSeries'][0]
+
     async def deploy(
             self, entity_url, application_name=None, bind=None, budget=None,
             channel=None, config=None, constraints=None, force=False,
@@ -967,11 +990,6 @@ class Model(object):
             - series is required; how do we pick a default?
 
         """
-        if to:
-            placement = parse_placement(to)
-        else:
-            placement = []
-
         if storage:
             storage = {
                 k: client.Constraints(**v)
@@ -982,12 +1000,13 @@ class Model(object):
             entity_url.startswith('local:') or
             os.path.isdir(entity_url)
         )
-        entity_id = await self.charmstore.entityId(entity_url) \
-            if not is_local else entity_url
+        if is_local:
+            entity_id = entity_url
+        else:
+            entity = await self.charmstore.entity(entity_url)
+            entity_id = entity['Id']
 
-        app_facade = client.ApplicationFacade()
         client_facade = client.ClientFacade()
-        app_facade.connect(self.connection)
         client_facade.connect(self.connection)
 
         is_bundle = ((is_local and
@@ -1005,18 +1024,22 @@ class Model(object):
                 # haven't made it yet we'll need to wait on them to be added
                 await asyncio.gather(*[
                     asyncio.ensure_future(
-                        self._wait_for_new('application', app_name))
+                        self._wait_for_new('application', app_name),
+                        loop=self.loop)
                     for app_name in pending_apps
-                ])
+                ], loop=self.loop)
             return [app for name, app in self.applications.items()
                     if name in handler.applications]
         else:
-            log.debug(
-                'Deploying %s', entity_id)
-
             if not is_local:
+                if not application_name:
+                    application_name = entity['Meta']['charm-metadata']['Name']
+                if not series:
+                    series = self._get_series(entity_url, entity)
+                if not channel:
+                    channel = 'stable'
                 await client_facade.AddCharm(channel, entity_id)
-            elif not entity_id.startswith('local:'):
+            else:
                 # We have a local charm dir that needs to be uploaded
                 charm_dir = os.path.abspath(
                     os.path.expanduser(entity_id))
@@ -1027,23 +1050,54 @@ class Model(object):
                         "Pass a 'series' kwarg to Model.deploy().".format(
                             charm_dir))
                 entity_id = await self.add_local_charm_dir(charm_dir, series)
-
-            app = client.ApplicationDeploy(
-                application=application_name,
-                channel=channel,
+            return await self._deploy(
                 charm_url=entity_id,
-                config=config,
-                constraints=parse_constraints(constraints),
+                application=application_name,
+                series=series,
+                config=config or {},
+                constraints=constraints,
                 endpoint_bindings=bind,
-                num_units=num_units,
                 resources=resources,
-                series=series,
                 storage=storage,
+                channel=channel,
+                num_units=num_units,
+                placement=parse_placement(to),
             )
-            app.placement = placement
 
-            await app_facade.Deploy([app])
-            return await self._wait_for_new('application', application_name)
+    async def _deploy(self, charm_url, application, series, config,
+                      constraints, endpoint_bindings, resources, storage,
+                      channel=None, num_units=None, placement=None):
+        """Logic shared between `Model.deploy` and `BundleHandler.deploy`.
+        """
+        log.info('Deploying %s', charm_url)
+
+        # stringify all config values for API, and convert to YAML
+        config = {k: str(v) for k, v in config.items()}
+        config = yaml.dump({application: config},
+                           default_flow_style=False)
+
+        app_facade = client.ApplicationFacade()
+        app_facade.connect(self.connection)
+
+        app = client.ApplicationDeploy(
+            charm_url=charm_url,
+            application=application,
+            series=series,
+            channel=channel,
+            config_yaml=config,
+            constraints=parse_constraints(constraints),
+            endpoint_bindings=endpoint_bindings,
+            num_units=num_units,
+            resources=resources,
+            storage=storage,
+            placement=placement,
+        )
+
+        result = await app_facade.Deploy([app])
+        errors = [r.error.message for r in result.results if r.error]
+        if errors:
+            raise JujuError('\n'.join(errors))
+        return await self._wait_for_new('application', application)
 
     def destroy(self):
         """Terminate all machines and resources for this model.
@@ -1302,15 +1356,17 @@ class Model(object):
         """
         raise NotImplementedError()
 
-    def get_status(self, filter_=None, utc=False):
+    async def get_status(self, filters=None, utc=False):
         """Return the status of the model.
 
-        :param str filter_: Service or unit name or wildcard ('*')
+        :param str filters: Optional list of applications, units, or machines
+            to include, which can use wildcards ('*').
         :param bool utc: Display time as UTC in RFC3339 format
 
         """
-        raise NotImplementedError()
-    status = get_status
+        client_facade = client.ClientFacade()
+        client_facade.connect(self.connection)
+        return await client_facade.FullStatus(filters)
 
     def sync_tools(
             self, all_=False, destination=None, dry_run=False, public=False,
@@ -1488,7 +1544,7 @@ class BundleHandler(object):
             charm_urls = await asyncio.gather(*[
                 self.model.add_local_charm_dir(*params)
                 for params in args
-            ])
+            ], loop=self.model.loop)
             # Update the 'charm:' entry for each app with the new 'local:' url.
             for app_name, charm_url in zip(apps, charm_urls):
                 bundle['services'][app_name]['charm'] = charm_url
@@ -1644,28 +1700,16 @@ class BundleHandler(object):
         """
         # resolve indirect references
         charm = self.resolve(charm)
-        # stringify all config values for API, and convert to YAML
-        options = {k: str(v) for k, v in options.items()}
-        options = yaml.dump({application: options}, default_flow_style=False)
-        # build param object
-        app = client.ApplicationDeploy(
+        await self.model._deploy(
             charm_url=charm,
-            series=series,
             application=application,
-            # Pass options to config-yaml rather than config, as
-            # config-yaml invokes a newer codepath that better handles
-            # empty strings in the options values.
-            config_yaml=options,
-            constraints=parse_constraints(constraints),
-            storage=storage,
+            series=series,
+            config=options,
+            constraints=constraints,
             endpoint_bindings=endpoint_bindings,
             resources=resources,
+            storage=storage,
         )
-        # do the do
-        log.info('Deploying %s', charm)
-        await self.app_facade.Deploy([app])
-        # ensure the app is in the model for future operations
-        await self.model._wait_for_new('application', application)
         return application
 
     async def addUnit(self, application, to):
@@ -1733,7 +1777,7 @@ class CharmStore(object):
     """
     def __init__(self, loop):
         self.loop = loop
-        self._cs = charmstore.CharmStore()
+        self._cs = theblues.charmstore.CharmStore(timeout=5)
 
     def __getattr__(self, name):
         """
@@ -1747,7 +1791,13 @@ class CharmStore(object):
         else:
             async def coro(*args, **kwargs):
                 method = partial(attr, *args, **kwargs)
-                return await self.loop.run_in_executor(None, method)
+                for attempt in range(1, 4):
+                    try:
+                        return await self.loop.run_in_executor(None, method)
+                    except theblues.errors.ServerError:
+                        if attempt == 3:
+                            raise
+                        await asyncio.sleep(1, loop=self.loop)
             setattr(self, name, coro)
             wrapper = coro
         return wrapper
index c3d7f23..c0a500c 100644 (file)
@@ -3,7 +3,7 @@ import os
 from pathlib import Path
 
 
-async def execute_process(*cmd, log=None):
+async def execute_process(*cmd, log=None, loop=None):
     '''
     Wrapper around asyncio.create_subprocess_exec.
 
@@ -13,7 +13,7 @@ async def execute_process(*cmd, log=None):
             stdin=asyncio.subprocess.PIPE,
             stdout=asyncio.subprocess.PIPE,
             stderr=asyncio.subprocess.PIPE,
-            )
+            loop=loop)
     stdout, stderr = await p.communicate()
     if log:
         log.debug("Exec %s -> %d", cmd, p.returncode)
index 382da43..af386ea 100644 (file)
@@ -1,9 +1,11 @@
-import uuid
+import mock
 import subprocess
+import uuid
 
 import pytest
 
 from juju.controller import Controller
+from juju.client.connection import JujuData
 
 
 def is_bootstrapped():
@@ -29,9 +31,16 @@ class CleanModel():
         model_name = 'model-{}'.format(uuid.uuid4())
         self.model = await self.controller.add_model(model_name)
 
+        # Ensure that we connect to the new model by default.  This also
+        # prevents failures if test was started with no current model.
+        self._patch_cm = mock.patch.object(JujuData, 'current_model',
+                                           return_value=model_name)
+        self._patch_cm.start()
+
         return self.model
 
     async def __aexit__(self, exc_type, exc, tb):
+        self._patch_cm.stop()
         await self.model.disconnect()
         await self.controller.destroy_model(self.model.info.uuid)
         await self.controller.disconnect()
index 2fe97d0..4d45a1c 100644 (file)
@@ -1,6 +1,9 @@
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
 import pytest
 
 from .. import base
+from juju.model import Model
 
 MB = 1
 GB = 1024
@@ -95,3 +98,40 @@ async def test_relate(event_loop):
         )
 
         assert isinstance(my_relation, Relation)
+
+
+async def _deploy_in_loop(new_loop, model_name):
+    new_model = Model(new_loop)
+    await new_model.connect_model(model_name)
+    try:
+        await new_model.deploy('cs:xenial/ubuntu')
+        assert 'ubuntu' in new_model.applications
+    finally:
+        await new_model.disconnect()
+
+
+@base.bootstrapped
+@pytest.mark.asyncio
+async def test_explicit_loop(event_loop):
+    async with base.CleanModel() as model:
+        model_name = model.info.name
+        new_loop = asyncio.new_event_loop()
+        new_loop.run_until_complete(
+            _deploy_in_loop(new_loop, model_name))
+        await model._wait_for_new('application', 'ubuntu')
+        assert 'ubuntu' in model.applications
+
+
+@base.bootstrapped
+@pytest.mark.asyncio
+async def test_explicit_loop_threaded(event_loop):
+    async with base.CleanModel() as model:
+        model_name = model.info.name
+        new_loop = asyncio.new_event_loop()
+        with ThreadPoolExecutor(1) as executor:
+            f = executor.submit(
+                new_loop.run_until_complete,
+                _deploy_in_loop(new_loop, model_name))
+            f.result()
+        await model._wait_for_new('application', 'ubuntu')
+        assert 'ubuntu' in model.applications
index f8cced3..67db5ae 100644 (file)
@@ -92,3 +92,24 @@ class TestModelState(unittest.TestCase):
         self.assertFalse(new)
         self.assertIsInstance(prev, Application)
         self.assertTrue(prev)
+
+
+def test_get_series():
+    from juju.model import Model
+    model = Model()
+    entity = {
+        'Meta': {
+            'supported-series': {
+                'SupportedSeries': [
+                    'xenial',
+                    'trusty',
+                ],
+            },
+        },
+    }
+    assert model._get_series('cs:trusty/ubuntu', entity) == 'trusty'
+    assert model._get_series('xenial/ubuntu', entity) == 'xenial'
+    assert model._get_series('~foo/xenial/ubuntu', entity) == 'xenial'
+    assert model._get_series('~foo/ubuntu', entity) == 'xenial'
+    assert model._get_series('ubuntu', entity) == 'xenial'
+    assert model._get_series('cs:ubuntu', entity) == 'xenial'
diff --git a/tox.ini b/tox.ini
index d4f3a73..1ac6356 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -23,4 +23,4 @@ commands = py.test -ra -s -x -n auto -k 'not integration'
 
 [testenv:integration]
 basepython=python3
-commands = py.test -ra -s -x -n auto
+commands = py.test -ra -s -x -n auto {posargs}