Fix issue where we do not check to make sure that we are receiving the correct response.
[osm/N2VC.git] / juju / client / connection.py
1 import base64
2 import io
3 import json
4 import logging
5 import os
6 import random
7 import shlex
8 import ssl
9 import string
10 import subprocess
11 import websockets
12 from http.client import HTTPSConnection
13
14 import asyncio
15 import yaml
16
17 from juju import tag
18 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
19
20 log = logging.getLogger("websocket")
21
22
23 class Connection:
24 """
25 Usage::
26
27 # Connect to an arbitrary api server
28 client = await Connection.connect(
29 api_endpoint, model_uuid, username, password, cacert)
30
31 # Connect using a controller/model name
32 client = await Connection.connect_model('local.local:default')
33
34 # Connect to the currently active model
35 client = await Connection.connect_current()
36
37 Note: Any connection method or constructor can accept an optional `loop`
38 argument to override the default event loop from `asyncio.get_event_loop`.
39 """
40 def __init__(
41 self, endpoint, uuid, username, password, cacert=None,
42 macaroons=None, loop=None):
43 self.endpoint = endpoint
44 self.uuid = uuid
45 self.username = username
46 self.password = password
47 self.macaroons = macaroons
48 self.cacert = cacert
49 self.loop = loop or asyncio.get_event_loop()
50
51 self.__request_id__ = 0
52 self.addr = None
53 self.ws = None
54 self.facades = {}
55 self.messages = {}
56
57 @property
58 def is_open(self):
59 if self.ws:
60 return self.ws.open
61 return False
62
63 def _get_ssl(self, cert=None):
64 return ssl.create_default_context(
65 purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
66
67 async def open(self):
68 if self.uuid:
69 url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
70 else:
71 url = "wss://{}/api".format(self.endpoint)
72
73 kw = dict()
74 kw['ssl'] = self._get_ssl(self.cacert)
75 kw['loop'] = self.loop
76 self.addr = url
77 self.ws = await websockets.connect(url, **kw)
78 log.info("Driver connected to juju %s", url)
79 return self
80
81 async def close(self):
82 await self.ws.close()
83
84 async def recv(self, request_id):
85 while not self.messages.get(request_id):
86 await asyncio.sleep(0)
87
88 result = self.messages[request_id]
89
90 del self.messages[request_id]
91 return result
92
93 async def receiver(self):
94 while self.is_open:
95 result = await self.ws.recv()
96 if result is not None:
97 result = json.loads(result)
98 self.messages[result['request-id']] = result
99
100 async def rpc(self, msg, encoder=None):
101 self.__request_id__ += 1
102 msg['request-id'] = self.__request_id__
103 if'params' not in msg:
104 msg['params'] = {}
105 if "version" not in msg:
106 msg['version'] = self.facades[msg['type']]
107 outgoing = json.dumps(msg, indent=2, cls=encoder)
108 await self.ws.send(outgoing)
109 result = await self.recv(msg['request-id'])
110
111 if not result:
112 return result
113
114 if 'error' in result:
115 # API Error Response
116 raise JujuAPIError(result)
117
118 if not 'response' in result:
119 # This may never happen
120 return result
121
122 if 'results' in result['response']:
123 # Check for errors in a result list.
124 errors = []
125 for res in result['response']['results']:
126 if res.get('error', {}).get('message'):
127 errors.append(res['error']['message'])
128 if errors:
129 raise JujuError(errors)
130
131 elif result['response'].get('error', {}).get('message'):
132 raise JujuError(result['response']['error']['message'])
133
134 return result
135
136 def http_headers(self):
137 """Return dictionary of http headers necessary for making an http
138 connection to the endpoint of this Connection.
139
140 :return: Dictionary of headers
141
142 """
143 if not self.username:
144 return {}
145
146 creds = u'{}:{}'.format(
147 tag.user(self.username),
148 self.password or ''
149 )
150 token = base64.b64encode(creds.encode())
151 return {
152 'Authorization': 'Basic {}'.format(token.decode())
153 }
154
155 def https_connection(self):
156 """Return an https connection to this Connection's endpoint.
157
158 Returns a 3-tuple containing::
159
160 1. The :class:`HTTPSConnection` instance
161 2. Dictionary of auth headers to be used with the connection
162 3. The root url path (str) to be used for requests.
163
164 """
165 endpoint = self.endpoint
166 host, remainder = endpoint.split(':', 1)
167 port = remainder
168 if '/' in remainder:
169 port, _ = remainder.split('/', 1)
170
171 conn = HTTPSConnection(
172 host, int(port),
173 context=self._get_ssl(self.cacert),
174 )
175
176 path = (
177 "/model/{}".format(self.uuid)
178 if self.uuid else ""
179 )
180 return conn, self.http_headers(), path
181
182 async def clone(self):
183 """Return a new Connection, connected to the same websocket endpoint
184 as this one.
185
186 """
187 return await Connection.connect(
188 self.endpoint,
189 self.uuid,
190 self.username,
191 self.password,
192 self.cacert,
193 self.macaroons,
194 self.loop,
195 )
196
197 async def controller(self):
198 """Return a Connection to the controller at self.endpoint
199
200 """
201 return await Connection.connect(
202 self.endpoint,
203 None,
204 self.username,
205 self.password,
206 self.cacert,
207 self.macaroons,
208 self.loop,
209 )
210
211 @classmethod
212 async def connect(
213 cls, endpoint, uuid, username, password, cacert=None,
214 macaroons=None, loop=None):
215 """Connect to the websocket.
216
217 If uuid is None, the connection will be to the controller. Otherwise it
218 will be to the model.
219
220 """
221 client = cls(endpoint, uuid, username, password, cacert, macaroons,
222 loop)
223 await client.open()
224 self.loop.create_task(self.receiver)
225
226 redirect_info = await client.redirect_info()
227 if not redirect_info:
228 await client.login(username, password, macaroons)
229 return client
230
231 await client.close()
232 servers = [
233 s for servers in redirect_info['servers']
234 for s in servers if s["scope"] == 'public'
235 ]
236 for server in servers:
237 client = cls(
238 "{value}:{port}".format(**server), uuid, username,
239 password, redirect_info['ca-cert'], macaroons)
240 await client.open()
241 try:
242 result = await client.login(username, password, macaroons)
243 if 'discharge-required-error' in result:
244 continue
245 return client
246 except Exception as e:
247 await client.close()
248 log.exception(e)
249
250 raise Exception(
251 "Couldn't authenticate to %s", endpoint)
252
253 @classmethod
254 async def connect_current(cls, loop=None):
255 """Connect to the currently active model.
256
257 """
258 jujudata = JujuData()
259 controller_name = jujudata.current_controller()
260 model_name = jujudata.current_model()
261
262 return await cls.connect_model(
263 '{}:{}'.format(controller_name, model_name), loop)
264
265 @classmethod
266 async def connect_current_controller(cls, loop=None):
267 """Connect to the currently active controller.
268
269 """
270 jujudata = JujuData()
271 controller_name = jujudata.current_controller()
272 if not controller_name:
273 raise JujuConnectionError('No current controller')
274
275 return await cls.connect_controller(controller_name, loop)
276
277 @classmethod
278 async def connect_controller(cls, controller_name, loop=None):
279 """Connect to a controller by name.
280
281 """
282 jujudata = JujuData()
283 controller = jujudata.controllers()[controller_name]
284 endpoint = controller['api-endpoints'][0]
285 cacert = controller.get('ca-cert')
286 accounts = jujudata.accounts()[controller_name]
287 username = accounts['user']
288 password = accounts.get('password')
289 macaroons = get_macaroons() if not password else None
290
291 return await cls.connect(
292 endpoint, None, username, password, cacert, macaroons, loop)
293
294 @classmethod
295 async def connect_model(cls, model, loop=None):
296 """Connect to a model by name.
297
298 :param str model: [<controller>:]<model>
299
300 """
301 jujudata = JujuData()
302
303 if ':' in model:
304 # explicit controller given
305 controller_name, model_name = model.split(':')
306 else:
307 # use the current controller if one isn't explicitly given
308 controller_name = jujudata.current_controller()
309 model_name = model
310
311 accounts = jujudata.accounts()[controller_name]
312 username = accounts['user']
313 # model name must include a user prefix, so add it if it doesn't
314 if '/' not in model_name:
315 model_name = '{}/{}'.format(username, model_name)
316
317 controller = jujudata.controllers()[controller_name]
318 endpoint = controller['api-endpoints'][0]
319 cacert = controller.get('ca-cert')
320 password = accounts.get('password')
321 models = jujudata.models()[controller_name]
322 model_uuid = models['models'][model_name]['uuid']
323 macaroons = get_macaroons() if not password else None
324
325 return await cls.connect(
326 endpoint, model_uuid, username, password, cacert, macaroons, loop)
327
328 def build_facades(self, info):
329 self.facades.clear()
330 for facade in info:
331 self.facades[facade['name']] = facade['versions'][-1]
332
333 async def login(self, username, password, macaroons=None):
334 if macaroons:
335 username = ''
336 password = ''
337
338 if username and not username.startswith('user-'):
339 username = 'user-{}'.format(username)
340
341 result = await self.rpc({
342 "type": "Admin",
343 "request": "Login",
344 "version": 3,
345 "params": {
346 "auth-tag": username,
347 "credentials": password,
348 "nonce": "".join(random.sample(string.printable, 12)),
349 "macaroons": macaroons or []
350 }})
351 response = result['response']
352 self.build_facades(response.get('facades', {}))
353 self.info = response.copy()
354 return response
355
356 async def redirect_info(self):
357 try:
358 result = await self.rpc({
359 "type": "Admin",
360 "request": "RedirectInfo",
361 "version": 3,
362 })
363 except JujuAPIError as e:
364 if e.message == 'not redirected':
365 return None
366 raise
367 return result['response']
368
369
370 class JujuData:
371 def __init__(self):
372 self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
373 self.path = os.path.abspath(os.path.expanduser(self.path))
374
375 def current_controller(self):
376 cmd = shlex.split('juju list-controllers --format yaml')
377 output = subprocess.check_output(cmd)
378 output = yaml.safe_load(output)
379 return output.get('current-controller', '')
380
381 def current_model(self, controller_name=None):
382 if not controller_name:
383 controller_name = self.current_controller()
384 models = self.models()[controller_name]
385 if 'current-model' not in models:
386 raise JujuError('No current model')
387 return models['current-model']
388
389 def controllers(self):
390 return self._load_yaml('controllers.yaml', 'controllers')
391
392 def models(self):
393 return self._load_yaml('models.yaml', 'controllers')
394
395 def accounts(self):
396 return self._load_yaml('accounts.yaml', 'controllers')
397
398 def _load_yaml(self, filename, key):
399 filepath = os.path.join(self.path, filename)
400 with io.open(filepath, 'rt') as f:
401 return yaml.safe_load(f)[key]
402
403
404 def get_macaroons():
405 """Decode and return macaroons from default ~/.go-cookies
406
407 """
408 try:
409 cookie_file = os.path.expanduser('~/.go-cookies')
410 with open(cookie_file, 'r') as f:
411 cookies = json.load(f)
412 except (OSError, ValueError):
413 log.warn("Couldn't load macaroons from %s", cookie_file)
414 return []
415
416 base64_macaroons = [
417 c['Value'] for c in cookies
418 if c['Name'].startswith('macaroon-') and c['Value']
419 ]
420
421 return [
422 json.loads(base64.b64decode(value).decode('utf-8'))
423 for value in base64_macaroons
424 ]