Pass through event loop
authorCory Johns <johnsca@gmail.com>
Wed, 1 Mar 2017 17:02:24 +0000 (12:02 -0500)
committerCory Johns <johnsca@gmail.com>
Fri, 3 Mar 2017 16:33:39 +0000 (11:33 -0500)
There were several places where the default event loop was used instead
of the given event loop.

juju/client/connection.py
juju/controller.py
juju/model.py
juju/utils.py

index 3e304c1..3011a8a 100644 (file)
@@ -11,6 +11,7 @@ import subprocess
 import websockets
 from http.client import HTTPSConnection
 
 import websockets
 from http.client import HTTPSConnection
 
+import asyncio
 import yaml
 
 from juju import tag
 import yaml
 
 from juju import tag
@@ -36,13 +37,14 @@ class Connection:
     """
     def __init__(
             self, endpoint, uuid, username, password, cacert=None,
     """
     def __init__(
             self, endpoint, uuid, username, password, cacert=None,
-            macaroons=None):
+            macaroons=None, loop=None):
         self.endpoint = endpoint
         self.uuid = uuid
         self.username = username
         self.password = password
         self.macaroons = macaroons
         self.cacert = cacert
         self.endpoint = endpoint
         self.uuid = uuid
         self.username = username
         self.password = password
         self.macaroons = macaroons
         self.cacert = cacert
+        self.loop = loop or asyncio.get_event_loop()
 
         self.__request_id__ = 0
         self.addr = None
 
         self.__request_id__ = 0
         self.addr = None
@@ -67,6 +69,7 @@ class Connection:
 
         kw = dict()
         kw['ssl'] = self._get_ssl(self.cacert)
 
         kw = dict()
         kw['ssl'] = self._get_ssl(self.cacert)
+        kw['loop'] = self.loop
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
         log.info("Driver connected to juju %s", url)
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
         log.info("Driver connected to juju %s", url)
@@ -153,6 +156,7 @@ class Connection:
             self.password,
             self.cacert,
             self.macaroons,
             self.password,
             self.cacert,
             self.macaroons,
+            self.loop,
         )
 
     async def controller(self):
         )
 
     async def controller(self):
@@ -166,19 +170,21 @@ class Connection:
             self.password,
             self.cacert,
             self.macaroons,
             self.password,
             self.cacert,
             self.macaroons,
+            self.loop,
         )
 
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
         )
 
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
-            macaroons=None):
+            macaroons=None, loop=None):
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
         will be to the model.
 
         """
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
         will be to the model.
 
         """
-        client = cls(endpoint, uuid, username, password, cacert, macaroons)
+        client = cls(endpoint, uuid, username, password, cacert, macaroons,
+                     loop)
         await client.open()
 
         redirect_info = await client.redirect_info()
         await client.open()
 
         redirect_info = await client.redirect_info()
@@ -209,7 +215,7 @@ class Connection:
             "Couldn't authenticate to %s", endpoint)
 
     @classmethod
             "Couldn't authenticate to %s", endpoint)
 
     @classmethod
-    async def connect_current(cls):
+    async def connect_current(cls, loop=None):
         """Connect to the currently active model.
 
         """
         """Connect to the currently active model.
 
         """
@@ -219,10 +225,10 @@ class Connection:
         model_name = models['current-model']
 
         return await cls.connect_model(
         model_name = models['current-model']
 
         return await cls.connect_model(
-            '{}:{}'.format(controller_name, model_name))
+            '{}:{}'.format(controller_name, model_name), loop)
 
     @classmethod
 
     @classmethod
-    async def connect_current_controller(cls):
+    async def connect_current_controller(cls, loop=None):
         """Connect to the currently active controller.
 
         """
         """Connect to the currently active controller.
 
         """
@@ -231,10 +237,10 @@ class Connection:
         if not controller_name:
             raise JujuConnectionError('No current controller')
 
         if not controller_name:
             raise JujuConnectionError('No current controller')
 
-        return await cls.connect_controller(controller_name)
+        return await cls.connect_controller(controller_name, loop)
 
     @classmethod
 
     @classmethod
