Pass through event loop
authorCory Johns <johnsca@gmail.com>
Wed, 1 Mar 2017 17:02:24 +0000 (12:02 -0500)
committerCory Johns <johnsca@gmail.com>
Fri, 3 Mar 2017 16:33:39 +0000 (11:33 -0500)
There were several places where the default event loop was used instead
of the given event loop.

juju/client/connection.py
juju/controller.py
juju/model.py
juju/utils.py

index 3e304c1..3011a8a 100644 (file)
@@ -11,6 +11,7 @@ import subprocess
 import websockets
 from http.client import HTTPSConnection
 
+import asyncio
 import yaml
 
 from juju import tag
@@ -36,13 +37,14 @@ class Connection:
     """
     def __init__(
             self, endpoint, uuid, username, password, cacert=None,
-            macaroons=None):
+            macaroons=None, loop=None):
         self.endpoint = endpoint
         self.uuid = uuid
         self.username = username
         self.password = password
         self.macaroons = macaroons
         self.cacert = cacert
+        self.loop = loop or asyncio.get_event_loop()
 
         self.__request_id__ = 0
         self.addr = None
@@ -67,6 +69,7 @@ class Connection:
 
         kw = dict()
         kw['ssl'] = self._get_ssl(self.cacert)
+        kw['loop'] = self.loop
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
         log.info("Driver connected to juju %s", url)
@@ -153,6 +156,7 @@ class Connection:
             self.password,
             self.cacert,
             self.macaroons,
+            self.loop,
         )
 
     async def controller(self):
@@ -166,19 +170,21 @@ class Connection:
             self.password,
             self.cacert,
             self.macaroons,
+            self.loop,
         )
 
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
-            macaroons=None):
+            macaroons=None, loop=None):
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
         will be to the model.
 
         """
-        client = cls(endpoint, uuid, username, password, cacert, macaroons)
+        client = cls(endpoint, uuid, username, password, cacert, macaroons,
+                     loop)
         await client.open()
 
         redirect_info = await client.redirect_info()
@@ -209,7 +215,7 @@ class Connection:
             "Couldn't authenticate to %s", endpoint)
 
     @classmethod
-    async def connect_current(cls):
+    async def connect_current(cls, loop=None):
         """Connect to the currently active model.
 
         """
@@ -219,10 +225,10 @@ class Connection:
         model_name = models['current-model']
 
         return await cls.connect_model(
-            '{}:{}'.format(controller_name, model_name))
+            '{}:{}'.format(controller_name, model_name), loop)
 
     @classmethod
-    async def connect_current_controller(cls):
+    async def connect_current_controller(cls, loop=None):
         """Connect to the currently active controller.
 
         """
@@ -231,10 +237,10 @@ class Connection:
         if not controller_name:
             raise JujuConnectionError('No current controller')
 
-        return await cls.connect_controller(controller_name)
+        return await cls.connect_controller(controller_name, loop)
 
     @classmethod
-    async def connect_controller(cls, controller_name):
+    async def connect_controller(cls, controller_name, loop=None):
         """Connect to a controller by name.
 
         """
@@ -248,10 +254,10 @@ class Connection:
         macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
-            endpoint, None, username, password, cacert, macaroons)
+            endpoint, None, username, password, cacert, macaroons, loop)
 
     @classmethod
-    async def connect_model(cls, model):
+    async def connect_model(cls, model, loop=None):
         """Connect to a model by name.
 
         :param str model: <controller>:<model>
@@ -272,7 +278,7 @@ class Connection:
         macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert, macaroons)
+            endpoint, model_uuid, username, password, cacert, macaroons, loop)
 
     def build_facades(self, info):
         self.facades.clear()
index 2bcb2e7..b3d2ac1 100644 (file)
@@ -103,8 +103,9 @@ class Controller(object):
         try:
             ssh_key = await utils.read_ssh_key(loop=self.loop)
             await utils.execute_process(
-                'juju', 'add-ssh-key', '-m', model_name, ssh_key, log=log)
-        except Exception as e:
+                'juju', 'add-ssh-key', '-m', model_name, ssh_key, log=log,
+                loop=self.loop)
+        except Exception:
             log.exception(
                 "Could not add ssh key to model. You will not be able "
                 "to ssh into machines in this model. "
@@ -119,6 +120,7 @@ class Controller(object):
             self.connection.password,
             self.connection.cacert,
             self.connection.macaroons,
+            loop=self.loop,
         )
 
         return model
index 63c306b..634c3b6 100644 (file)
@@ -376,8 +376,8 @@ class Model(object):
         self.state = ModelState(self)
         self.info = None
         self._watcher_task = None
-        self._watch_shutdown = asyncio.Event(loop=loop)
-        self._watch_received = asyncio.Event(loop=loop)
+        self._watch_shutdown = asyncio.Event(loop=self.loop)
+        self._watch_received = asyncio.Event(loop=self.loop)
         self._charmstore = CharmStore(self.loop)
 
     async def connect(self, *args, **kw):
@@ -386,6 +386,8 @@ class Model(object):
         args and kw are passed through to Connection.connect()
 
         """
