Updated examples to use juju.loop
[osm/N2VC.git] / juju / client / connection.py
index 486a57f..3ee8f16 100644 (file)
@@ -16,6 +16,7 @@ import yaml
 
 from juju import tag
 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
+from juju.utils import IdQueue
 
 log = logging.getLogger("websocket")
 
@@ -52,7 +53,7 @@ class Connection:
         self.addr = None
         self.ws = None
         self.facades = {}
-        self.messages = {}
+        self.messages = IdQueue(loop=self.loop)
 
     @property
     def is_open(self):
@@ -75,6 +76,7 @@ class Connection:
         kw['loop'] = self.loop
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
+        self.loop.create_task(self.receiver())
         log.info("Driver connected to juju %s", url)
         return self
 
@@ -82,20 +84,24 @@ class Connection:
         await self.ws.close()
 
     async def recv(self, request_id):
-        while not self.messages.get(request_id):
-            await asyncio.sleep(0)
-
-        result = self.messages[request_id]
-
-        del self.messages[request_id]
-        return result
+        if not self.is_open:
+            raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
+        return await self.messages.get(request_id)
 
     async def receiver(self):
         while self.is_open:
-            result = await self.ws.recv()
-            if result is not None:
-                result = json.loads(result)
-                self.messages[result['request-id']] = result
+            try:
+                result = await self.ws.recv()
+                if result is not None:
+                    result = json.loads(result)
+                    await self.messages.put(result['request-id'], result)
+            except Exception as e:
+                await self.messages.put_all(e)
+                if isinstance(e, websockets.ConnectionClosed):
+                    # ConnectionClosed is not really exceptional for us,
+                    # but it may be for any pending message listeners
+                    return
+                raise
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -115,7 +121,7 @@ class Connection:
             # API Error Response
             raise JujuAPIError(result)
 
-        if not 'response' in result:
+        if 'response' not in result:
             # This may never happen
             return result
 
@@ -221,7 +227,6 @@ class Connection:
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
                      loop)
         await client.open()
-        self.loop.create_task(self.receiver)
 
         redirect_info = await client.redirect_info()
         if not redirect_info: