+
+
+@base.bootstrapped
+@pytest.mark.asyncio
+async def test_redirect(event_loop):
+ controller = Controller()
+ await controller.connect()
+ kwargs = controller.connection().connect_params()
+ await controller.disconnect()
+
+ # websockets.server.logger.setLevel(logging.DEBUG)
+ # websockets.client.logger.setLevel(logging.DEBUG)
+ # # websockets.protocol.logger.setLevel(logging.DEBUG)
+ # logger.setLevel(logging.DEBUG)
+
+ destination = 'wss://{}/api'.format(kwargs['endpoint'])
+ redirect_statuses = [
+ http.HTTPStatus.MOVED_PERMANENTLY,
+ http.HTTPStatus.FOUND,
+ http.HTTPStatus.SEE_OTHER,
+ http.HTTPStatus.TEMPORARY_REDIRECT,
+ http.HTTPStatus.PERMANENT_REDIRECT,
+ ]
+ test_server_cert = Path(__file__).with_name('cert.pem')
+ kwargs['cacert'] += '\n' + test_server_cert.read_text()
+ server = RedirectServer(destination, event_loop)
+ try:
+ for status in redirect_statuses:
+ logger.debug('test: starting {}'.format(status))
+ server.start(status)
+ await run_with_interrupt(server.running.wait(),
+ server.terminated)
+ if server.exception:
+ raise server.exception
+ assert not server.terminated.is_set()
+ logger.debug('test: started')
+ kwargs_copy = dict(kwargs,
+ endpoint='localhost:{}'.format(server.port))
+ logger.debug('test: connecting')
+ conn = await Connection.connect(**kwargs_copy)
+ logger.debug('test: connected')
+ await conn.close()
+ logger.debug('test: stopping')
+ server.stop()
+ await server.stopped.wait()
+ logger.debug('test: stopped')
+ finally:
+ server.terminate()
+ await server.terminated.wait()
+
+
+class RedirectServer:
+ def __init__(self, destination, loop):
+ self.destination = destination
+ self.loop = loop
+ self._start = asyncio.Event()
+ self._stop = asyncio.Event()
+ self._terminate = asyncio.Event()
+ self.running = asyncio.Event()
+ self.stopped = asyncio.Event()
+ self.terminated = asyncio.Event()
+ if hasattr(ssl, 'PROTOCOL_TLS_SERVER'):
+ # python 3.6+
+ protocol = ssl.PROTOCOL_TLS_SERVER
+ elif hasattr(ssl, 'PROTOCOL_TLS'):
+ # python 3.5.3+
+ protocol = ssl.PROTOCOL_TLS
+ else:
+ # python 3.5.2
+ protocol = ssl.PROTOCOL_TLSv1_2
+ self.ssl_context = ssl.SSLContext(protocol)
+ crt_file = Path(__file__).with_name('cert.pem')
+ key_file = Path(__file__).with_name('key.pem')
+ self.ssl_context.load_cert_chain(str(crt_file), str(key_file))
+ self.status = None
+ self.port = None
+ self._task = self.loop.create_task(self.run())
+
+ def start(self, status):
+ self.status = status
+ self.port = self._find_free_port()
+ self._start.set()
+
+ def stop(self):
+ self._stop.set()
+
+ def terminate(self):
+ self._terminate.set()
+ self.stop()
+
+ @property
+ def exception(self):
+ try:
+ return self._task.exception()
+ except (asyncio.CancelledError, asyncio.InvalidStateError):
+ return None
+
+ async def run(self):
+ logger.debug('server: active')
+
+ async def hello(websocket, path):
+ await websocket.send('hello')
+
+ async def redirect(path, request_headers):
+ return self.status, {'Location': self.destination}, b""
+
+ try:
+ while not self._terminate.is_set():
+ await run_with_interrupt(self._start.wait(),
+ self._terminate,
+ loop=self.loop)
+ if self._terminate.is_set():
+ break
+ self._start.clear()
+ logger.debug('server: starting {}'.format(self.status))
+ try:
+ async with websockets.serve(ws_handler=hello,
+ process_request=redirect,
+ host='localhost',
+ port=self.port,
+ ssl=self.ssl_context,
+ loop=self.loop):
+ self.stopped.clear()
+ self.running.set()
+ logger.debug('server: started')
+ while not self._stop.is_set():
+ await run_with_interrupt(
+ asyncio.sleep(1, loop=self.loop),
+ self._stop,
+ loop=self.loop)
+ logger.debug('server: tick')
+ logger.debug('server: stopping')
+ except asyncio.CancelledError:
+ break
+ finally:
+ self.stopped.set()
+ self._stop.clear()
+ self.running.clear()
+ logger.debug('server: stopped')
+ logger.debug('server: terminating')
+ except asyncio.CancelledError:
+ pass
+ finally:
+ self._start.clear()
+ self._stop.clear()
+ self._terminate.clear()
+ self.stopped.set()
+ self.running.clear()
+ self.terminated.set()
+ logger.debug('server: terminated')
+
+ def _find_free_port(self):
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ s.bind(('', 0))
+ return s.getsockname()[1]