+        if 'loop' not in kw:
+            kw['loop'] = self.loop
         self.connection = await connection.Connection.connect(*args, **kw)
         await self._after_connect()
 
@@ -393,7 +395,8 @@ class Model(object):
         """Connect to the current Juju model.
 
         """
-        self.connection = await connection.Connection.connect_current()
+        self.connection = await connection.Connection.connect_current(
+            self.loop)
         await self._after_connect()
 
     async def connect_model(self, model_name):
@@ -402,7 +405,8 @@ class Model(object):
         :param model_name:  Format [controller:][user/]model
 
         """
-        self.connection = await connection.Connection.connect_model(model_name)
+        self.connection = await connection.Connection.connect_model(model_name,
+                                                                    self.loop)
         await self._after_connect()
 
     async def _after_connect(self):
@@ -511,8 +515,8 @@ class Model(object):
         """
         async def _block():
             while not all(c() for c in conditions):
-                await asyncio.sleep(wait_period)
-        await asyncio.wait_for(_block(), timeout)
+                await asyncio.sleep(wait_period, loop=self.loop)
+        await asyncio.wait_for(_block(), timeout, loop=self.loop)
 
     @property
     def applications(self):
@@ -623,7 +627,8 @@ class Model(object):
                         # canceled with it. So we shield them. But this means
                         # they can *never* be canceled.
                         await asyncio.shield(
-                            self._notify_observers(delta, old_obj, new_obj))
+                            self._notify_observers(delta, old_obj, new_obj),
+                            loop=self.loop)
                     self._watch_received.set()
             except CancelledError:
                 log.debug('Closing watcher connection')
@@ -662,7 +667,8 @@ class Model(object):
 
         for o in self.observers:
             if o.cares_about(delta):
-                asyncio.ensure_future(o(delta, old_obj, new_obj, self))
+                asyncio.ensure_future(o(delta, old_obj, new_obj, self),
+                                      loop=self.loop)
 
     async def _wait(self, entity_type, entity_id, action, predicate=None):
         """
@@ -1005,9 +1011,10 @@ class Model(object):
                 # haven't made it yet we'll need to wait on them to be added
                 await asyncio.gather(*[
                     asyncio.ensure_future(
-                        self._wait_for_new('application', app_name))
+                        self._wait_for_new('application', app_name),
+                        loop=self.loop)
                     for app_name in pending_apps
-                ])
+                ], loop=self.loop)
             return [app for name, app in self.applications.items()
                     if name in handler.applications]
         else:
@@ -1493,7 +1500,7 @@ class BundleHandler(object):
             charm_urls = await asyncio.gather(*[
                 self.model.add_local_charm_dir(*params)
                 for params in args
-            ])
+            ], loop=self.loop)
             # Update the 'charm:' entry for each app with the new 'local:' url.
             for app_name, charm_url in zip(apps, charm_urls):
                 bundle['services'][app_name]['charm'] = charm_url
index c3d7f23..c0a500c 100644 (file)
@@ -3,7 +3,7 @@ import os
 from pathlib import Path
 
 
-async def execute_process(*cmd, log=None):
+async def execute_process(*cmd, log=None, loop=None):
     '''
     Wrapper around asyncio.create_subprocess_exec.
 
@@ -13,7 +13,7 @@ async def execute_process(*cmd, log=None):
             stdin=asyncio.subprocess.PIPE,
             stdout=asyncio.subprocess.PIPE,
             stderr=asyncio.subprocess.PIPE,
-            )
+            loop=loop)
     stdout, stderr = await p.communicate()
     if log:
         log.debug("Exec %s -> %d", cmd, p.returncode)