13 from concurrent
.futures
import CancelledError
14 from http
.client
import HTTPSConnection
15 from pathlib
import Path
20 from juju
import tag
, utils
21 from juju
.client
import client
22 from juju
.errors
import JujuError
, JujuAPIError
, JujuConnectionError
23 from juju
.utils
import IdQueue
25 log
= logging
.getLogger("websocket")
30 Monitor helper class for our Connection class.
32 Contains a reference to an instantiated Connection, along with a
33 reference to the Connection.receiver Future. Upon inspecttion of
34 these objects, this class determines whether the connection is in
35 an 'error', 'connected' or 'disconnected' state.
37 Use this class to stay up to date on the health of a connection,
38 and take appropriate action if the connection errors out due to
39 network issues or other unexpected circumstances.
43 CONNECTED
= 'connected'
44 DISCONNECTING
= 'disconnecting'
45 DISCONNECTED
= 'disconnected'
47 def __init__(self
, connection
):
48 self
.connection
= weakref
.ref(connection
)
49 self
.reconnecting
= asyncio
.Lock(loop
=connection
.loop
)
50 self
.close_called
= asyncio
.Event(loop
=connection
.loop
)
51 self
.receiver_stopped
= asyncio
.Event(loop
=connection
.loop
)
52 self
.pinger_stopped
= asyncio
.Event(loop
=connection
.loop
)
53 self
.receiver_stopped
.set()
54 self
.pinger_stopped
.set()
59 Determine the status of the connection and receiver, and return
60 ERROR, CONNECTED, or DISCONNECTED as appropriate.
62 For simplicity, we only consider ourselves to be connected
63 after the Connection class has setup a receiver task. This
64 only happens after the websocket is open, and the connection
65 isn't usable until that receiver has been started.
68 connection
= self
.connection()
70 # the connection instance was destroyed but someone kept
71 # a separate reference to the monitor for some reason
73 return self
.DISCONNECTED
75 # connection cleanly disconnected or not yet opened
77 return self
.DISCONNECTED
79 # close called but not yet complete
80 if self
.close_called
.is_set():
81 return self
.DISCONNECTING
83 # connection closed uncleanly (we didn't call connection.close)
84 if self
.receiver_stopped
.is_set() or not connection
.ws
.open:
95 # Connect to an arbitrary api server
96 client = await Connection.connect(
97 api_endpoint, model_uuid, username, password, cacert)
99 # Connect using a controller/model name
100 client = await Connection.connect_model('local.local:default')
102 # Connect to the currently active model
103 client = await Connection.connect_current()
105 Note: Any connection method or constructor can accept an optional `loop`
106 argument to override the default event loop from `asyncio.get_event_loop`.
109 DEFAULT_FRAME_SIZE
= 'default_frame_size'
110 MAX_FRAME_SIZE
= 2**22
111 "Maximum size for a single frame. Defaults to 4MB."
114 self
, endpoint
, uuid
, username
, password
, cacert
=None,
115 macaroons
=None, loop
=None, max_frame_size
=DEFAULT_FRAME_SIZE
):
116 self
.endpoint
= endpoint
117 self
._endpoint
= endpoint
120 self
.macaroons
= macaroons
125 self
.username
= username
126 self
.password
= password
128 self
._cacert
= cacert
129 self
.loop
= loop
or asyncio
.get_event_loop()
131 self
.__request
_id
__ = 0
135 self
.messages
= IdQueue(loop
=self
.loop
)
136 self
.monitor
= Monitor(connection
=self
)
137 if max_frame_size
is self
.DEFAULT_FRAME_SIZE
:
138 max_frame_size
= self
.MAX_FRAME_SIZE
139 self
.max_frame_size
= max_frame_size
143 return self
.monitor
.status
== Monitor
.CONNECTED
145 def _get_ssl(self
, cert
=None):
146 return ssl
.create_default_context(
147 purpose
=ssl
.Purpose
.CLIENT_AUTH
, cadata
=cert
)
149 async def open(self
):
151 url
= "wss://{}/model/{}/api".format(self
.endpoint
, self
.uuid
)
153 url
= "wss://{}/api".format(self
.endpoint
)
156 kw
['ssl'] = self
._get
_ssl
(self
.cacert
)
157 kw
['loop'] = self
.loop
158 kw
['max_size'] = self
.max_frame_size
160 self
.ws
= await websockets
.connect(url
, **kw
)
161 self
.loop
.create_task(self
.receiver())
162 self
.monitor
.receiver_stopped
.clear()
163 log
.info("Driver connected to juju %s", url
)
164 self
.monitor
.close_called
.clear()
167 async def close(self
):
170 self
.monitor
.close_called
.set()
171 await self
.monitor
.pinger_stopped
.wait()
172 await self
.monitor
.receiver_stopped
.wait()
173 await self
.ws
.close()
176 async def recv(self
, request_id
):
178 raise websockets
.exceptions
.ConnectionClosed(0, 'websocket closed')
179 return await self
.messages
.get(request_id
)
181 async def receiver(self
):
184 result
= await utils
.run_with_interrupt(
186 self
.monitor
.close_called
,
188 if self
.monitor
.close_called
.is_set():
190 if result
is not None:
191 result
= json
.loads(result
)
192 await self
.messages
.put(result
['request-id'], result
)
193 except CancelledError
:
195 except websockets
.ConnectionClosed
as e
:
196 log
.warning('Receiver: Connection closed, reconnecting')
197 await self
.messages
.put_all(e
)
198 # the reconnect has to be done as a task because the receiver will
199 # be cancelled by the reconnect and we don't want the reconnect
200 # to be aborted half-way through
201 self
.loop
.create_task(self
.reconnect())
203 except Exception as e
:
204 log
.exception("Error in receiver")
205 # make pending listeners aware of the error
206 await self
.messages
.put_all(e
)
209 self
.monitor
.receiver_stopped
.set()
211 async def pinger(self
):
213 A Controller can time us out if we are silent for too long. This
214 is especially true in JaaS, which has a fairly strict timeout.
216 To prevent timing out, we send a ping every ten seconds.
219 async def _do_ping():
221 await pinger_facade
.Ping()
222 await asyncio
.sleep(10, loop
=self
.loop
)
223 except CancelledError
:
226 pinger_facade
= client
.PingerFacade
.from_connection(self
)
229 await utils
.run_with_interrupt(
231 self
.monitor
.close_called
,
233 if self
.monitor
.close_called
.is_set():
236 self
.monitor
.pinger_stopped
.set()
239 async def rpc(self
, msg
, encoder
=None):
240 self
.__request
_id
__ += 1
241 msg
['request-id'] = self
.__request
_id
__
242 if'params' not in msg
:
244 if "version" not in msg
:
245 msg
['version'] = self
.facades
[msg
['type']]
246 outgoing
= json
.dumps(msg
, indent
=2, cls
=encoder
)
247 for attempt
in range(3):
249 await self
.ws
.send(outgoing
)
251 except websockets
.ConnectionClosed
:
254 log
.warning('RPC: Connection closed, reconnecting')
255 # the reconnect has to be done in a separate task because,
256 # if it is triggered by the pinger, then this RPC call will
257 # be cancelled when the pinger is cancelled by the reconnect,
258 # and we don't want the reconnect to be aborted halfway through
259 await asyncio
.wait([self
.reconnect()], loop
=self
.loop
)
260 result
= await self
.recv(msg
['request-id'])
265 if 'error' in result
:
267 raise JujuAPIError(result
)
269 if 'response' not in result
:
270 # This may never happen
273 if 'results' in result
['response']:
274 # Check for errors in a result list.
276 for res
in result
['response']['results']:
277 if res
.get('error', {}).get('message'):
278 errors
.append(res
['error']['message'])
280 raise JujuError(errors
)
282 elif result
['response'].get('error', {}).get('message'):
283 raise JujuError(result
['response']['error']['message'])
287 def http_headers(self
):
288 """Return dictionary of http headers necessary for making an http
289 connection to the endpoint of this Connection.
291 :return: Dictionary of headers
294 if not self
.username
:
297 creds
= u
'{}:{}'.format(
298 tag
.user(self
.username
),
301 token
= base64
.b64encode(creds
.encode())
303 'Authorization': 'Basic {}'.format(token
.decode())
306 def https_connection(self
):
307 """Return an https connection to this Connection's endpoint.
309 Returns a 3-tuple containing::
311 1. The :class:`HTTPSConnection` instance
312 2. Dictionary of auth headers to be used with the connection
313 3. The root url path (str) to be used for requests.
316 endpoint
= self
.endpoint
317 host
, remainder
= endpoint
.split(':', 1)
320 port
, _
= remainder
.split('/', 1)
322 conn
= HTTPSConnection(
324 context
=self
._get
_ssl
(self
.cacert
),
328 "/model/{}".format(self
.uuid
)
331 return conn
, self
.http_headers(), path
333 async def clone(self
):
334 """Return a new Connection, connected to the same websocket endpoint
338 return await Connection
.connect(
349 async def controller(self
):
350 """Return a Connection to the controller at self.endpoint
353 return await Connection
.connect(
363 async def _try_endpoint(self
, endpoint
, cacert
):
368 self
.endpoint
= endpoint
372 result
= await self
.login()
373 if 'discharge-required-error' in result
['response']:
374 log
.info('Macaroon discharge required, disconnecting')
377 log
.info('Authenticated')
379 except JujuAPIError
as e
:
380 if e
.error_code
!= 'redirection required':
382 log
.info('Controller requested redirect')
383 redirect_info
= await self
.redirect_info()
384 redir_cacert
= redirect_info
['ca-cert']
386 ("{value}:{port}".format(**s
), redir_cacert
)
387 for servers
in redirect_info
['servers']
388 for s
in servers
if s
["scope"] == 'public'
393 return success
, result
, new_endpoints
395 async def reconnect(self
):
396 """ Force a reconnection.
398 monitor
= self
.monitor
399 if monitor
.reconnecting
.locked() or monitor
.close_called
.is_set():
401 async with monitor
.reconnecting
:
403 await self
._connect
()
405 async def _connect(self
):
406 endpoints
= [(self
._endpoint
, self
._cacert
)]
408 _endpoint
, _cacert
= endpoints
.pop(0)
409 success
, result
, new_endpoints
= await self
._try
_endpoint
(
413 endpoints
.extend(new_endpoints
)
415 # ran out of endpoints without a successful login
416 raise Exception("Couldn't authenticate to {}".format(
419 response
= result
['response']
420 self
.info
= response
.copy()
421 self
.build_facades(response
.get('facades', {}))
422 self
.loop
.create_task(self
.pinger())
423 self
.monitor
.pinger_stopped
.clear()
427 cls
, endpoint
, uuid
, username
, password
, cacert
=None,
428 macaroons
=None, loop
=None, max_frame_size
=None):
429 """Connect to the websocket.
431 If uuid is None, the connection will be to the controller. Otherwise it
432 will be to the model.
435 client
= cls(endpoint
, uuid
, username
, password
, cacert
, macaroons
,
436 loop
, max_frame_size
)
437 await client
._connect
()
441 async def connect_current(cls
, loop
=None, max_frame_size
=None):
442 """Connect to the currently active model.
445 jujudata
= JujuData()
447 controller_name
= jujudata
.current_controller()
448 if not controller_name
:
449 raise JujuConnectionError('No current controller')
451 model_name
= jujudata
.current_model()
453 return await cls
.connect_model(
454 '{}:{}'.format(controller_name
, model_name
), loop
, max_frame_size
)
457 async def connect_current_controller(cls
, loop
=None, max_frame_size
=None):
458 """Connect to the currently active controller.
461 jujudata
= JujuData()
462 controller_name
= jujudata
.current_controller()
463 if not controller_name
:
464 raise JujuConnectionError('No current controller')
466 return await cls
.connect_controller(controller_name
, loop
,
470 async def connect_controller(cls
, controller_name
, loop
=None,
471 max_frame_size
=None):
472 """Connect to a controller by name.
475 jujudata
= JujuData()
476 controller
= jujudata
.controllers()[controller_name
]
477 endpoint
= controller
['api-endpoints'][0]
478 cacert
= controller
.get('ca-cert')
479 accounts
= jujudata
.accounts()[controller_name
]
480 username
= accounts
['user']
481 password
= accounts
.get('password')
482 macaroons
= get_macaroons(controller_name
) if not password
else None
484 return await cls
.connect(
485 endpoint
, None, username
, password
, cacert
, macaroons
, loop
,
489 async def connect_model(cls
, model
, loop
=None, max_frame_size
=None):
490 """Connect to a model by name.
492 :param str model: [<controller>:]<model>
495 jujudata
= JujuData()
498 # explicit controller given
499 controller_name
, model_name
= model
.split(':')
501 # use the current controller if one isn't explicitly given
502 controller_name
= jujudata
.current_controller()
505 accounts
= jujudata
.accounts()[controller_name
]
506 username
= accounts
['user']
507 # model name must include a user prefix, so add it if it doesn't
508 if '/' not in model_name
:
509 model_name
= '{}/{}'.format(username
, model_name
)
511 controller
= jujudata
.controllers()[controller_name
]
512 endpoint
= controller
['api-endpoints'][0]
513 cacert
= controller
.get('ca-cert')
514 password
= accounts
.get('password')
515 models
= jujudata
.models()[controller_name
]
516 model_uuid
= models
['models'][model_name
]['uuid']
517 macaroons
= get_macaroons(controller_name
) if not password
else None
519 return await cls
.connect(
520 endpoint
, model_uuid
, username
, password
, cacert
, macaroons
, loop
,
523 def build_facades(self
, facades
):
525 for facade
in facades
:
526 self
.facades
[facade
['name']] = facade
['versions'][-1]
528 async def login(self
):
529 username
= self
.username
530 if username
and not username
.startswith('user-'):
531 username
= 'user-{}'.format(username
)
533 result
= await self
.rpc({
538 "auth-tag": username
,
539 "credentials": self
.password
,
540 "nonce": "".join(random
.sample(string
.printable
, 12)),
541 "macaroons": self
.macaroons
545 async def redirect_info(self
):
547 result
= await self
.rpc({
549 "request": "RedirectInfo",
552 except JujuAPIError
as e
:
553 if e
.message
== 'not redirected':
556 return result
['response']
561 self
.path
= os
.environ
.get('JUJU_DATA') or '~/.local/share/juju'
562 self
.path
= os
.path
.abspath(os
.path
.expanduser(self
.path
))
564 def current_controller(self
):
565 cmd
= shlex
.split('juju list-controllers --format yaml')
566 output
= subprocess
.check_output(cmd
)
567 output
= yaml
.safe_load(output
)
568 return output
.get('current-controller', '')
570 def current_model(self
, controller_name
=None):
571 if not controller_name
:
572 controller_name
= self
.current_controller()
573 models
= self
.models()[controller_name
]
574 if 'current-model' not in models
:
575 raise JujuError('No current model')
576 return models
['current-model']
578 def controllers(self
):
579 return self
._load
_yaml
('controllers.yaml', 'controllers')
582 return self
._load
_yaml
('models.yaml', 'controllers')
585 return self
._load
_yaml
('accounts.yaml', 'controllers')
587 def _load_yaml(self
, filename
, key
):
588 filepath
= os
.path
.join(self
.path
, filename
)
589 with io
.open(filepath
, 'rt') as f
:
590 return yaml
.safe_load(f
)[key
]
593 def get_macaroons(controller_name
=None):
594 """Decode and return macaroons from default ~/.go-cookies
599 cookie_files
.append('~/.local/share/juju/cookies/{}.json'.format(
601 cookie_files
.append('~/.go-cookies')
602 for cookie_file
in cookie_files
:
603 cookie_file
= Path(cookie_file
).expanduser()
604 if cookie_file
.exists():
606 cookies
= json
.loads(cookie_file
.read_text())
608 except (OSError, ValueError):
609 log
.warn("Couldn't load macaroons from %s", cookie_file
)
612 log
.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files
))
616 c
['Value'] for c
in cookies
617 if c
['Name'].startswith('macaroon-') and c
['Value']
621 json
.loads(base64
.b64decode(value
).decode('utf-8'))
622 for value
in base64_macaroons