From f3e0df690919bb6413aee809bd9d6d295daa7cc8 Mon Sep 17 00:00:00 2001 From: Cory Johns Date: Tue, 7 Mar 2017 15:29:26 -0600 Subject: [PATCH] Refactor to use IdQueue pattern --- juju/client/connection.py | 15 +++++---------- juju/utils.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/juju/client/connection.py b/juju/client/connection.py index 486a57f..625c609 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -16,6 +16,7 @@ import yaml from juju import tag from juju.errors import JujuError, JujuAPIError, JujuConnectionError +from juju.utils import IdQueue log = logging.getLogger("websocket") @@ -52,7 +53,7 @@ class Connection: self.addr = None self.ws = None self.facades = {} - self.messages = {} + self.messages = IdQueue(loop=self.loop) @property def is_open(self): @@ -82,20 +83,14 @@ class Connection: await self.ws.close() async def recv(self, request_id): - while not self.messages.get(request_id): - await asyncio.sleep(0) - - result = self.messages[request_id] - - del self.messages[request_id] - return result + return await self.messages.get(request_id) async def receiver(self): while self.is_open: result = await self.ws.recv() if result is not None: result = json.loads(result) - self.messages[result['request-id']] = result + await self.messages.put(result['request-id'], result) async def rpc(self, msg, encoder=None): self.__request_id__ += 1 @@ -221,7 +216,7 @@ class Connection: client = cls(endpoint, uuid, username, password, cacert, macaroons, loop) await client.open() - self.loop.create_task(self.receiver) + client.loop.create_task(client.receiver) redirect_info = await client.redirect_info() if not redirect_info: diff --git a/juju/utils.py b/juju/utils.py index c0a500c..9f5d63d 100644 --- a/juju/utils.py +++ b/juju/utils.py @@ -1,5 +1,7 @@ import asyncio import os +from collections import defaultdict +from functools import partial from pathlib import Path @@ -45,3 +47,19 @@ async def read_ssh_key(loop): ''' return await loop.run_in_executor(None, _read_ssh_key) + + +class IdQueue: + """ + Wrapper around asyncio.Queue that maintains a separate queue for each ID. + """ + def __init__(self, maxsize=0, *, loop=None): + self._queues = defaultdict(partial(asyncio.Queue, maxsize, loop=loop)) + + async def get(self, id): + value = await self._queues[id].get() + del self._queues[id] + return value + + async def put(self, id, value): + await self._queues[id].put(value) -- 2.17.1