Revert "Remove vendored libjuju"
[osm/N2VC.git] / modules / libjuju / tests / integration / test_connection.py
index 79ad9d0..6647a03 100644 (file)
@@ -1,17 +1,28 @@
 import asyncio
+import http
+import logging
+import socket
+import ssl
+from contextlib import closing
+from pathlib import Path
 
 from juju.client import client
 from juju.client.connection import Connection
+from juju.controller import Controller
+from juju.utils import run_with_interrupt
 
 import pytest
+import websockets
 
 from .. import base
 
 
+logger = logging.getLogger(__name__)
+
+
 @base.bootstrapped
 @pytest.mark.asyncio
 async def test_monitor(event_loop):
-
     async with base.CleanModel() as model:
         conn = model.connection()
         assert conn.monitor.status == 'connected'
@@ -67,3 +78,159 @@ async def test_reconnect(event_loop):
             await model.block_until(lambda: conn.is_open, timeout=3)
         finally:
             await conn.close()
+
+
+@base.bootstrapped
+@pytest.mark.asyncio
+async def test_redirect(event_loop):
+    controller = Controller()
+    await controller.connect()
+    kwargs = controller.connection().connect_params()
+    await controller.disconnect()
+
+    # websockets.server.logger.setLevel(logging.DEBUG)
+    # websockets.client.logger.setLevel(logging.DEBUG)
+    # # websockets.protocol.logger.setLevel(logging.DEBUG)
+    # logger.setLevel(logging.DEBUG)
+
+    destination = 'wss://{}/api'.format(kwargs['endpoint'])
+    redirect_statuses = [
+        http.HTTPStatus.MOVED_PERMANENTLY,
+        http.HTTPStatus.FOUND,
+        http.HTTPStatus.SEE_OTHER,
+        http.HTTPStatus.TEMPORARY_REDIRECT,
+        http.HTTPStatus.PERMANENT_REDIRECT,
+    ]
+    test_server_cert = Path(__file__).with_name('cert.pem')
+    kwargs['cacert'] += '\n' + test_server_cert.read_text()
+    server = RedirectServer(destination, event_loop)
+    try:
+        for status in redirect_statuses:
+            logger.debug('test: starting {}'.format(status))
+            server.start(status)
+            await run_with_interrupt(server.running.wait(),
+                                     server.terminated)
+            if server.exception:
+                raise server.exception
+            assert not server.terminated.is_set()
+            logger.debug('test: started')
+            kwargs_copy = dict(kwargs,
+                               endpoint='localhost:{}'.format(server.port))
+            logger.debug('test: connecting')
+            conn = await Connection.connect(**kwargs_copy)
+            logger.debug('test: connected')
+            await conn.close()
+            logger.debug('test: stopping')
+            server.stop()
+            await server.stopped.wait()
+            logger.debug('test: stopped')
+    finally:
+        server.terminate()
+        await server.terminated.wait()
+
+
+class RedirectServer:
+    def __init__(self, destination, loop):
+        self.destination = destination
+        self.loop = loop
+        self._start = asyncio.Event()
+        self._stop = asyncio.Event()
+        self._terminate = asyncio.Event()
+        self.running = asyncio.Event()
+        self.stopped = asyncio.Event()
+        self.terminated = asyncio.Event()
+        if hasattr(ssl, 'PROTOCOL_TLS_SERVER'):
+            # python 3.6+
+            protocol = ssl.PROTOCOL_TLS_SERVER
+        elif hasattr(ssl, 'PROTOCOL_TLS'):
+            # python 3.5.3+
+            protocol = ssl.PROTOCOL_TLS
+        else:
+            # python 3.5.2
+            protocol = ssl.PROTOCOL_TLSv1_2
+        self.ssl_context = ssl.SSLContext(protocol)
+        crt_file = Path(__file__).with_name('cert.pem')
+        key_file = Path(__file__).with_name('key.pem')
+        self.ssl_context.load_cert_chain(str(crt_file), str(key_file))
+        self.status = None
+        self.port = None
+        self._task = self.loop.create_task(self.run())
+
+    def start(self, status):
+        self.status = status
+        self.port = self._find_free_port()
+        self._start.set()
+
+    def stop(self):
+        self._stop.set()
+
+    def terminate(self):
+        self._terminate.set()
+        self.stop()
+
+    @property
+    def exception(self):
+        try:
+            return self._task.exception()
+        except (asyncio.CancelledError, asyncio.InvalidStateError):
+            return None
+
+    async def run(self):
+        logger.debug('server: active')
+
+        async def hello(websocket, path):
+            await websocket.send('hello')
+
+        async def redirect(path, request_headers):
+            return self.status, {'Location': self.destination}, b""
+
+        try:
+            while not self._terminate.is_set():
+                await run_with_interrupt(self._start.wait(),
+                                         self._terminate,
+                                         loop=self.loop)
+                if self._terminate.is_set():
+                    break
+                self._start.clear()
+                logger.debug('server: starting {}'.format(self.status))
+                try:
+                    async with websockets.serve(ws_handler=hello,
+                                                process_request=redirect,
+                                                host='localhost',
+                                                port=self.port,
+                                                ssl=self.ssl_context,
+                                                loop=self.loop):
+                        self.stopped.clear()
+                        self.running.set()
+                        logger.debug('server: started')
+                        while not self._stop.is_set():
+                            await run_with_interrupt(
+                                asyncio.sleep(1, loop=self.loop),
+                                self._stop,
+                                loop=self.loop)
+                            logger.debug('server: tick')
+                        logger.debug('server: stopping')
+                except asyncio.CancelledError:
+                    break
+                finally:
+                    self.stopped.set()
+                    self._stop.clear()
+                    self.running.clear()
+                    logger.debug('server: stopped')
+            logger.debug('server: terminating')
+        except asyncio.CancelledError:
+            pass
+        finally:
+            self._start.clear()
+            self._stop.clear()
+            self._terminate.clear()
+            self.stopped.set()
+            self.running.clear()
+            self.terminated.set()
+            logger.debug('server: terminated')
+
+    def _find_free_port(self):
+        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+            s.bind(('', 0))
+            return s.getsockname()[1]