blob: 6647a03646d56f7c87485e8dfae09c35f427f6f7 [file] [log] [blame]
import asyncio
import http
import logging
import socket
import ssl
from contextlib import closing
from pathlib import Path
from juju.client import client
from juju.client.connection import Connection
from juju.controller import Controller
from juju.utils import run_with_interrupt
import pytest
import websockets
from .. import base
logger = logging.getLogger(__name__)
@base.bootstrapped
@pytest.mark.asyncio
async def test_monitor(event_loop):
async with base.CleanModel() as model:
conn = model.connection()
assert conn.monitor.status == 'connected'
await conn.close()
assert conn.monitor.status == 'disconnected'
@base.bootstrapped
@pytest.mark.asyncio
async def test_monitor_catches_error(event_loop):
async with base.CleanModel() as model:
conn = model.connection()
assert conn.monitor.status == 'connected'
try:
async with conn.monitor.reconnecting:
await conn.ws.close()
await asyncio.sleep(1)
assert conn.monitor.status == 'error'
finally:
await conn.close()
@base.bootstrapped
@pytest.mark.asyncio
async def test_full_status(event_loop):
async with base.CleanModel() as model:
await model.deploy(
'ubuntu-0',
application_name='ubuntu',
series='trusty',
channel='stable',
)
c = client.ClientFacade.from_connection(model.connection())
await c.FullStatus(None)
@base.bootstrapped
@pytest.mark.asyncio
async def test_reconnect(event_loop):
async with base.CleanModel() as model:
kwargs = model.connection().connect_params()
conn = await Connection.connect(**kwargs)
try:
await asyncio.sleep(0.1)
assert conn.is_open
await conn.ws.close()
assert not conn.is_open
await model.block_until(lambda: conn.is_open, timeout=3)
finally:
await conn.close()
@base.bootstrapped
@pytest.mark.asyncio
async def test_redirect(event_loop):
controller = Controller()
await controller.connect()
kwargs = controller.connection().connect_params()
await controller.disconnect()
# websockets.server.logger.setLevel(logging.DEBUG)
# websockets.client.logger.setLevel(logging.DEBUG)
# # websockets.protocol.logger.setLevel(logging.DEBUG)
# logger.setLevel(logging.DEBUG)
destination = 'wss://{}/api'.format(kwargs['endpoint'])
redirect_statuses = [
http.HTTPStatus.MOVED_PERMANENTLY,
http.HTTPStatus.FOUND,
http.HTTPStatus.SEE_OTHER,
http.HTTPStatus.TEMPORARY_REDIRECT,
http.HTTPStatus.PERMANENT_REDIRECT,
]
test_server_cert = Path(__file__).with_name('cert.pem')
kwargs['cacert'] += '\n' + test_server_cert.read_text()
server = RedirectServer(destination, event_loop)
try:
for status in redirect_statuses:
logger.debug('test: starting {}'.format(status))
server.start(status)
await run_with_interrupt(server.running.wait(),
server.terminated)
if server.exception:
raise server.exception
assert not server.terminated.is_set()
logger.debug('test: started')
kwargs_copy = dict(kwargs,
endpoint='localhost:{}'.format(server.port))
logger.debug('test: connecting')
conn = await Connection.connect(**kwargs_copy)
logger.debug('test: connected')
await conn.close()
logger.debug('test: stopping')
server.stop()
await server.stopped.wait()
logger.debug('test: stopped')
finally:
server.terminate()
await server.terminated.wait()
class RedirectServer:
def __init__(self, destination, loop):
self.destination = destination
self.loop = loop
self._start = asyncio.Event()
self._stop = asyncio.Event()
self._terminate = asyncio.Event()
self.running = asyncio.Event()
self.stopped = asyncio.Event()
self.terminated = asyncio.Event()
if hasattr(ssl, 'PROTOCOL_TLS_SERVER'):
# python 3.6+
protocol = ssl.PROTOCOL_TLS_SERVER
elif hasattr(ssl, 'PROTOCOL_TLS'):
# python 3.5.3+
protocol = ssl.PROTOCOL_TLS
else:
# python 3.5.2
protocol = ssl.PROTOCOL_TLSv1_2
self.ssl_context = ssl.SSLContext(protocol)
crt_file = Path(__file__).with_name('cert.pem')
key_file = Path(__file__).with_name('key.pem')
self.ssl_context.load_cert_chain(str(crt_file), str(key_file))
self.status = None
self.port = None
self._task = self.loop.create_task(self.run())
def start(self, status):
self.status = status
self.port = self._find_free_port()
self._start.set()
def stop(self):
self._stop.set()
def terminate(self):
self._terminate.set()
self.stop()
@property
def exception(self):
try:
return self._task.exception()
except (asyncio.CancelledError, asyncio.InvalidStateError):
return None
async def run(self):
logger.debug('server: active')
async def hello(websocket, path):
await websocket.send('hello')
async def redirect(path, request_headers):
return self.status, {'Location': self.destination}, b""
try:
while not self._terminate.is_set():
await run_with_interrupt(self._start.wait(),
self._terminate,
loop=self.loop)
if self._terminate.is_set():
break
self._start.clear()
logger.debug('server: starting {}'.format(self.status))
try:
async with websockets.serve(ws_handler=hello,
process_request=redirect,
host='localhost',
port=self.port,
ssl=self.ssl_context,
loop=self.loop):
self.stopped.clear()
self.running.set()
logger.debug('server: started')
while not self._stop.is_set():
await run_with_interrupt(
asyncio.sleep(1, loop=self.loop),
self._stop,
loop=self.loop)
logger.debug('server: tick')
logger.debug('server: stopping')
except asyncio.CancelledError:
break
finally:
self.stopped.set()
self._stop.clear()
self.running.clear()
logger.debug('server: stopped')
logger.debug('server: terminating')
except asyncio.CancelledError:
pass
finally:
self._start.clear()
self._stop.clear()
self._terminate.clear()
self.stopped.set()
self.running.clear()
self.terminated.set()
logger.debug('server: terminated')
def _find_free_port(self):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(('', 0))
return s.getsockname()[1]