-    async def connect_controller(cls, controller_name):
+    async def connect_controller(cls, controller_name, loop=None):
         """Connect to a controller by name.
 
         """
         """Connect to a controller by name.
 
         """
@@ -248,10 +254,10 @@ class Connection:
         macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
         macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
-            endpoint, None, username, password, cacert, macaroons)
+            endpoint, None, username, password, cacert, macaroons, loop)
 
     @classmethod
 
     @classmethod
-    async def connect_model(cls, model):
+    async def connect_model(cls, model, loop=None):
         """Connect to a model by name.
 
         :param str model: <controller>:<model>
         """Connect to a model by name.
 
         :param str model: <controller>:<model>
@@ -272,7 +278,7 @@ class Connection:
         macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
         macaroons = get_macaroons() if not password else None
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert, macaroons)
+            endpoint, model_uuid, username, password, cacert, macaroons, loop)
 
     def build_facades(self, info):
         self.facades.clear()
 
     def build_facades(self, info):
         self.facades.clear()
index 2bcb2e7..b3d2ac1 100644 (file)
@@ -103,8 +103,9 @@ class Controller(object):
         try:
             ssh_key = await utils.read_ssh_key(loop=self.loop)
             await utils.execute_process(
         try:
             ssh_key = await utils.read_ssh_key(loop=self.loop)
             await utils.execute_process(
-                'juju', 'add-ssh-key', '-m', model_name, ssh_key, log=log)
-        except Exception as e:
+                'juju', 'add-ssh-key', '-m', model_name, ssh_key, log=log,
+                loop=self.loop)
+        except Exception:
             log.exception(
                 "Could not add ssh key to model. You will not be able "
                 "to ssh into machines in this model. "
             log.exception(
                 "Could not add ssh key to model. You will not be able "
                 "to ssh into machines in this model. "
@@ -119,6 +120,7 @@ class Controller(object):
             self.connection.password,
             self.connection.cacert,
             self.connection.macaroons,
             self.connection.password,
             self.connection.cacert,
             self.connection.macaroons,
+            loop=self.loop,
         )
 
         return model
         )
 
         return model
index 63c306b..634c3b6 100644 (file)
@@ -376,8 +376,8 @@ class Model(object):
         self.state = ModelState(self)
         self.info = None
         self._watcher_task = None
         self.state = ModelState(self)
         self.info = None
         self._watcher_task = None
-        self._watch_shutdown = asyncio.Event(loop=loop)
-        self._watch_received = asyncio.Event(loop=loop)
+        self._watch_shutdown = asyncio.Event(loop=self.loop)
+        self._watch_received = asyncio.Event(loop=self.loop)
         self._charmstore = CharmStore(self.loop)
 
     async def connect(self, *args, **kw):
         self._charmstore = CharmStore(self.loop)
 
     async def connect(self, *args, **kw):
@@ -386,6 +386,8 @@ class Model(object):
         args and kw are passed through to Connection.connect()
 
         """
         args and kw are passed through to Connection.connect()
 
         """
+        if 'loop' not in kw:
+            kw['loop'] = self.loop
         self.connection = await connection.Connection.connect(*args, **kw)
         await self._after_connect()
 
         self.connection = await connection.Connection.connect(*args, **kw)
         await self._after_connect()
 
@@ -393,7 +395,8 @@ class Model(object):
         """Connect to the current Juju model.
 
         """
         """Connect to the current Juju model.
 
         """
-        self.connection = await connection.Connection.connect_current()
+        self.connection = await connection.Connection.connect_current(
+            self.loop)
         await self._after_connect()
 
     async def connect_model(self, model_name):
         await self._after_connect()
 
     async def connect_model(self, model_name):
@@ -402,7 +405,8 @@ class Model(object):
         :param model_name:  Format [controller:][user/]model
 
         """
         :param model_name:  Format [controller:][user/]model
 
         """
-        self.connection = await connection.Connection.connect_model(model_name)
+        self.connection = await connection.Connection.connect_model(model_name,
+                                                                    self.loop)
         await self._after_connect()
 
     async def _after_connect(self):
         await self._after_connect()
 
     async def _after_connect(self):
