Added test for out of order receive and fix bug creating task
[osm/N2VC.git] / juju / client / connection.py
1 import base64
2 import io
3 import json
4 import logging
5 import os
6 import random
7 import shlex
8 import ssl
9 import string
10 import subprocess
11 import websockets
12 from http.client import HTTPSConnection
13
14 import asyncio
15 import yaml
16
17 from juju import tag
18 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
19 from juju.utils import IdQueue
20
21 log = logging.getLogger("websocket")
22
23
24 class Connection:
25 """
26 Usage::
27
28 # Connect to an arbitrary api server
29 client = await Connection.connect(
30 api_endpoint, model_uuid, username, password, cacert)
31
32 # Connect using a controller/model name
33 client = await Connection.connect_model('local.local:default')
34
35 # Connect to the currently active model
36 client = await Connection.connect_current()
37
38 Note: Any connection method or constructor can accept an optional `loop`
39 argument to override the default event loop from `asyncio.get_event_loop`.
40 """
41 def __init__(
42 self, endpoint, uuid, username, password, cacert=None,
43 macaroons=None, loop=None):
44 self.endpoint = endpoint
45 self.uuid = uuid
46 self.username = username
47 self.password = password
48 self.macaroons = macaroons
49 self.cacert = cacert
50 self.loop = loop or asyncio.get_event_loop()
51
52 self.__request_id__ = 0
53 self.addr = None
54 self.ws = None
55 self.facades = {}
56 self.messages = IdQueue(loop=self.loop)
57
58 @property
59 def is_open(self):
60 if self.ws:
61 return self.ws.open
62 return False
63
64 def _get_ssl(self, cert=None):
65 return ssl.create_default_context(
66 purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
67
68 async def open(self):
69 if self.uuid:
70 url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
71 else:
72 url = "wss://{}/api".format(self.endpoint)
73
74 kw = dict()
75 kw['ssl'] = self._get_ssl(self.cacert)
76 kw['loop'] = self.loop
77 self.addr = url
78 self.ws = await websockets.connect(url, **kw)
79 self.loop.create_task(self.receiver())
80 log.info("Driver connected to juju %s", url)
81 return self
82
83 async def close(self):
84 await self.ws.close()
85
86 async def recv(self, request_id):
87 return await self.messages.get(request_id)
88
89 async def receiver(self):
90 while self.is_open:
91 result = await self.ws.recv()
92 if result is not None:
93 result = json.loads(result)
94 await self.messages.put(result['request-id'], result)
95
96 async def rpc(self, msg, encoder=None):
97 self.__request_id__ += 1
98 msg['request-id'] = self.__request_id__
99 if'params' not in msg:
100 msg['params'] = {}
101 if "version" not in msg:
102 msg['version'] = self.facades[msg['type']]
103 outgoing = json.dumps(msg, indent=2, cls=encoder)
104 await self.ws.send(outgoing)
105 result = await self.recv(msg['request-id'])
106
107 if not result:
108 return result
109
110 if 'error' in result:
111 # API Error Response
112 raise JujuAPIError(result)
113
114 if not 'response' in result:
115 # This may never happen
116 return result
117
118 if 'results' in result['response']:
119 # Check for errors in a result list.
120 errors = []
121 for res in result['response']['results']:
122 if res.get('error', {}).get('message'):
123 errors.append(res['error']['message'])
124 if errors:
125 raise JujuError(errors)
126
127 elif result['response'].get('error', {}).get('message'):
128 raise JujuError(result['response']['error']['message'])
129
130 return result
131
132 def http_headers(self):
133 """Return dictionary of http headers necessary for making an http
134 connection to the endpoint of this Connection.
135
136 :return: Dictionary of headers
137
138 """
139 if not self.username:
140 return {}
141
142 creds = u'{}:{}'.format(
143 tag.user(self.username),
144 self.password or ''
145 )
146 token = base64.b64encode(creds.encode())
147 return {
148 'Authorization': 'Basic {}'.format(token.decode())
149 }
150
151 def https_connection(self):
152 """Return an https connection to this Connection's endpoint.
153
154 Returns a 3-tuple containing::
155
156 1. The :class:`HTTPSConnection` instance
157 2. Dictionary of auth headers to be used with the connection
158 3. The root url path (str) to be used for requests.
159
160 """
161 endpoint = self.endpoint
162 host, remainder = endpoint.split(':', 1)
163 port = remainder
164 if '/' in remainder:
165 port, _ = remainder.split('/', 1)
166
167 conn = HTTPSConnection(
168 host, int(port),
169 context=self._get_ssl(self.cacert),
170 )
171
172 path = (
173 "/model/{}".format(self.uuid)
174 if self.uuid else ""
175 )
176 return conn, self.http_headers(), path
177
178 async def clone(self):
179 """Return a new Connection, connected to the same websocket endpoint
180 as this one.
181
182 """
183 return await Connection.connect(
184 self.endpoint,
185 self.uuid,
186 self.username,
187 self.password,
188 self.cacert,
189 self.macaroons,
190 self.loop,
191 )
192
193 async def controller(self):
194 """Return a Connection to the controller at self.endpoint
195
196 """
197 return await Connection.connect(
198 self.endpoint,
199 None,
200 self.username,
201 self.password,
202 self.cacert,
203 self.macaroons,
204 self.loop,
205 )
206
207 @classmethod
208 async def connect(
209 cls, endpoint, uuid, username, password, cacert=None,
210 macaroons=None, loop=None):
211 """Connect to the websocket.
212
213 If uuid is None, the connection will be to the controller. Otherwise it
214 will be to the model.
215
216 """
217 client = cls(endpoint, uuid, username, password, cacert, macaroons,
218 loop)
219 await client.open()
220
221 redirect_info = await client.redirect_info()
222 if not redirect_info:
223 await client.login(username, password, macaroons)
224 return client
225
226 await client.close()
227 servers = [
228 s for servers in redirect_info['servers']
229 for s in servers if s["scope"] == 'public'
230 ]
231 for server in servers:
232 client = cls(
233 "{value}:{port}".format(**server), uuid, username,
234 password, redirect_info['ca-cert'], macaroons)
235 await client.open()
236 try:
237 result = await client.login(username, password, macaroons)
238 if 'discharge-required-error' in result:
239 continue
240 return client
241 except Exception as e:
242 await client.close()
243 log.exception(e)
244
245 raise Exception(
246 "Couldn't authenticate to %s", endpoint)
247
248 @classmethod
249 async def connect_current(cls, loop=None):
250 """Connect to the currently active model.
251
252 """
253 jujudata = JujuData()
254 controller_name = jujudata.current_controller()
255 model_name = jujudata.current_model()
256
257 return await cls.connect_model(
258 '{}:{}'.format(controller_name, model_name), loop)
259
260 @classmethod
261 async def connect_current_controller(cls, loop=None):
262 """Connect to the currently active controller.
263
264 """
265 jujudata = JujuData()
266 controller_name = jujudata.current_controller()
267 if not controller_name:
268 raise JujuConnectionError('No current controller')
269
270 return await cls.connect_controller(controller_name, loop)
271
272 @classmethod
273 async def connect_controller(cls, controller_name, loop=None):
274 """Connect to a controller by name.
275
276 """
277 jujudata = JujuData()
278 controller = jujudata.controllers()[controller_name]
279 endpoint = controller['api-endpoints'][0]
280 cacert = controller.get('ca-cert')
281 accounts = jujudata.accounts()[controller_name]
282 username = accounts['user']
283 password = accounts.get('password')
284 macaroons = get_macaroons() if not password else None
285
286 return await cls.connect(
287 endpoint, None, username, password, cacert, macaroons, loop)
288
289 @classmethod
290 async def connect_model(cls, model, loop=None):
291 """Connect to a model by name.
292
293 :param str model: [<controller>:]<model>
294
295 """
296 jujudata = JujuData()
297
298 if ':' in model:
299 # explicit controller given
300 controller_name, model_name = model.split(':')
301 else:
302 # use the current controller if one isn't explicitly given
303 controller_name = jujudata.current_controller()
304 model_name = model
305
306 accounts = jujudata.accounts()[controller_name]
307 username = accounts['user']
308 # model name must include a user prefix, so add it if it doesn't
309 if '/' not in model_name:
310 model_name = '{}/{}'.format(username, model_name)
311
312 controller = jujudata.controllers()[controller_name]
313 endpoint = controller['api-endpoints'][0]
314 cacert = controller.get('ca-cert')
315 password = accounts.get('password')
316 models = jujudata.models()[controller_name]
317 model_uuid = models['models'][model_name]['uuid']
318 macaroons = get_macaroons() if not password else None
319
320 return await cls.connect(
321 endpoint, model_uuid, username, password, cacert, macaroons, loop)
322
323 def build_facades(self, info):
324 self.facades.clear()
325 for facade in info:
326 self.facades[facade['name']] = facade['versions'][-1]
327
328 async def login(self, username, password, macaroons=None):
329 if macaroons:
330 username = ''
331 password = ''
332
333 if username and not username.startswith('user-'):
334 username = 'user-{}'.format(username)
335
336 result = await self.rpc({
337 "type": "Admin",
338 "request": "Login",
339 "version": 3,
340 "params": {
341 "auth-tag": username,
342 "credentials": password,
343 "nonce": "".join(random.sample(string.printable, 12)),
344 "macaroons": macaroons or []
345 }})
346 response = result['response']
347 self.build_facades(response.get('facades', {}))
348 self.info = response.copy()
349 return response
350
351 async def redirect_info(self):
352 try:
353 result = await self.rpc({
354 "type": "Admin",
355 "request": "RedirectInfo",
356 "version": 3,
357 })
358 except JujuAPIError as e:
359 if e.message == 'not redirected':
360 return None
361 raise
362 return result['response']
363
364
365 class JujuData:
366 def __init__(self):
367 self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
368 self.path = os.path.abspath(os.path.expanduser(self.path))
369
370 def current_controller(self):
371 cmd = shlex.split('juju list-controllers --format yaml')
372 output = subprocess.check_output(cmd)
373 output = yaml.safe_load(output)
374 return output.get('current-controller', '')
375
376 def current_model(self, controller_name=None):
377 if not controller_name:
378 controller_name = self.current_controller()
379 models = self.models()[controller_name]
380 if 'current-model' not in models:
381 raise JujuError('No current model')
382 return models['current-model']
383
384 def controllers(self):
385 return self._load_yaml('controllers.yaml', 'controllers')
386
387 def models(self):
388 return self._load_yaml('models.yaml', 'controllers')
389
390 def accounts(self):
391 return self._load_yaml('accounts.yaml', 'controllers')
392
393 def _load_yaml(self, filename, key):
394 filepath = os.path.join(self.path, filename)
395 with io.open(filepath, 'rt') as f:
396 return yaml.safe_load(f)[key]
397
398
399 def get_macaroons():
400 """Decode and return macaroons from default ~/.go-cookies
401
402 """
403 try:
404 cookie_file = os.path.expanduser('~/.go-cookies')
405 with open(cookie_file, 'r') as f:
406 cookies = json.load(f)
407 except (OSError, ValueError):
408 log.warn("Couldn't load macaroons from %s", cookie_file)
409 return []
410
411 base64_macaroons = [
412 c['Value'] for c in cookies
413 if c['Name'].startswith('macaroon-') and c['Value']
414 ]
415
416 return [
417 json.loads(base64.b64decode(value).decode('utf-8'))
418 for value in base64_macaroons
419 ]