from juju import tag
from juju.errors import JujuError, JujuAPIError, JujuConnectionError
+from juju.utils import IdQueue
log = logging.getLogger("websocket")
self.addr = None
self.ws = None
self.facades = {}
+ self.messages = IdQueue(loop=self.loop)
@property
def is_open(self):
kw['loop'] = self.loop
self.addr = url
self.ws = await websockets.connect(url, **kw)
+ self.loop.create_task(self.receiver())
log.info("Driver connected to juju %s", url)
return self
async def close(self):
await self.ws.close()
- async def recv(self):
- result = await self.ws.recv()
- if result is not None:
- result = json.loads(result)
- return result
+ async def recv(self, request_id):
+ if not self.is_open:
+ raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
+ return await self.messages.get(request_id)
+
+ async def receiver(self):
+ while self.is_open:
+ try:
+ result = await self.ws.recv()
+ if result is not None:
+ result = json.loads(result)
+ await self.messages.put(result['request-id'], result)
+ except Exception as e:
+ await self.messages.put_all(e)
+ raise
+ await self.messages.put_all(websockets.exceptions.ConnectionClosed(
+ 0, 'websocket closed'))
async def rpc(self, msg, encoder=None):
self.__request_id__ += 1
msg['version'] = self.facades[msg['type']]
outgoing = json.dumps(msg, indent=2, cls=encoder)
await self.ws.send(outgoing)
- result = await self.recv()
+ result = await self.recv(msg['request-id'])
if not result:
return result
# API Error Response
raise JujuAPIError(result)
- if not 'response' in result:
+ if 'response' not in result:
# This may never happen
return result
import asyncio
import os
+from collections import defaultdict
+from functools import partial
from pathlib import Path
'''
return await loop.run_in_executor(None, _read_ssh_key)
+
+
+class IdQueue:
+ """
+ Wrapper around asyncio.Queue that maintains a separate queue for each ID.
+ """
+ def __init__(self, maxsize=0, *, loop=None):
+ self._queues = defaultdict(partial(asyncio.Queue, maxsize, loop=loop))
+
+ async def get(self, id):
+ value = await self._queues[id].get()
+ del self._queues[id]
+ if isinstance(value, Exception):
+ raise value
+ return value
+
+ async def put(self, id, value):
+ await self._queues[id].put(value)
+
+ async def put_all(self, value):
+ for queue in self._queues.values():
+ await queue.put(value)
await self.model.disconnect()
await self.controller.destroy_model(self.model.info.uuid)
await self.controller.disconnect()
+
+
+class AsyncMock(mock.MagicMock):
+ async def __call__(self, *args, **kwargs):
+ return super().__call__(*args, **kwargs)
--- /dev/null
+import json
+import mock
+import pytest
+from collections import deque
+
+from websockets.exceptions import ConnectionClosed
+
+from .. import base
+from juju.client.connection import Connection
+
+
+class WebsocketMock:
+ def __init__(self, responses):
+ super().__init__()
+ self.responses = deque(responses)
+ self.open = True
+
+ async def send(self, message):
+ pass
+
+ async def recv(self):
+ if not self.responses:
+ raise ConnectionClosed(0, 'no reason')
+ return json.dumps(self.responses.popleft())
+
+ async def close(self):
+ self.open = False
+
+
+@pytest.mark.asyncio
+async def test_out_of_order(event_loop):
+ con = Connection(*[None]*4)
+ ws = WebsocketMock([
+ {'request-id': 1},
+ {'request-id': 3},
+ {'request-id': 2},
+ ])
+ expected_responses = [
+ {'request-id': 1},
+ {'request-id': 2},
+ {'request-id': 3},
+ ]
+ con._get_sll = mock.MagicMock()
+ with mock.patch('websockets.connect', base.AsyncMock(return_value=ws)):
+ await con.open()
+ actual_responses = []
+ for i in range(3):
+ actual_responses.append(await con.rpc({'version': 1}))
+ assert actual_responses == expected_responses