@@ -511,8 +515,8 @@ class Model(object):
         """
         async def _block():
             while not all(c() for c in conditions):
         """
         async def _block():
             while not all(c() for c in conditions):
-                await asyncio.sleep(wait_period)
-        await asyncio.wait_for(_block(), timeout)
+                await asyncio.sleep(wait_period, loop=self.loop)
+        await asyncio.wait_for(_block(), timeout, loop=self.loop)
 
     @property
     def applications(self):
 
     @property
     def applications(self):
@@ -623,7 +627,8 @@ class Model(object):
                         # canceled with it. So we shield them. But this means
                         # they can *never* be canceled.
                         await asyncio.shield(
                         # canceled with it. So we shield them. But this means
                         # they can *never* be canceled.
                         await asyncio.shield(
-                            self._notify_observers(delta, old_obj, new_obj))
+                            self._notify_observers(delta, old_obj, new_obj),
+                            loop=self.loop)
                     self._watch_received.set()
             except CancelledError:
                 log.debug('Closing watcher connection')
                     self._watch_received.set()
             except CancelledError:
                 log.debug('Closing watcher connection')
@@ -662,7 +667,8 @@ class Model(object):
 
         for o in self.observers:
             if o.cares_about(delta):
 
         for o in self.observers:
             if o.cares_about(delta):
-                asyncio.ensure_future(o(delta, old_obj, new_obj, self))
+                asyncio.ensure_future(o(delta, old_obj, new_obj, self),
+                                      loop=self.loop)
 
     async def _wait(self, entity_type, entity_id, action, predicate=None):
         """
 
     async def _wait(self, entity_type, entity_id, action, predicate=None):
         """
@@ -1005,9 +1011,10 @@ class Model(object):
                 # haven't made it yet we'll need to wait on them to be added
                 await asyncio.gather(*[
                     asyncio.ensure_future(
                 # haven't made it yet we'll need to wait on them to be added
                 await asyncio.gather(*[
                     asyncio.ensure_future(
-                        self._wait_for_new('application', app_name))
+                        self._wait_for_new('application', app_name),
+                        loop=self.loop)
                     for app_name in pending_apps
                     for app_name in pending_apps
-                ])
+                ], loop=self.loop)
             return [app for name, app in self.applications.items()
                     if name in handler.applications]
         else:
             return [app for name, app in self.applications.items()
                     if name in handler.applications]
         else:
@@ -1493,7 +1500,7 @@ class BundleHandler(object):
             charm_urls = await asyncio.gather(*[
                 self.model.add_local_charm_dir(*params)
                 for params in args
             charm_urls = await asyncio.gather(*[
                 self.model.add_local_charm_dir(*params)
                 for params in args
-            ])
+            ], loop=self.loop)
             # Update the 'charm:' entry for each app with the new 'local:' url.
             for app_name, charm_url in zip(apps, charm_urls):
                 bundle['services'][app_name]['charm'] = charm_url
             # Update the 'charm:' entry for each app with the new 'local:' url.
             for app_name, charm_url in zip(apps, charm_urls):
                 bundle['services'][app_name]['charm'] = charm_url
index c3d7f23..c0a500c 100644 (file)
@@ -3,7 +3,7 @@ import os
 from pathlib import Path
 
 
 from pathlib import Path
 
 
-async def execute_process(*cmd, log=None):
+async def execute_process(*cmd, log=None, loop=None):
     '''
     Wrapper around asyncio.create_subprocess_exec.
 
     '''
     Wrapper around asyncio.create_subprocess_exec.
 
@@ -13,7 +13,7 @@ async def execute_process(*cmd, log=None):
             stdin=asyncio.subprocess.PIPE,
             stdout=asyncio.subprocess.PIPE,
             stderr=asyncio.subprocess.PIPE,
             stdin=asyncio.subprocess.PIPE,
             stdout=asyncio.subprocess.PIPE,
             stderr=asyncio.subprocess.PIPE,
-            )
+            loop=loop)
     stdout, stderr = await p.communicate()
     if log:
         log.debug("Exec %s -> %d", cmd, p.returncode)
     stdout, stderr = await p.communicate()
     if log:
         log.debug("Exec %s -> %d", cmd, p.returncode)