Raise errors in a centralized fashion.
[osm/N2VC.git] / juju / client / connection.py
1 import base64
2 import io
3 import json
4 import logging
5 import os
6 import random
7 import shlex
8 import ssl
9 import string
10 import subprocess
11 import websockets
12 from http.client import HTTPSConnection
13
14 import yaml
15
16 from juju import tag
17 from juju.errors import JujuAPIError, JujuConnectionError, JujuError
18
19 log = logging.getLogger("websocket")
20
21
22 class Connection:
23 """
24 Usage::
25
26 # Connect to an arbitrary api server
27 client = await Connection.connect(
28 api_endpoint, model_uuid, username, password, cacert)
29
30 # Connect using a controller/model name
31 client = await Connection.connect_model('local.local:default')
32
33 # Connect to the currently active model
34 client = await Connection.connect_current()
35
36 """
37 def __init__(
38 self, endpoint, uuid, username, password, cacert=None,
39 macaroons=None):
40 self.endpoint = endpoint
41 self.uuid = uuid
42 self.username = username
43 self.password = password
44 self.macaroons = macaroons
45 self.cacert = cacert
46
47 self.__request_id__ = 0
48 self.addr = None
49 self.ws = None
50 self.facades = {}
51
52 @property
53 def is_open(self):
54 if self.ws:
55 return self.ws.open
56 return False
57
58 def _get_ssl(self, cert=None):
59 return ssl.create_default_context(
60 purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
61
62 async def open(self):
63 if self.uuid:
64 url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
65 else:
66 url = "wss://{}/api".format(self.endpoint)
67
68 kw = dict()
69 kw['ssl'] = self._get_ssl(self.cacert)
70 self.addr = url
71 self.ws = await websockets.connect(url, **kw)
72 log.info("Driver connected to juju %s", url)
73 return self
74
75 async def close(self):
76 await self.ws.close()
77
78 async def recv(self):
79 result = await self.ws.recv()
80 if result is not None:
81 result = json.loads(result)
82 return result
83
84 async def rpc(self, msg, encoder=None):
85 self.__request_id__ += 1
86 msg['request-id'] = self.__request_id__
87 if'params' not in msg:
88 msg['params'] = {}
89 if "version" not in msg:
90 msg['version'] = self.facades[msg['type']]
91 outgoing = json.dumps(msg, indent=2, cls=encoder)
92 await self.ws.send(outgoing)
93 result = await self.recv()
94
95 if not result:
96 return result
97
98 if 'error' in result:
99 # API Error Response
100 raise JujuAPIError(result)
101
102 if not 'response' in result:
103 # This may never happen
104 return result
105
106 if 'results' in result['response']:
107 # Check for errors in a result list.
108 errors = []
109 for res in result['response']['results']:
110 if res.get('error', {}).get('message'):
111 errors.append(res['error']['message'])
112 if errors:
113 raise JujuError(errors)
114
115 elif result['response'].get('error', {}).get('message'):
116 raise JujuError(result['response']['error']['message'])
117
118 return result
119
120 def http_headers(self):
121 """Return dictionary of http headers necessary for making an http
122 connection to the endpoint of this Connection.
123
124 :return: Dictionary of headers
125
126 """
127 if not self.username:
128 return {}
129
130 creds = u'{}:{}'.format(
131 tag.user(self.username),
132 self.password or ''
133 )
134 token = base64.b64encode(creds.encode())
135 return {
136 'Authorization': 'Basic {}'.format(token.decode())
137 }
138
139 def https_connection(self):
140 """Return an https connection to this Connection's endpoint.
141
142 Returns a 3-tuple containing::
143
144 1. The :class:`HTTPSConnection` instance
145 2. Dictionary of auth headers to be used with the connection
146 3. The root url path (str) to be used for requests.
147
148 """
149 endpoint = self.endpoint
150 host, remainder = endpoint.split(':', 1)
151 port = remainder
152 if '/' in remainder:
153 port, _ = remainder.split('/', 1)
154
155 conn = HTTPSConnection(
156 host, int(port),
157 context=self._get_ssl(self.cacert),
158 )
159
160 path = (
161 "/model/{}".format(self.uuid)
162 if self.uuid else ""
163 )
164 return conn, self.http_headers(), path
165
166 async def clone(self):
167 """Return a new Connection, connected to the same websocket endpoint
168 as this one.
169
170 """
171 return await Connection.connect(
172 self.endpoint,
173 self.uuid,
174 self.username,
175 self.password,
176 self.cacert,
177 self.macaroons,
178 )
179
180 async def controller(self):
181 """Return a Connection to the controller at self.endpoint
182
183 """
184 return await Connection.connect(
185 self.endpoint,
186 None,
187 self.username,
188 self.password,
189 self.cacert,
190 self.macaroons,
191 )
192
193 @classmethod
194 async def connect(
195 cls, endpoint, uuid, username, password, cacert=None,
196 macaroons=None):
197 """Connect to the websocket.
198
199 If uuid is None, the connection will be to the controller. Otherwise it
200 will be to the model.
201
202 """
203 client = cls(endpoint, uuid, username, password, cacert, macaroons)
204 await client.open()
205
206 redirect_info = await client.redirect_info()
207 if not redirect_info:
208 await client.login(username, password, macaroons)
209 return client
210
211 await client.close()
212 servers = [
213 s for servers in redirect_info['servers']
214 for s in servers if s["scope"] == 'public'
215 ]
216 for server in servers:
217 client = cls(
218 "{value}:{port}".format(**server), uuid, username,
219 password, redirect_info['ca-cert'], macaroons)
220 await client.open()
221 try:
222 result = await client.login(username, password, macaroons)
223 if 'discharge-required-error' in result:
224 continue
225 return client
226 except Exception as e:
227 await client.close()
228 log.exception(e)
229
230 raise Exception(
231 "Couldn't authenticate to %s", endpoint)
232
233 @classmethod
234 async def connect_current(cls):
235 """Connect to the currently active model.
236
237 """
238 jujudata = JujuData()
239 controller_name = jujudata.current_controller()
240 models = jujudata.models()[controller_name]
241 model_name = models['current-model']
242
243 return await cls.connect_model(
244 '{}:{}'.format(controller_name, model_name))
245
246 @classmethod
247 async def connect_current_controller(cls):
248 """Connect to the currently active controller.
249
250 """
251 jujudata = JujuData()
252 controller_name = jujudata.current_controller()
253 if not controller_name:
254 raise JujuConnectionError('No current controller')
255
256 return await cls.connect_controller(controller_name)
257
258 @classmethod
259 async def connect_controller(cls, controller_name):
260 """Connect to a controller by name.
261
262 """
263 jujudata = JujuData()
264 controller = jujudata.controllers()[controller_name]
265 endpoint = controller['api-endpoints'][0]
266 cacert = controller.get('ca-cert')
267 accounts = jujudata.accounts()[controller_name]
268 username = accounts['user']
269 password = accounts.get('password')
270 macaroons = get_macaroons() if not password else None
271
272 return await cls.connect(
273 endpoint, None, username, password, cacert, macaroons)
274
275 @classmethod
276 async def connect_model(cls, model):
277 """Connect to a model by name.
278
279 :param str model: <controller>:<model>
280
281 """
282 controller_name, model_name = model.split(':')
283
284 jujudata = JujuData()
285 controller = jujudata.controllers()[controller_name]
286 endpoint = controller['api-endpoints'][0]
287 cacert = controller.get('ca-cert')
288 accounts = jujudata.accounts()[controller_name]
289 username = accounts['user']
290 password = accounts.get('password')
291 models = jujudata.models()[controller_name]
292 model_uuid = models['models'][model_name]['uuid']
293 macaroons = get_macaroons() if not password else None
294
295 return await cls.connect(
296 endpoint, model_uuid, username, password, cacert, macaroons)
297
298 def build_facades(self, info):
299 self.facades.clear()
300 for facade in info:
301 self.facades[facade['name']] = facade['versions'][-1]
302
303 async def login(self, username, password, macaroons=None):
304 if macaroons:
305 username = ''
306 password = ''
307
308 if username and not username.startswith('user-'):
309 username = 'user-{}'.format(username)
310
311 result = await self.rpc({
312 "type": "Admin",
313 "request": "Login",
314 "version": 3,
315 "params": {
316 "auth-tag": username,
317 "credentials": password,
318 "nonce": "".join(random.sample(string.printable, 12)),
319 "macaroons": macaroons or []
320 }})
321 response = result['response']
322 self.build_facades(response.get('facades', {}))
323 self.info = response.copy()
324 return response
325
326 async def redirect_info(self):
327 try:
328 result = await self.rpc({
329 "type": "Admin",
330 "request": "RedirectInfo",
331 "version": 3,
332 })
333 except JujuAPIError as e:
334 if e.message == 'not redirected':
335 return None
336 raise
337 return result['response']
338
339
340 class JujuData:
341 def __init__(self):
342 self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
343 self.path = os.path.abspath(os.path.expanduser(self.path))
344
345 def current_controller(self):
346 cmd = shlex.split('juju list-controllers --format yaml')
347 output = subprocess.check_output(cmd)
348 output = yaml.safe_load(output)
349 return output.get('current-controller', '')
350
351 def controllers(self):
352 return self._load_yaml('controllers.yaml', 'controllers')
353
354 def models(self):
355 return self._load_yaml('models.yaml', 'controllers')
356
357 def accounts(self):
358 return self._load_yaml('accounts.yaml', 'controllers')
359
360 def _load_yaml(self, filename, key):
361 filepath = os.path.join(self.path, filename)
362 with io.open(filepath, 'rt') as f:
363 return yaml.safe_load(f)[key]
364
365
366 def get_macaroons():
367 """Decode and return macaroons from default ~/.go-cookies
368
369 """
370 try:
371 cookie_file = os.path.expanduser('~/.go-cookies')
372 with open(cookie_file, 'r') as f:
373 cookies = json.load(f)
374 except (OSError, ValueError) as e:
375 log.warn("Couldn't load macaroons from %s", cookie_file)
376 return []
377
378 base64_macaroons = [
379 c['Value'] for c in cookies
380 if c['Name'].startswith('macaroon-') and c['Value']
381 ]
382
383 return [
384 json.loads(base64.b64decode(value).decode('utf-8'))
385 for value in base64_macaroons
386 ]