6647a03646d56f7c87485e8dfae09c35f427f6f7
[osm/N2VC.git] / tests / integration / test_connection.py
1 import asyncio
2 import http
3 import logging
4 import socket
5 import ssl
6 from contextlib import closing
7 from pathlib import Path
8
9 from juju.client import client
10 from juju.client.connection import Connection
11 from juju.controller import Controller
12 from juju.utils import run_with_interrupt
13
14 import pytest
15 import websockets
16
17 from .. import base
18
19
20 logger = logging.getLogger(__name__)
21
22
23 @base.bootstrapped
24 @pytest.mark.asyncio
25 async def test_monitor(event_loop):
26 async with base.CleanModel() as model:
27 conn = model.connection()
28 assert conn.monitor.status == 'connected'
29 await conn.close()
30
31 assert conn.monitor.status == 'disconnected'
32
33
34 @base.bootstrapped
35 @pytest.mark.asyncio
36 async def test_monitor_catches_error(event_loop):
37
38 async with base.CleanModel() as model:
39 conn = model.connection()
40
41 assert conn.monitor.status == 'connected'
42 try:
43 async with conn.monitor.reconnecting:
44 await conn.ws.close()
45 await asyncio.sleep(1)
46 assert conn.monitor.status == 'error'
47 finally:
48 await conn.close()
49
50
51 @base.bootstrapped
52 @pytest.mark.asyncio
53 async def test_full_status(event_loop):
54 async with base.CleanModel() as model:
55 await model.deploy(
56 'ubuntu-0',
57 application_name='ubuntu',
58 series='trusty',
59 channel='stable',
60 )
61
62 c = client.ClientFacade.from_connection(model.connection())
63
64 await c.FullStatus(None)
65
66
67 @base.bootstrapped
68 @pytest.mark.asyncio
69 async def test_reconnect(event_loop):
70 async with base.CleanModel() as model:
71 kwargs = model.connection().connect_params()
72 conn = await Connection.connect(**kwargs)
73 try:
74 await asyncio.sleep(0.1)
75 assert conn.is_open
76 await conn.ws.close()
77 assert not conn.is_open
78 await model.block_until(lambda: conn.is_open, timeout=3)
79 finally:
80 await conn.close()
81
82
83 @base.bootstrapped
84 @pytest.mark.asyncio
85 async def test_redirect(event_loop):
86 controller = Controller()
87 await controller.connect()
88 kwargs = controller.connection().connect_params()
89 await controller.disconnect()
90
91 # websockets.server.logger.setLevel(logging.DEBUG)
92 # websockets.client.logger.setLevel(logging.DEBUG)
93 # # websockets.protocol.logger.setLevel(logging.DEBUG)
94 # logger.setLevel(logging.DEBUG)
95
96 destination = 'wss://{}/api'.format(kwargs['endpoint'])
97 redirect_statuses = [
98 http.HTTPStatus.MOVED_PERMANENTLY,
99 http.HTTPStatus.FOUND,
100 http.HTTPStatus.SEE_OTHER,
101 http.HTTPStatus.TEMPORARY_REDIRECT,
102 http.HTTPStatus.PERMANENT_REDIRECT,
103 ]
104 test_server_cert = Path(__file__).with_name('cert.pem')
105 kwargs['cacert'] += '\n' + test_server_cert.read_text()
106 server = RedirectServer(destination, event_loop)
107 try:
108 for status in redirect_statuses:
109 logger.debug('test: starting {}'.format(status))
110 server.start(status)
111 await run_with_interrupt(server.running.wait(),
112 server.terminated)
113 if server.exception:
114 raise server.exception
115 assert not server.terminated.is_set()
116 logger.debug('test: started')
117 kwargs_copy = dict(kwargs,
118 endpoint='localhost:{}'.format(server.port))
119 logger.debug('test: connecting')
120 conn = await Connection.connect(**kwargs_copy)
121 logger.debug('test: connected')
122 await conn.close()
123 logger.debug('test: stopping')
124 server.stop()
125 await server.stopped.wait()
126 logger.debug('test: stopped')
127 finally:
128 server.terminate()
129 await server.terminated.wait()
130
131
132 class RedirectServer:
133 def __init__(self, destination, loop):
134 self.destination = destination
135 self.loop = loop
136 self._start = asyncio.Event()
137 self._stop = asyncio.Event()
138 self._terminate = asyncio.Event()
139 self.running = asyncio.Event()
140 self.stopped = asyncio.Event()
141 self.terminated = asyncio.Event()
142 if hasattr(ssl, 'PROTOCOL_TLS_SERVER'):
143 # python 3.6+
144 protocol = ssl.PROTOCOL_TLS_SERVER
145 elif hasattr(ssl, 'PROTOCOL_TLS'):
146 # python 3.5.3+
147 protocol = ssl.PROTOCOL_TLS
148 else:
149 # python 3.5.2
150 protocol = ssl.PROTOCOL_TLSv1_2
151 self.ssl_context = ssl.SSLContext(protocol)
152 crt_file = Path(__file__).with_name('cert.pem')
153 key_file = Path(__file__).with_name('key.pem')
154 self.ssl_context.load_cert_chain(str(crt_file), str(key_file))
155 self.status = None
156 self.port = None
157 self._task = self.loop.create_task(self.run())
158
159 def start(self, status):
160 self.status = status
161 self.port = self._find_free_port()
162 self._start.set()
163
164 def stop(self):
165 self._stop.set()
166
167 def terminate(self):
168 self._terminate.set()
169 self.stop()
170
171 @property
172 def exception(self):
173 try:
174 return self._task.exception()
175 except (asyncio.CancelledError, asyncio.InvalidStateError):
176 return None
177
178 async def run(self):
179 logger.debug('server: active')
180
181 async def hello(websocket, path):
182 await websocket.send('hello')
183
184 async def redirect(path, request_headers):
185 return self.status, {'Location': self.destination}, b""
186
187 try:
188 while not self._terminate.is_set():
189 await run_with_interrupt(self._start.wait(),
190 self._terminate,
191 loop=self.loop)
192 if self._terminate.is_set():
193 break
194 self._start.clear()
195 logger.debug('server: starting {}'.format(self.status))
196 try:
197 async with websockets.serve(ws_handler=hello,
198 process_request=redirect,
199 host='localhost',
200 port=self.port,
201 ssl=self.ssl_context,
202 loop=self.loop):
203 self.stopped.clear()
204 self.running.set()
205 logger.debug('server: started')
206 while not self._stop.is_set():
207 await run_with_interrupt(
208 asyncio.sleep(1, loop=self.loop),
209 self._stop,
210 loop=self.loop)
211 logger.debug('server: tick')
212 logger.debug('server: stopping')
213 except asyncio.CancelledError:
214 break
215 finally:
216 self.stopped.set()
217 self._stop.clear()
218 self.running.clear()
219 logger.debug('server: stopped')
220 logger.debug('server: terminating')
221 except asyncio.CancelledError:
222 pass
223 finally:
224 self._start.clear()
225 self._stop.clear()
226 self._terminate.clear()
227 self.stopped.set()
228 self.running.clear()
229 self.terminated.set()
230 logger.debug('server: terminated')
231
232 def _find_free_port(self):
233 with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
234 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
235 s.bind(('', 0))
236 return s.getsockname()[1]