Refactor to use IdQueue pattern
[osm/N2VC.git] / juju / client / connection.py
1 import base64
2 import io
3 import json
4 import logging
5 import os
6 import random
7 import shlex
8 import ssl
9 import string
10 import subprocess
11 import websockets
12 from http.client import HTTPSConnection
13
14 import asyncio
15 import yaml
16
17 from juju import tag
18 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
19 from juju.utils import IdQueue
20
21 log = logging.getLogger("websocket")
22
23
24 class Connection:
25 """
26 Usage::
27
28 # Connect to an arbitrary api server
29 client = await Connection.connect(
30 api_endpoint, model_uuid, username, password, cacert)
31
32 # Connect using a controller/model name
33 client = await Connection.connect_model('local.local:default')
34
35 # Connect to the currently active model
36 client = await Connection.connect_current()
37
38 Note: Any connection method or constructor can accept an optional `loop`
39 argument to override the default event loop from `asyncio.get_event_loop`.
40 """
41 def __init__(
42 self, endpoint, uuid, username, password, cacert=None,
43 macaroons=None, loop=None):
44 self.endpoint = endpoint
45 self.uuid = uuid
46 self.username = username
47 self.password = password
48 self.macaroons = macaroons
49 self.cacert = cacert
50 self.loop = loop or asyncio.get_event_loop()
51
52 self.__request_id__ = 0
53 self.addr = None
54 self.ws = None
55 self.facades = {}
56 self.messages = IdQueue(loop=self.loop)
57
58 @property
59 def is_open(self):
60 if self.ws:
61 return self.ws.open
62 return False
63
64 def _get_ssl(self, cert=None):
65 return ssl.create_default_context(
66 purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
67
68 async def open(self):
69 if self.uuid:
70 url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
71 else:
72 url = "wss://{}/api".format(self.endpoint)
73
74 kw = dict()
75 kw['ssl'] = self._get_ssl(self.cacert)
76 kw['loop'] = self.loop
77 self.addr = url
78 self.ws = await websockets.connect(url, **kw)
79 log.info("Driver connected to juju %s", url)
80 return self
81
82 async def close(self):
83 await self.ws.close()
84
85 async def recv(self, request_id):
86 return await self.messages.get(request_id)
87
88 async def receiver(self):
89 while self.is_open:
90 result = await self.ws.recv()
91 if result is not None:
92 result = json.loads(result)
93 await self.messages.put(result['request-id'], result)
94
95 async def rpc(self, msg, encoder=None):
96 self.__request_id__ += 1
97 msg['request-id'] = self.__request_id__
98 if'params' not in msg:
99 msg['params'] = {}
100 if "version" not in msg:
101 msg['version'] = self.facades[msg['type']]
102 outgoing = json.dumps(msg, indent=2, cls=encoder)
103 await self.ws.send(outgoing)
104 result = await self.recv(msg['request-id'])
105
106 if not result:
107 return result
108
109 if 'error' in result:
110 # API Error Response
111 raise JujuAPIError(result)
112
113 if not 'response' in result:
114 # This may never happen
115 return result
116
117 if 'results' in result['response']:
118 # Check for errors in a result list.
119 errors = []
120 for res in result['response']['results']:
121 if res.get('error', {}).get('message'):
122 errors.append(res['error']['message'])
123 if errors:
124 raise JujuError(errors)
125
126 elif result['response'].get('error', {}).get('message'):
127 raise JujuError(result['response']['error']['message'])
128
129 return result
130
131 def http_headers(self):
132 """Return dictionary of http headers necessary for making an http
133 connection to the endpoint of this Connection.
134
135 :return: Dictionary of headers
136
137 """
138 if not self.username:
139 return {}
140
141 creds = u'{}:{}'.format(
142 tag.user(self.username),
143 self.password or ''
144 )
145 token = base64.b64encode(creds.encode())
146 return {
147 'Authorization': 'Basic {}'.format(token.decode())
148 }
149
150 def https_connection(self):
151 """Return an https connection to this Connection's endpoint.
152
153 Returns a 3-tuple containing::
154
155 1. The :class:`HTTPSConnection` instance
156 2. Dictionary of auth headers to be used with the connection
157 3. The root url path (str) to be used for requests.
158
159 """
160 endpoint = self.endpoint
161 host, remainder = endpoint.split(':', 1)
162 port = remainder
163 if '/' in remainder:
164 port, _ = remainder.split('/', 1)
165
166 conn = HTTPSConnection(
167 host, int(port),
168 context=self._get_ssl(self.cacert),
169 )
170
171 path = (
172 "/model/{}".format(self.uuid)
173 if self.uuid else ""
174 )
175 return conn, self.http_headers(), path
176
177 async def clone(self):
178 """Return a new Connection, connected to the same websocket endpoint
179 as this one.
180
181 """
182 return await Connection.connect(
183 self.endpoint,
184 self.uuid,
185 self.username,
186 self.password,
187 self.cacert,
188 self.macaroons,
189 self.loop,
190 )
191
192 async def controller(self):
193 """Return a Connection to the controller at self.endpoint
194
195 """
196 return await Connection.connect(
197 self.endpoint,
198 None,
199 self.username,
200 self.password,
201 self.cacert,
202 self.macaroons,
203 self.loop,
204 )
205
206 @classmethod
207 async def connect(
208 cls, endpoint, uuid, username, password, cacert=None,
209 macaroons=None, loop=None):
210 """Connect to the websocket.
211
212 If uuid is None, the connection will be to the controller. Otherwise it
213 will be to the model.
214
215 """
216 client = cls(endpoint, uuid, username, password, cacert, macaroons,
217 loop)
218 await client.open()
219 client.loop.create_task(client.receiver)
220
221 redirect_info = await client.redirect_info()
222 if not redirect_info:
223 await client.login(username, password, macaroons)
224 return client
225
226 await client.close()
227 servers = [
228 s for servers in redirect_info['servers']
229 for s in servers if s["scope"] == 'public'
230 ]
231 for server in servers:
232 client = cls(
233 "{value}:{port}".format(**server), uuid, username,
234 password, redirect_info['ca-cert'], macaroons)
235 await client.open()
236 try:
237 result = await client.login(username, password, macaroons)
238 if 'discharge-required-error' in result:
239 continue
240 return client
241 except Exception as e:
242 await client.close()
243 log.exception(e)
244
245 raise Exception(
246 "Couldn't authenticate to %s", endpoint)
247
248 @classmethod
249 async def connect_current(cls, loop=None):
250 """Connect to the currently active model.
251
252 """
253 jujudata = JujuData()
254 controller_name = jujudata.current_controller()
255 model_name = jujudata.current_model()
256
257 return await cls.connect_model(
258 '{}:{}'.format(controller_name, model_name), loop)
259
260 @classmethod
261 async def connect_current_controller(cls, loop=None):
262 """Connect to the currently active controller.
263
264 """
265 jujudata = JujuData()
266 controller_name = jujudata.current_controller()
267 if not controller_name:
268 raise JujuConnectionError('No current controller')
269
270 return await cls.connect_controller(controller_name, loop)
271
272 @classmethod
273 async def connect_controller(cls, controller_name, loop=None):
274 """Connect to a controller by name.
275
276 """
277 jujudata = JujuData()
278 controller = jujudata.controllers()[controller_name]
279 endpoint = controller['api-endpoints'][0]
280 cacert = controller.get('ca-cert')
281 accounts = jujudata.accounts()[controller_name]
282 username = accounts['user']
283 password = accounts.get('password')
284 macaroons = get_macaroons() if not password else None
285
286 return await cls.connect(
287 endpoint, None, username, password, cacert, macaroons, loop)
288
289 @classmethod
290 async def connect_model(cls, model, loop=None):
291 """Connect to a model by name.
292
293 :param str model: [<controller>:]<model>
294
295 """
296 jujudata = JujuData()
297
298 if ':' in model:
299 # explicit controller given
300 controller_name, model_name = model.split(':')
301 else:
302 # use the current controller if one isn't explicitly given
303 controller_name = jujudata.current_controller()
304 model_name = model
305
306 accounts = jujudata.accounts()[controller_name]
307 username = accounts['user']
308 # model name must include a user prefix, so add it if it doesn't
309 if '/' not in model_name:
310 model_name = '{}/{}'.format(username, model_name)
311
312 controller = jujudata.controllers()[controller_name]
313 endpoint = controller['api-endpoints'][0]
314 cacert = controller.get('ca-cert')
315 password = accounts.get('password')
316 models = jujudata.models()[controller_name]
317 model_uuid = models['models'][model_name]['uuid']
318 macaroons = get_macaroons() if not password else None
319
320 return await cls.connect(
321 endpoint, model_uuid, username, password, cacert, macaroons, loop)
322
323 def build_facades(self, info):
324 self.facades.clear()
325 for facade in info:
326 self.facades[facade['name']] = facade['versions'][-1]
327
328 async def login(self, username, password, macaroons=None):
329 if macaroons:
330 username = ''
331 password = ''
332
333 if username and not username.startswith('user-'):
334 username = 'user-{}'.format(username)
335
336 result = await self.rpc({
337 "type": "Admin",
338 "request": "Login",
339 "version": 3,
340 "params": {
341 "auth-tag": username,
342 "credentials": password,
343 "nonce": "".join(random.sample(string.printable, 12)),
344 "macaroons": macaroons or []
345 }})
346 response = result['response']
347 self.build_facades(response.get('facades', {}))
348 self.info = response.copy()
349 return response
350
351 async def redirect_info(self):
352 try:
353 result = await self.rpc({
354 "type": "Admin",
355 "request": "RedirectInfo",
356 "version": 3,
357 })
358 except JujuAPIError as e:
359 if e.message == 'not redirected':
360 return None
361 raise
362 return result['response']
363
364
365 class JujuData:
366 def __init__(self):
367 self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
368 self.path = os.path.abspath(os.path.expanduser(self.path))
369
370 def current_controller(self):
371 cmd = shlex.split('juju list-controllers --format yaml')
372 output = subprocess.check_output(cmd)
373 output = yaml.safe_load(output)
374 return output.get('current-controller', '')
375
376 def current_model(self, controller_name=None):
377 if not controller_name:
378 controller_name = self.current_controller()
379 models = self.models()[controller_name]
380 if 'current-model' not in models:
381 raise JujuError('No current model')
382 return models['current-model']
383
384 def controllers(self):
385 return self._load_yaml('controllers.yaml', 'controllers')
386
387 def models(self):
388 return self._load_yaml('models.yaml', 'controllers')
389
390 def accounts(self):
391 return self._load_yaml('accounts.yaml', 'controllers')
392
393 def _load_yaml(self, filename, key):
394 filepath = os.path.join(self.path, filename)
395 with io.open(filepath, 'rt') as f:
396 return yaml.safe_load(f)[key]
397
398
399 def get_macaroons():
400 """Decode and return macaroons from default ~/.go-cookies
401
402 """
403 try:
404 cookie_file = os.path.expanduser('~/.go-cookies')
405 with open(cookie_file, 'r') as f:
406 cookies = json.load(f)
407 except (OSError, ValueError):
408 log.warn("Couldn't load macaroons from %s", cookie_file)
409 return []
410
411 base64_macaroons = [
412 c['Value'] for c in cookies
413 if c['Name'].startswith('macaroon-') and c['Value']
414 ]
415
416 return [
417 json.loads(base64.b64decode(value).decode('utf-8'))
418 for value in base64_macaroons
419 ]