Squashed 'modules/libjuju/' changes from c50c361..c127833
[osm/N2VC.git] / tests / integration / test_connection.py
index 290203d..6647a03 100644 (file)
@@ -1,28 +1,30 @@
 import asyncio
-import pytest
+import http
+import logging
+import socket
+import ssl
+from contextlib import closing
+from pathlib import Path
 
-from juju.client.connection import Connection
 from juju.client import client
-from .. import base
+from juju.client.connection import Connection
+from juju.controller import Controller
+from juju.utils import run_with_interrupt
 
+import pytest
+import websockets
 
-@base.bootstrapped
-@pytest.mark.asyncio
-async def test_connect_current(event_loop):
-    async with base.CleanModel():
-        conn = await Connection.connect_current()
+from .. import base
 
-        assert isinstance(conn, Connection)
-        await conn.close()
+
+logger = logging.getLogger(__name__)
 
 
 @base.bootstrapped
 @pytest.mark.asyncio
 async def test_monitor(event_loop):
-
-    async with base.CleanModel():
-        conn = await Connection.connect_current()
-
+    async with base.CleanModel() as model:
+        conn = model.connection()
         assert conn.monitor.status == 'connected'
         await conn.close()
 
@@ -33,15 +35,17 @@ async def test_monitor(event_loop):
 @pytest.mark.asyncio
 async def test_monitor_catches_error(event_loop):
 
-    async with base.CleanModel():
-        conn = await Connection.connect_current()
+    async with base.CleanModel() as model:
+        conn = model.connection()
 
         assert conn.monitor.status == 'connected'
-        await conn.ws.close()
-
-        assert conn.monitor.status == 'error'
-
-        await conn.close()
+        try:
+            async with conn.monitor.reconnecting:
+                await conn.ws.close()
+                await asyncio.sleep(1)
+                assert conn.monitor.status == 'error'
+        finally:
+            await conn.close()
 
 
 @base.bootstrapped
@@ -55,7 +59,7 @@ async def test_full_status(event_loop):
             channel='stable',
         )
 
-        c = client.ClientFacade.from_connection(model.connection)
+        c = client.ClientFacade.from_connection(model.connection())
 
         await c.FullStatus(None)
 
@@ -64,15 +68,8 @@ async def test_full_status(event_loop):
 @pytest.mark.asyncio
 async def test_reconnect(event_loop):
     async with base.CleanModel() as model:
-        conn = await Connection.connect(
-            model.connection.endpoint,
-            model.connection.uuid,
-            model.connection.username,
-            model.connection.password,
-            model.connection.cacert,
-            model.connection.macaroons,
-            model.connection.loop,
-            model.connection.max_frame_size)
+        kwargs = model.connection().connect_params()
+        conn = await Connection.connect(**kwargs)
         try:
             await asyncio.sleep(0.1)
             assert conn.is_open
@@ -81,3 +78,159 @@ async def test_reconnect(event_loop):
             await model.block_until(lambda: conn.is_open, timeout=3)
         finally:
             await conn.close()
