import asyncio
-import pytest
+import http
+import logging
+import socket
+import ssl
+from contextlib import closing
+from pathlib import Path
-from juju.client.connection import Connection
from juju.client import client
-from .. import base
+from juju.client.connection import Connection
+from juju.controller import Controller
+from juju.utils import run_with_interrupt
+import pytest
+import websockets
-@base.bootstrapped
-@pytest.mark.asyncio
-async def test_connect_current(event_loop):
- async with base.CleanModel():
- conn = await Connection.connect_current()
+from .. import base
- assert isinstance(conn, Connection)
- await conn.close()
+
+logger = logging.getLogger(__name__)
@base.bootstrapped
@pytest.mark.asyncio
async def test_monitor(event_loop):
-
- async with base.CleanModel():
- conn = await Connection.connect_current()
-
+ async with base.CleanModel() as model:
+ conn = model.connection()
assert conn.monitor.status == 'connected'
await conn.close()
@pytest.mark.asyncio
async def test_monitor_catches_error(event_loop):
- async with base.CleanModel():
- conn = await Connection.connect_current()
+ async with base.CleanModel() as model:
+ conn = model.connection()
assert conn.monitor.status == 'connected'
- await conn.ws.close()
-
- assert conn.monitor.status == 'error'
-
- await conn.close()
+ try:
+ async with conn.monitor.reconnecting:
+ await conn.ws.close()
+ await asyncio.sleep(1)
+ assert conn.monitor.status == 'error'
+ finally:
+ await conn.close()
@base.bootstrapped
channel='stable',
)
- c = client.ClientFacade.from_connection(model.connection)
+ c = client.ClientFacade.from_connection(model.connection())
await c.FullStatus(None)
@pytest.mark.asyncio
async def test_reconnect(event_loop):
async with base.CleanModel() as model:
- conn = await Connection.connect(
- model.connection.endpoint,
- model.connection.uuid,
- model.connection.username,
- model.connection.password,
- model.connection.cacert,
- model.connection.macaroons,
- model.connection.loop,
- model.connection.max_frame_size)
+ kwargs = model.connection().connect_params()
+ conn = await Connection.connect(**kwargs)
try:
await asyncio.sleep(0.1)
assert conn.is_open
await model.block_until(lambda: conn.is_open, timeout=3)
finally:
await conn.close()
+
+
+@base.bootstrapped
+@pytest.mark.asyncio
+async def test_redirect(event_loop):
+ controller = Controller()
+ await controller.connect()
+ kwargs = controller.connection().connect_params()
+ await controller.disconnect()
+
+ # websockets.server.logger.setLevel(logging.DEBUG)
+ # websockets.client.logger.setLevel(logging.DEBUG)
+ # # websockets.protocol.logger.setLevel(logging.DEBUG)
+ # logger.setLevel(logging.DEBUG)
+
+ destination = 'wss://{}/api'.format(kwargs['endpoint'])
+ redirect_statuses = [
+ http.HTTPStatus.MOVED_PERMANENTLY,
+ http.HTTPStatus.FOUND,
+ http.HTTPStatus.SEE_OTHER,
+ http.HTTPStatus.TEMPORARY_REDIRECT,
+ http.HTTPStatus.PERMANENT_REDIRECT,
+ ]
+ test_server_cert = Path(__file__).with_name('cert.pem')
+ kwargs['cacert'] += '\n' + test_server_cert.read_text()
+ server = RedirectServer(destination, event_loop)
+ try:
+ for status in redirect_statuses:
+ logger.debug('test: starting {}'.format(status))
+ server.start(status)
+ await run_with_interrupt(server.running.wait(),
+ server.terminated)
+ if server.exception:
+ raise server.exception
+ assert not server.terminated.is_set()
+ logger.debug('test: started')
+ kwargs_copy = dict(kwargs,
+ endpoint='localhost:{}'.format(server.port))
+ logger.debug('test: connecting')
+ conn = await Connection.connect(**kwargs_copy)
+ logger.debug('test: connected')
+ await conn.close()
+ logger.debug('test: stopping')
+ server.stop()
+ await server.stopped.wait()
+ logger.debug('test: stopped')
+ finally:
+ server.terminate()
+ await server.terminated.wait()
+
+
+class RedirectServer:
+ def __init__(self, destination, loop):
+ self.destination = destination
+ self.loop = loop
+ self._start = asyncio.Event()
+ self._stop = asyncio.Event()
+ self._terminate = asyncio.Event()
+ self.running = asyncio.Event()
+ self.stopped = asyncio.Event()
+ self.terminated = asyncio.Event()
+ if hasattr(ssl, 'PROTOCOL_TLS_SERVER'):
+ # python 3.6+
+ protocol = ssl.PROTOCOL_TLS_SERVER
+ elif hasattr(ssl, 'PROTOCOL_TLS'):
+ # python 3.5.3+
+ protocol = ssl.PROTOCOL_TLS
+ else:
+ # python 3.5.2
+ protocol = ssl.PROTOCOL_TLSv1_2
+ self.ssl_context = ssl.SSLContext(protocol)
+ crt_file = Path(__file__).with_name('cert.pem')
+ key_file = Path(__file__).with_name('key.pem')
+ self.ssl_context.load_cert_chain(str(crt_file), str(key_file))
+ self.status = None
+ self.port = None
+ self._task = self.loop.create_task(self.run())
+
+ def start(self, status):
+ self.status = status
+ self.port = self._find_free_port()
+ self._start.set()
+
+ def stop(self):
+ self._stop.set()
+
+ def terminate(self):
+ self._terminate.set()
+ self.stop()
+
+ @property
+ def exception(self):
+ try:
+ return self._task.exception()
+ except (asyncio.CancelledError, asyncio.InvalidStateError):
+ return None
+
+ async def run(self):
+ logger.debug('server: active')
+
+ async def hello(websocket, path):
+ await websocket.send('hello')
+
+ async def redirect(path, request_headers):
+ return self.status, {'Location': self.destination}, b""
+
+ try:
+ while not self._terminate.is_set():
+ await run_with_interrupt(self._start.wait(),
+ self._terminate,
+ loop=self.loop)
+ if self._terminate.is_set():
+ break
+ self._start.clear()
+ logger.debug('server: starting {}'.format(self.status))
+ try:
+ async with websockets.serve(ws_handler=hello,
+ process_request=redirect,
+ host='localhost',
+ port=self.port,
+ ssl=self.ssl_context,
+ loop=self.loop):
+ self.stopped.clear()
+ self.running.set()
+ logger.debug('server: started')
+ while not self._stop.is_set():
+ await run_with_interrupt(
+ asyncio.sleep(1, loop=self.loop),
+ self._stop,
+ loop=self.loop)
+ logger.debug('server: tick')
+ logger.debug('server: stopping')
+ except asyncio.CancelledError:
+ break
+ finally:
+ self.stopped.set()
+ self._stop.clear()
+ self.running.clear()
+ logger.debug('server: stopped')
+ logger.debug('server: terminating')
+ except asyncio.CancelledError:
+ pass
+ finally:
+ self._start.clear()
+ self._stop.clear()
+ self._terminate.clear()
+ self.stopped.set()
+ self.running.clear()
+ self.terminated.set()
+ logger.debug('server: terminated')
+
+ def _find_free_port(self):
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ s.bind(('', 0))
+ return s.getsockname()[1]