Refactor to use IdQueue pattern
authorCory Johns <johnsca@gmail.com>
Tue, 7 Mar 2017 21:29:26 +0000 (15:29 -0600)
committerCory Johns <johnsca@gmail.com>
Tue, 7 Mar 2017 22:06:57 +0000 (16:06 -0600)
juju/client/connection.py
juju/utils.py

index 486a57f..625c609 100644 (file)
@@ -16,6 +16,7 @@ import yaml
 
 from juju import tag
 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
+from juju.utils import IdQueue
 
 log = logging.getLogger("websocket")
 
@@ -52,7 +53,7 @@ class Connection:
         self.addr = None
         self.ws = None
         self.facades = {}
-        self.messages = {}
+        self.messages = IdQueue(loop=self.loop)
 
     @property
     def is_open(self):
@@ -82,20 +83,14 @@ class Connection:
         await self.ws.close()
 
     async def recv(self, request_id):
-        while not self.messages.get(request_id):
-            await asyncio.sleep(0)
-
-        result = self.messages[request_id]
-
-        del self.messages[request_id]
-        return result
+        return await self.messages.get(request_id)
 
     async def receiver(self):
         while self.is_open:
             result = await self.ws.recv()
             if result is not None:
                 result = json.loads(result)
-                self.messages[result['request-id']] = result
+                await self.messages.put(result['request-id'], result)
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -221,7 +216,7 @@ class Connection:
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
                      loop)
         await client.open()
-        self.loop.create_task(self.receiver)
+        client.loop.create_task(client.receiver)
 
         redirect_info = await client.redirect_info()
         if not redirect_info:
index c0a500c..9f5d63d 100644 (file)
@@ -1,5 +1,7 @@
 import asyncio
 import os
+from collections import defaultdict
+from functools import partial
 from pathlib import Path
 
 
@@ -45,3 +47,19 @@ async def read_ssh_key(loop):
 
     '''
     return await loop.run_in_executor(None, _read_ssh_key)
+
+
+class IdQueue:
+    """
+    Wrapper around asyncio.Queue that maintains a separate queue for each ID.
+    """
+    def __init__(self, maxsize=0, *, loop=None):
+        self._queues = defaultdict(partial(asyncio.Queue, maxsize, loop=loop))
+
+    async def get(self, id):
+        value = await self._queues[id].get()
+        del self._queues[id]
+        return value
+
+    async def put(self, id, value):
+        await self._queues[id].put(value)