+
+
+@base.bootstrapped
+@pytest.mark.asyncio
+async def test_redirect(event_loop):
+    controller = Controller()
+    await controller.connect()
+    kwargs = controller.connection().connect_params()
+    await controller.disconnect()
+
+    # websockets.server.logger.setLevel(logging.DEBUG)
+    # websockets.client.logger.setLevel(logging.DEBUG)
+    # # websockets.protocol.logger.setLevel(logging.DEBUG)
+    # logger.setLevel(logging.DEBUG)
+
+    destination = 'wss://{}/api'.format(kwargs['endpoint'])
+    redirect_statuses = [
+        http.HTTPStatus.MOVED_PERMANENTLY,
+        http.HTTPStatus.FOUND,
+        http.HTTPStatus.SEE_OTHER,
+        http.HTTPStatus.TEMPORARY_REDIRECT,
+        http.HTTPStatus.PERMANENT_REDIRECT,
+    ]
+    test_server_cert = Path(__file__).with_name('cert.pem')
+    kwargs['cacert'] += '\n' + test_server_cert.read_text()
+    server = RedirectServer(destination, event_loop)
+    try:
+        for status in redirect_statuses:
+            logger.debug('test: starting {}'.format(status))
+            server.start(status)
+            await run_with_interrupt(server.running.wait(),
+                                     server.terminated)
+            if server.exception:
+                raise server.exception
+            assert not server.terminated.is_set()
+            logger.debug('test: started')
+            kwargs_copy = dict(kwargs,
+                               endpoint='localhost:{}'.format(server.port))
+            logger.debug('test: connecting')
+            conn = await Connection.connect(**kwargs_copy)
+            logger.debug('test: connected')
+            await conn.close()
+            logger.debug('test: stopping')
+            server.stop()
+            await server.stopped.wait()
+            logger.debug('test: stopped')
+    finally:
+        server.terminate()
+        await server.terminated.wait()
+
+
+class RedirectServer:
+    def __init__(self, destination, loop):
+        self.destination = destination
+        self.loop = loop
+        self._start = asyncio.Event()
+        self._stop = asyncio.Event()
+        self._terminate = asyncio.Event()
+        self.running = asyncio.Event()
+        self.stopped = asyncio.Event()
+        self.terminated = asyncio.Event()
+        if hasattr(ssl, 'PROTOCOL_TLS_SERVER'):
+            # python 3.6+
+            protocol = ssl.PROTOCOL_TLS_SERVER
+        elif hasattr(ssl, 'PROTOCOL_TLS'):
+            # python 3.5.3+
+            protocol = ssl.PROTOCOL_TLS
+        else:
+            # python 3.5.2
+            protocol = ssl.PROTOCOL_TLSv1_2
+        self.ssl_context = ssl.SSLContext(protocol)
+        crt_file = Path(__file__).with_name('cert.pem')
+        key_file = Path(__file__).with_name('key.pem')
+        self.ssl_context.load_cert_chain(str(crt_file), str(key_file))
+        self.status = None
+        self.port = None
+        self._task = self.loop.create_task(self.run())
+
+    def start(self, status):
+        self.status = status
+        self.port = self._find_free_port()
+        self._start.set()
+
+    def stop(self):
+        self._stop.set()
+
+    def terminate(self):
+        self._terminate.set()
+        self.stop()
+
+    @property
+    def exception(self):
+        try:
+            return self._task.exception()
+        except (asyncio.CancelledError, asyncio.InvalidStateError):
+            return None
+
+    async def run(self):
+        logger.debug('server: active')
+
+        async def hello(websocket, path):
+            await websocket.send('hello')
+
+        async def redirect(path, request_headers):
+            return self.status, {'Location': self.destination}, b""
+
+        try:
+            while not self._terminate.is_set():
+                await run_with_interrupt(self._start.wait(),
+                                         self._terminate,
+                                         loop=self.loop)
+                if self._terminate.is_set():
+                    break
+                self._start.clear()
+                logger.debug('server: starting {}'.format(self.status))
+                try:
+                    async with websockets.serve(ws_handler=hello,
+                                                process_request=redirect,
+                                                host='localhost',
+                                                port=self.port,
+                                                ssl=self.ssl_context,
+                                                loop=self.loop):
+                        self.stopped.clear()
+                        self.running.set()
+                        logger.debug('server: started')
+                        while not self._stop.is_set():
+                            await run_with_interrupt(
+                                asyncio.sleep(1, loop=self.loop),
+                                self._stop,
+                                loop=self.loop)
+                            logger.debug('server: tick')
+                        logger.debug('server: stopping')
+                except asyncio.CancelledError:
+                    break
+                finally:
+                    self.stopped.set()
+                    self._stop.clear()
+                    self.running.clear()
+                    logger.debug('server: stopped')
+            logger.debug('server: terminating')
+        except asyncio.CancelledError:
+            pass
+        finally:
+            self._start.clear()
+            self._stop.clear()
+            self._terminate.clear()
+            self.stopped.set()
+            self.running.clear()
+            self.terminated.set()
+            logger.debug('server: terminated')
+
+    def _find_free_port(self):
+        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+            s.bind(('', 0))
+            return s.getsockname()[1]