from juju import tag
from juju.errors import JujuError, JujuAPIError, JujuConnectionError
+from juju.utils import IdQueue
log = logging.getLogger("websocket")
self.addr = None
self.ws = None
self.facades = {}
- self.messages = {}
+ self.messages = IdQueue(loop=self.loop)
@property
def is_open(self):
await self.ws.close()
async def recv(self, request_id):
- while not self.messages.get(request_id):
- await asyncio.sleep(0)
-
- result = self.messages[request_id]
-
- del self.messages[request_id]
- return result
+ return await self.messages.get(request_id)
async def receiver(self):
while self.is_open:
result = await self.ws.recv()
if result is not None:
result = json.loads(result)
- self.messages[result['request-id']] = result
+ await self.messages.put(result['request-id'], result)
async def rpc(self, msg, encoder=None):
self.__request_id__ += 1
client = cls(endpoint, uuid, username, password, cacert, macaroons,
loop)
await client.open()
- self.loop.create_task(self.receiver)
+ client.loop.create_task(client.receiver)
redirect_info = await client.redirect_info()
if not redirect_info:
import asyncio
import os
+from collections import defaultdict
+from functools import partial
from pathlib import Path
'''
return await loop.run_in_executor(None, _read_ssh_key)
+
+
+class IdQueue:
+ """
+ Wrapper around asyncio.Queue that maintains a separate queue for each ID.
+ """
+ def __init__(self, maxsize=0, *, loop=None):
+ self._queues = defaultdict(partial(asyncio.Queue, maxsize, loop=loop))
+
+ async def get(self, id):
+ value = await self._queues[id].get()
+ del self._queues[id]
+ return value
+
+ async def put(self, id, value):
+ await self._queues[id].put(value)