Merge pull request #76 from juju/bug/match-websocket-responses
authorCory Johns <johnsca@gmail.com>
Tue, 7 Mar 2017 23:00:49 +0000 (17:00 -0600)
committerGitHub <noreply@github.com>
Tue, 7 Mar 2017 23:00:49 +0000 (17:00 -0600)
Fix out-of-order websocket responses

juju/client/connection.py
juju/utils.py
tests/base.py
tests/unit/__init__.py [new file with mode: 0644]
tests/unit/test_connection.py [new file with mode: 0644]

index b508a1a..b9eb3bc 100644 (file)
@@ -16,6 +16,7 @@ import yaml
 
 from juju import tag
 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
+from juju.utils import IdQueue
 
 log = logging.getLogger("websocket")
 
@@ -52,6 +53,7 @@ class Connection:
         self.addr = None
         self.ws = None
         self.facades = {}
+        self.messages = IdQueue(loop=self.loop)
 
     @property
     def is_open(self):
@@ -74,17 +76,30 @@ class Connection:
         kw['loop'] = self.loop
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
+        self.loop.create_task(self.receiver())
         log.info("Driver connected to juju %s", url)
         return self
 
     async def close(self):
         await self.ws.close()
 
-    async def recv(self):
-        result = await self.ws.recv()
-        if result is not None:
-            result = json.loads(result)
-        return result
+    async def recv(self, request_id):
+        if not self.is_open:
+            raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
+        return await self.messages.get(request_id)
+
+    async def receiver(self):
+        while self.is_open:
+            try:
+                result = await self.ws.recv()
+                if result is not None:
+                    result = json.loads(result)
+                    await self.messages.put(result['request-id'], result)
+            except Exception as e:
+                await self.messages.put_all(e)
+                raise
+        await self.messages.put_all(websockets.exceptions.ConnectionClosed(
+            0, 'websocket closed'))
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -95,7 +110,7 @@ class Connection:
             msg['version'] = self.facades[msg['type']]
         outgoing = json.dumps(msg, indent=2, cls=encoder)
         await self.ws.send(outgoing)
-        result = await self.recv()
+        result = await self.recv(msg['request-id'])
 
         if not result:
             return result
@@ -104,7 +119,7 @@ class Connection:
             # API Error Response
             raise JujuAPIError(result)
 
-        if not 'response' in result:
+        if 'response' not in result:
             # This may never happen
             return result
 
index c0a500c..f4db66e 100644 (file)
@@ -1,5 +1,7 @@
 import asyncio
 import os
+from collections import defaultdict
+from functools import partial
 from pathlib import Path
 
 
@@ -45,3 +47,25 @@ async def read_ssh_key(loop):
 
     '''
     return await loop.run_in_executor(None, _read_ssh_key)
+
+
+class IdQueue:
+    """
+    Wrapper around asyncio.Queue that maintains a separate queue for each ID.
+    """
+    def __init__(self, maxsize=0, *, loop=None):
+        self._queues = defaultdict(partial(asyncio.Queue, maxsize, loop=loop))
+
+    async def get(self, id):
+        value = await self._queues[id].get()
+        del self._queues[id]
+        if isinstance(value, Exception):
+            raise value
+        return value
+
+    async def put(self, id, value):
+        await self._queues[id].put(value)
+
+    async def put_all(self, value):
+        for queue in self._queues.values():
+            await queue.put(value)
index af386ea..292d04a 100644 (file)
@@ -44,3 +44,8 @@ class CleanModel():
         await self.model.disconnect()
         await self.controller.destroy_model(self.model.info.uuid)
         await self.controller.disconnect()
+
+
+class AsyncMock(mock.MagicMock):
+    async def __call__(self, *args, **kwargs):
+        return super().__call__(*args, **kwargs)
diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py
new file mode 100644 (file)
index 0000000..5371fdb
--- /dev/null
@@ -0,0 +1,49 @@
+import json
+import mock
+import pytest
+from collections import deque
+
+from websockets.exceptions import ConnectionClosed
+
+from .. import base
+from juju.client.connection import Connection
+
+
+class WebsocketMock:
+    def __init__(self, responses):
+        super().__init__()
+        self.responses = deque(responses)
+        self.open = True
+
+    async def send(self, message):
+        pass
+
+    async def recv(self):
+        if not self.responses:
+            raise ConnectionClosed(0, 'no reason')
+        return json.dumps(self.responses.popleft())
+
+    async def close(self):
+        self.open = False
+
+
+@pytest.mark.asyncio
+async def test_out_of_order(event_loop):
+    con = Connection(*[None]*4)
+    ws = WebsocketMock([
+        {'request-id': 1},
+        {'request-id': 3},
+        {'request-id': 2},
+    ])
+    expected_responses = [
+        {'request-id': 1},
+        {'request-id': 2},
+        {'request-id': 3},
+    ]
+    con._get_sll = mock.MagicMock()
+    with mock.patch('websockets.connect', base.AsyncMock(return_value=ws)):
+        await con.open()
+    actual_responses = []
+    for i in range(3):
+        actual_responses.append(await con.rpc({'version': 1}))
+    assert actual_responses == expected_responses