Improve error handling
[osm/N2VC.git] / juju / client / connection.py
1 import base64
2 import io
3 import json
4 import logging
5 import os
6 import random
7 import shlex
8 import ssl
9 import string
10 import subprocess
11 import websockets
12 from http.client import HTTPSConnection
13
14 import asyncio
15 import yaml
16
17 from juju import tag
18 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
19 from juju.utils import IdQueue
20
21 log = logging.getLogger("websocket")
22
23
24 class Connection:
25 """
26 Usage::
27
28 # Connect to an arbitrary api server
29 client = await Connection.connect(
30 api_endpoint, model_uuid, username, password, cacert)
31
32 # Connect using a controller/model name
33 client = await Connection.connect_model('local.local:default')
34
35 # Connect to the currently active model
36 client = await Connection.connect_current()
37
38 Note: Any connection method or constructor can accept an optional `loop`
39 argument to override the default event loop from `asyncio.get_event_loop`.
40 """
41 def __init__(
42 self, endpoint, uuid, username, password, cacert=None,
43 macaroons=None, loop=None):
44 self.endpoint = endpoint
45 self.uuid = uuid
46 self.username = username
47 self.password = password
48 self.macaroons = macaroons
49 self.cacert = cacert
50 self.loop = loop or asyncio.get_event_loop()
51
52 self.__request_id__ = 0
53 self.addr = None
54 self.ws = None
55 self.facades = {}
56 self.messages = IdQueue(loop=self.loop)
57
58 @property
59 def is_open(self):
60 if self.ws:
61 return self.ws.open
62 return False
63
64 def _get_ssl(self, cert=None):
65 return ssl.create_default_context(
66 purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
67
68 async def open(self):
69 if self.uuid:
70 url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
71 else:
72 url = "wss://{}/api".format(self.endpoint)
73
74 kw = dict()
75 kw['ssl'] = self._get_ssl(self.cacert)
76 kw['loop'] = self.loop
77 self.addr = url
78 self.ws = await websockets.connect(url, **kw)
79 self.loop.create_task(self.receiver())
80 log.info("Driver connected to juju %s", url)
81 return self
82
83 async def close(self):
84 await self.ws.close()
85
86 async def recv(self, request_id):
87 if not self.is_open:
88 raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
89 return await self.messages.get(request_id)
90
91 async def receiver(self):
92 while self.is_open:
93 try:
94 result = await self.ws.recv()
95 if result is not None:
96 result = json.loads(result)
97 await self.messages.put(result['request-id'], result)
98 except Exception as e:
99 await self.messages.put_all(e)
100 raise
101 await self.messages.put_all(websockets.exceptions.ConnectionClosed(
102 0, 'websocket closed'))
103
104 async def rpc(self, msg, encoder=None):
105 self.__request_id__ += 1
106 msg['request-id'] = self.__request_id__
107 if'params' not in msg:
108 msg['params'] = {}
109 if "version" not in msg:
110 msg['version'] = self.facades[msg['type']]
111 outgoing = json.dumps(msg, indent=2, cls=encoder)
112 await self.ws.send(outgoing)
113 result = await self.recv(msg['request-id'])
114
115 if not result:
116 return result
117
118 if 'error' in result:
119 # API Error Response
120 raise JujuAPIError(result)
121
122 if 'response' not in result:
123 # This may never happen
124 return result
125
126 if 'results' in result['response']:
127 # Check for errors in a result list.
128 errors = []
129 for res in result['response']['results']:
130 if res.get('error', {}).get('message'):
131 errors.append(res['error']['message'])
132 if errors:
133 raise JujuError(errors)
134
135 elif result['response'].get('error', {}).get('message'):
136 raise JujuError(result['response']['error']['message'])
137
138 return result
139
140 def http_headers(self):
141 """Return dictionary of http headers necessary for making an http
142 connection to the endpoint of this Connection.
143
144 :return: Dictionary of headers
145
146 """
147 if not self.username:
148 return {}
149
150 creds = u'{}:{}'.format(
151 tag.user(self.username),
152 self.password or ''
153 )
154 token = base64.b64encode(creds.encode())
155 return {
156 'Authorization': 'Basic {}'.format(token.decode())
157 }
158
159 def https_connection(self):
160 """Return an https connection to this Connection's endpoint.
161
162 Returns a 3-tuple containing::
163
164 1. The :class:`HTTPSConnection` instance
165 2. Dictionary of auth headers to be used with the connection
166 3. The root url path (str) to be used for requests.
167
168 """
169 endpoint = self.endpoint
170 host, remainder = endpoint.split(':', 1)
171 port = remainder
172 if '/' in remainder:
173 port, _ = remainder.split('/', 1)
174
175 conn = HTTPSConnection(
176 host, int(port),
177 context=self._get_ssl(self.cacert),
178 )
179
180 path = (
181 "/model/{}".format(self.uuid)
182 if self.uuid else ""
183 )
184 return conn, self.http_headers(), path
185
186 async def clone(self):
187 """Return a new Connection, connected to the same websocket endpoint
188 as this one.
189
190 """
191 return await Connection.connect(
192 self.endpoint,
193 self.uuid,
194 self.username,
195 self.password,
196 self.cacert,
197 self.macaroons,
198 self.loop,
199 )
200
201 async def controller(self):
202 """Return a Connection to the controller at self.endpoint
203
204 """
205 return await Connection.connect(
206 self.endpoint,
207 None,
208 self.username,
209 self.password,
210 self.cacert,
211 self.macaroons,
212 self.loop,
213 )
214
215 @classmethod
216 async def connect(
217 cls, endpoint, uuid, username, password, cacert=None,
218 macaroons=None, loop=None):
219 """Connect to the websocket.
220
221 If uuid is None, the connection will be to the controller. Otherwise it
222 will be to the model.
223
224 """
225 client = cls(endpoint, uuid, username, password, cacert, macaroons,
226 loop)
227 await client.open()
228
229 redirect_info = await client.redirect_info()
230 if not redirect_info:
231 await client.login(username, password, macaroons)
232 return client
233
234 await client.close()
235 servers = [
236 s for servers in redirect_info['servers']
237 for s in servers if s["scope"] == 'public'
238 ]
239 for server in servers:
240 client = cls(
241 "{value}:{port}".format(**server), uuid, username,
242 password, redirect_info['ca-cert'], macaroons)
243 await client.open()
244 try:
245 result = await client.login(username, password, macaroons)
246 if 'discharge-required-error' in result:
247 continue
248 return client
249 except Exception as e:
250 await client.close()
251 log.exception(e)
252
253 raise Exception(
254 "Couldn't authenticate to %s", endpoint)
255
256 @classmethod
257 async def connect_current(cls, loop=None):
258 """Connect to the currently active model.
259
260 """
261 jujudata = JujuData()
262 controller_name = jujudata.current_controller()
263 model_name = jujudata.current_model()
264
265 return await cls.connect_model(
266 '{}:{}'.format(controller_name, model_name), loop)
267
268 @classmethod
269 async def connect_current_controller(cls, loop=None):
270 """Connect to the currently active controller.
271
272 """
273 jujudata = JujuData()
274 controller_name = jujudata.current_controller()
275 if not controller_name:
276 raise JujuConnectionError('No current controller')
277
278 return await cls.connect_controller(controller_name, loop)
279
280 @classmethod
281 async def connect_controller(cls, controller_name, loop=None):
282 """Connect to a controller by name.
283
284 """
285 jujudata = JujuData()
286 controller = jujudata.controllers()[controller_name]
287 endpoint = controller['api-endpoints'][0]
288 cacert = controller.get('ca-cert')
289 accounts = jujudata.accounts()[controller_name]
290 username = accounts['user']
291 password = accounts.get('password')
292 macaroons = get_macaroons() if not password else None
293
294 return await cls.connect(
295 endpoint, None, username, password, cacert, macaroons, loop)
296
297 @classmethod
298 async def connect_model(cls, model, loop=None):
299 """Connect to a model by name.
300
301 :param str model: [<controller>:]<model>
302
303 """
304 jujudata = JujuData()
305
306 if ':' in model:
307 # explicit controller given
308 controller_name, model_name = model.split(':')
309 else:
310 # use the current controller if one isn't explicitly given
311 controller_name = jujudata.current_controller()
312 model_name = model
313
314 accounts = jujudata.accounts()[controller_name]
315 username = accounts['user']
316 # model name must include a user prefix, so add it if it doesn't
317 if '/' not in model_name:
318 model_name = '{}/{}'.format(username, model_name)
319
320 controller = jujudata.controllers()[controller_name]
321 endpoint = controller['api-endpoints'][0]
322 cacert = controller.get('ca-cert')
323 password = accounts.get('password')
324 models = jujudata.models()[controller_name]
325 model_uuid = models['models'][model_name]['uuid']
326 macaroons = get_macaroons() if not password else None
327
328 return await cls.connect(
329 endpoint, model_uuid, username, password, cacert, macaroons, loop)
330
331 def build_facades(self, info):
332 self.facades.clear()
333 for facade in info:
334 self.facades[facade['name']] = facade['versions'][-1]
335
336 async def login(self, username, password, macaroons=None):
337 if macaroons:
338 username = ''
339 password = ''
340
341 if username and not username.startswith('user-'):
342 username = 'user-{}'.format(username)
343
344 result = await self.rpc({
345 "type": "Admin",
346 "request": "Login",
347 "version": 3,
348 "params": {
349 "auth-tag": username,
350 "credentials": password,
351 "nonce": "".join(random.sample(string.printable, 12)),
352 "macaroons": macaroons or []
353 }})
354 response = result['response']
355 self.build_facades(response.get('facades', {}))
356 self.info = response.copy()
357 return response
358
359 async def redirect_info(self):
360 try:
361 result = await self.rpc({
362 "type": "Admin",
363 "request": "RedirectInfo",
364 "version": 3,
365 })
366 except JujuAPIError as e:
367 if e.message == 'not redirected':
368 return None
369 raise
370 return result['response']
371
372
373 class JujuData:
374 def __init__(self):
375 self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
376 self.path = os.path.abspath(os.path.expanduser(self.path))
377
378 def current_controller(self):
379 cmd = shlex.split('juju list-controllers --format yaml')
380 output = subprocess.check_output(cmd)
381 output = yaml.safe_load(output)
382 return output.get('current-controller', '')
383
384 def current_model(self, controller_name=None):
385 if not controller_name:
386 controller_name = self.current_controller()
387 models = self.models()[controller_name]
388 if 'current-model' not in models:
389 raise JujuError('No current model')
390 return models['current-model']
391
392 def controllers(self):
393 return self._load_yaml('controllers.yaml', 'controllers')
394
395 def models(self):
396 return self._load_yaml('models.yaml', 'controllers')
397
398 def accounts(self):
399 return self._load_yaml('accounts.yaml', 'controllers')
400
401 def _load_yaml(self, filename, key):
402 filepath = os.path.join(self.path, filename)
403 with io.open(filepath, 'rt') as f:
404 return yaml.safe_load(f)[key]
405
406
407 def get_macaroons():
408 """Decode and return macaroons from default ~/.go-cookies
409
410 """
411 try:
412 cookie_file = os.path.expanduser('~/.go-cookies')
413 with open(cookie_file, 'r') as f:
414 cookies = json.load(f)
415 except (OSError, ValueError):
416 log.warn("Couldn't load macaroons from %s", cookie_file)
417 return []
418
419 base64_macaroons = [
420 c['Value'] for c in cookies
421 if c['Name'].startswith('macaroon-') and c['Value']
422 ]
423
424 return [
425 json.loads(base64.b64decode(value).decode('utf-8'))
426 for value in base64_macaroons
427 ]