+import asyncio
import base64
-import io
import json
import logging
-import os
-import random
-import shlex
import ssl
-import string
-import subprocess
+import urllib.request
import weakref
-import websockets
from concurrent.futures import CancelledError
from http.client import HTTPSConnection
-from pathlib import Path
-
-import asyncio
-import yaml
-from juju import tag, utils
+import macaroonbakery.httpbakery as httpbakery
+import macaroonbakery.bakery as bakery
+import websockets
+from juju import errors, tag, utils
from juju.client import client
-from juju.errors import JujuError, JujuAPIError, JujuConnectionError
from juju.utils import IdQueue
-log = logging.getLogger("websocket")
+log = logging.getLogger('juju.client.connection')
class Monitor:
Monitor helper class for our Connection class.
Contains a reference to an instantiated Connection, along with a
- reference to the Connection.receiver Future. Upon inspecttion of
+ reference to the Connection.receiver Future. Upon inspection of
these objects, this class determines whether the connection is in
an 'error', 'connected' or 'disconnected' state.
self.connection = weakref.ref(connection)
self.reconnecting = asyncio.Lock(loop=connection.loop)
self.close_called = asyncio.Event(loop=connection.loop)
- self.receiver_stopped = asyncio.Event(loop=connection.loop)
- self.pinger_stopped = asyncio.Event(loop=connection.loop)
- self.receiver_stopped.set()
- self.pinger_stopped.set()
@property
def status(self):
return self.DISCONNECTING
# connection closed uncleanly (we didn't call connection.close)
- if self.receiver_stopped.is_set() or not connection.ws.open:
+ stopped = connection._receiver_task.stopped.is_set()
+ if stopped or not connection.ws.open:
return self.ERROR
# everything is fine!
client = await Connection.connect(
api_endpoint, model_uuid, username, password, cacert)
- # Connect using a controller/model name
- client = await Connection.connect_model('local.local:default')
-
- # Connect to the currently active model
- client = await Connection.connect_current()
-
Note: Any connection method or constructor can accept an optional `loop`
argument to override the default event loop from `asyncio.get_event_loop`.
"""
- DEFAULT_FRAME_SIZE = 'default_frame_size'
MAX_FRAME_SIZE = 2**22
"Maximum size for a single frame. Defaults to 4MB."
- def __init__(
- self, endpoint, uuid, username, password, cacert=None,
- macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
- self.endpoint = endpoint
- self._endpoint = endpoint
+ @classmethod
+ async def connect(
+ cls,
+ endpoint=None,
+ uuid=None,
+ username=None,
+ password=None,
+ cacert=None,
+ bakery_client=None,
+ loop=None,
+ max_frame_size=None,
+ ):
+ """Connect to the websocket.
+
+ If uuid is None, the connection will be to the controller. Otherwise it
+ will be to the model.
+ :param str endpoint The hostname:port of the controller to connect to.
+ :param str uuid The model UUID to connect to (None for a
+ controller-only connection).
+ :param str username The username for controller-local users (or None
+ to use macaroon-based login.)
+ :param str password The password for controller-local users.
+ :param str cacert The CA certificate of the controller (PEM formatted).
+ :param httpbakery.Client bakery_client The macaroon bakery client to
+ to use when performing macaroon-based login. Macaroon tokens
+ acquired when logging will be saved to bakery_client.cookies.
+ If this is None, a default bakery_client will be used.
+ :param loop asyncio.BaseEventLoop The event loop to use for async
+ operations.
+ :param max_frame_size The maximum websocket frame size to allow.
+ """
+ self = cls()
+ if endpoint is None:
+ raise ValueError('no endpoint provided')
self.uuid = uuid
- if macaroons:
- self.macaroons = macaroons
- self.username = ''
- self.password = ''
- else:
- self.macaroons = []
- self.username = username
- self.password = password
- self.cacert = cacert
- self._cacert = cacert
+ if bakery_client is None:
+ bakery_client = httpbakery.Client()
+ self.bakery_client = bakery_client
+ if username and '@' in username and not username.endswith('@local'):
+ # We're trying to log in as an external user - we need to use
+ # macaroon authentication with no username or password.
+ if password is not None:
+ raise errors.JujuAuthError('cannot log in as external '
+ 'user with a password')
+ username = None
+ self.usertag = tag.user(username)
+ self.password = password
self.loop = loop or asyncio.get_event_loop()
self.__request_id__ = 0
+
+ # The following instance variables are initialized by the
+ # _connect_with_redirect method, but create them here
+ # as a reminder that they will exist.
self.addr = None
self.ws = None
+ self.endpoint = None
+ self.cacert = None
+ self.info = None
+
+ # Create that _Task objects but don't start the tasks yet.
+ self._pinger_task = _Task(self._pinger, self.loop)
+ self._receiver_task = _Task(self._receiver, self.loop)
+
self.facades = {}
self.messages = IdQueue(loop=self.loop)
self.monitor = Monitor(connection=self)
- if max_frame_size is self.DEFAULT_FRAME_SIZE:
+ if max_frame_size is None:
max_frame_size = self.MAX_FRAME_SIZE
self.max_frame_size = max_frame_size
+ await self._connect_with_redirect([(endpoint, cacert)])
+ return self
+
+ @property
+ def username(self):
+ if not self.usertag:
+ return None
+ return self.usertag[len('user-'):]
@property
def is_open(self):
return ssl.create_default_context(
purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
- async def open(self):
+ async def _open(self, endpoint, cacert):
if self.uuid:
- url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
+ url = "wss://{}/model/{}/api".format(endpoint, self.uuid)
else:
- url = "wss://{}/api".format(self.endpoint)
-
- kw = dict()
- kw['ssl'] = self._get_ssl(self.cacert)
- kw['loop'] = self.loop
- kw['max_size'] = self.max_frame_size
- self.addr = url
- self.ws = await websockets.connect(url, **kw)
- self.loop.create_task(self.receiver())
- self.monitor.receiver_stopped.clear()
- log.info("Driver connected to juju %s", url)
- self.monitor.close_called.clear()
- return self
+ url = "wss://{}/api".format(endpoint)
+
+ return (await websockets.connect(
+ url,
+ ssl=self._get_ssl(cacert),
+ loop=self.loop,
+ max_size=self.max_frame_size,
+ ), url, endpoint, cacert)
async def close(self):
if not self.ws:
return
self.monitor.close_called.set()
- await self.monitor.pinger_stopped.wait()
- await self.monitor.receiver_stopped.wait()
+ await self._pinger_task.stopped.wait()
+ await self._receiver_task.stopped.wait()
await self.ws.close()
self.ws = None
- async def recv(self, request_id):
+ async def _recv(self, request_id):
if not self.is_open:
raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
return await self.messages.get(request_id)
- async def receiver(self):
+ async def _receiver(self):
try:
while self.is_open:
result = await utils.run_with_interrupt(
# make pending listeners aware of the error
await self.messages.put_all(e)
raise
- finally:
- self.monitor.receiver_stopped.set()
- async def pinger(self):
+ async def _pinger(self):
'''
A Controller can time us out if we are silent for too long. This
is especially true in JaaS, which has a fairly strict timeout.
loop=self.loop)
if self.monitor.close_called.is_set():
break
- finally:
- self.monitor.pinger_stopped.set()
- return
+ except websockets.exceptions.ConnectionClosed:
+ # The connection has closed - we can't do anything
+ # more until the connection is restarted.
+ log.debug('ping failed because of closed connection')
+ pass
async def rpc(self, msg, encoder=None):
+ '''Make an RPC to the API. The message is encoded as JSON
+ using the given encoder if any.
+ :param msg: Parameters for the call (will be encoded as JSON).
+ :param encoder: Encoder to be used when encoding the message.
+ :return: The result of the call.
+ :raises JujuAPIError: When there's an error returned.
+ :raises JujuError:
+ '''
self.__request_id__ += 1
msg['request-id'] = self.__request_id__
if'params' not in msg:
if "version" not in msg:
msg['version'] = self.facades[msg['type']]
outgoing = json.dumps(msg, indent=2, cls=encoder)
+ log.debug('connection {} -> {}'.format(id(self), outgoing))
for attempt in range(3):
+ if self.monitor.status == Monitor.DISCONNECTED:
+ # closed cleanly; shouldn't try to reconnect
+ raise websockets.exceptions.ConnectionClosed(
+ 0, 'websocket closed')
try:
await self.ws.send(outgoing)
break
# be cancelled when the pinger is cancelled by the reconnect,
# and we don't want the reconnect to be aborted halfway through
await asyncio.wait([self.reconnect()], loop=self.loop)
- result = await self.recv(msg['request-id'])
+ if self.monitor.status != Monitor.CONNECTED:
+ # reconnect failed; abort and shutdown
+ log.error('RPC: Automatic reconnect failed')
+ raise
+ result = await self._recv(msg['request-id'])
+ log.debug('connection {} <- {}'.format(id(self), result))
if not result:
return result
if 'error' in result:
# API Error Response
- raise JujuAPIError(result)
+ raise errors.JujuAPIError(result)
if 'response' not in result:
# This may never happen
if 'results' in result['response']:
# Check for errors in a result list.
- errors = []
+ # TODO This loses the results that might have succeeded.
+ # Perhaps JujuError should return all the results including
+ # errors, or perhaps a keyword parameter to the rpc method
+ # could be added to trigger this behaviour.
+ err_results = []
for res in result['response']['results']:
if res.get('error', {}).get('message'):
- errors.append(res['error']['message'])
- if errors:
- raise JujuError(errors)
+ err_results.append(res['error']['message'])
+ if err_results:
+ raise errors.JujuError(err_results)
elif result['response'].get('error', {}).get('message'):
- raise JujuError(result['response']['error']['message'])
+ raise errors.JujuError(result['response']['error']['message'])
return result
- def http_headers(self):
+ def _http_headers(self):
"""Return dictionary of http headers necessary for making an http
connection to the endpoint of this Connection.
:return: Dictionary of headers
"""
- if not self.username:
+ if not self.usertag:
return {}
creds = u'{}:{}'.format(
- tag.user(self.username),
+ self.usertag,
self.password or ''
)
token = base64.b64encode(creds.encode())
"/model/{}".format(self.uuid)
if self.uuid else ""
)
- return conn, self.http_headers(), path
+ return conn, self._http_headers(), path
async def clone(self):
"""Return a new Connection, connected to the same websocket endpoint
as this one.
"""
- return await Connection.connect(
- self.endpoint,
- self.uuid,
- self.username,
- self.password,
- self.cacert,
- self.macaroons,
- self.loop,
- self.max_frame_size,
- )
+ return await Connection.connect(**self.connect_params())
+
+ def connect_params(self):
+ """Return a tuple of parameters suitable for passing to
+ Connection.connect that can be used to make a new connection
+ to the same controller (and model if specified. The first
+ element in the returned tuple holds the endpoint argument;
+ the other holds a dict of the keyword args.
+ """
+ return {
+ 'endpoint': self.endpoint,
+ 'uuid': self.uuid,
+ 'username': self.username,
+ 'password': self.password,
+ 'cacert': self.cacert,
+ 'bakery_client': self.bakery_client,
+ 'loop': self.loop,
+ 'max_frame_size': self.max_frame_size,
+ }
async def controller(self):
"""Return a Connection to the controller at self.endpoint
-
"""
return await Connection.connect(
self.endpoint,
- None,
- self.username,
- self.password,
- self.cacert,
- self.macaroons,
- self.loop,
+ username=self.username,
+ password=self.password,
+ cacert=self.cacert,
+ bakery_client=self.bakery_client,
+ loop=self.loop,
+ max_frame_size=self.max_frame_size,
)
- async def _try_endpoint(self, endpoint, cacert):
- success = False
- result = None
- new_endpoints = []
-
- self.endpoint = endpoint
- self.cacert = cacert
- await self.open()
- try:
- result = await self.login()
- if 'discharge-required-error' in result['response']:
- log.info('Macaroon discharge required, disconnecting')
- else:
- # successful login!
- log.info('Authenticated')
- success = True
- except JujuAPIError as e:
- if e.error_code != 'redirection required':
- raise
- log.info('Controller requested redirect')
- redirect_info = await self.redirect_info()
- redir_cacert = redirect_info['ca-cert']
- new_endpoints = [
- ("{value}:{port}".format(**s), redir_cacert)
- for servers in redirect_info['servers']
- for s in servers if s["scope"] == 'public'
- ]
- finally:
- if not success:
- await self.close()
- return success, result, new_endpoints
-
async def reconnect(self):
""" Force a reconnection.
"""
return
async with monitor.reconnecting:
await self.close()
- await self._connect()
-
- async def _connect(self):
- endpoints = [(self._endpoint, self._cacert)]
- while endpoints:
- _endpoint, _cacert = endpoints.pop(0)
- success, result, new_endpoints = await self._try_endpoint(
- _endpoint, _cacert)
- if success:
+ await self._connect_with_login([(self.endpoint, self.cacert)])
+
+ async def _connect(self, endpoints):
+ if len(endpoints) == 0:
+ raise errors.JujuConnectionError('no endpoints to connect to')
+
+ async def _try_endpoint(endpoint, cacert, delay):
+ if delay:
+ await asyncio.sleep(delay)
+ return await self._open(endpoint, cacert)
+
+ # Try all endpoints in parallel, with slight increasing delay (+100ms
+ # for each subsequent endpoint); the delay allows us to prefer the
+ # earlier endpoints over the latter. Use first successful connection.
+ tasks = [self.loop.create_task(_try_endpoint(endpoint, cacert,
+ 0.1 * i))
+ for i, (endpoint, cacert) in enumerate(endpoints)]
+ for task in asyncio.as_completed(tasks, loop=self.loop):
+ try:
+ result = await task
break
- endpoints.extend(new_endpoints)
+ except ConnectionError:
+ continue # ignore; try another endpoint
else:
- # ran out of endpoints without a successful login
- raise Exception("Couldn't authenticate to {}".format(
- self._endpoint))
-
- response = result['response']
- self.info = response.copy()
- self.build_facades(response.get('facades', {}))
- self.loop.create_task(self.pinger())
- self.monitor.pinger_stopped.clear()
+ raise errors.JujuConnectionError(
+ 'Unable to connect to any endpoint: {}'.format(', '.join([
+ endpoint for endpoint, cacert in endpoints])))
+ for task in tasks:
+ task.cancel()
+ self.ws = result[0]
+ self.addr = result[1]
+ self.endpoint = result[2]
+ self.cacert = result[3]
+ self._receiver_task.start()
+ log.info("Driver connected to juju %s", self.addr)
+ self.monitor.close_called.clear()
- @classmethod
- async def connect(
- cls, endpoint, uuid, username, password, cacert=None,
- macaroons=None, loop=None, max_frame_size=None):
+ async def _connect_with_login(self, endpoints):
"""Connect to the websocket.
If uuid is None, the connection will be to the controller. Otherwise it
will be to the model.
-
- """
- client = cls(endpoint, uuid, username, password, cacert, macaroons,
- loop, max_frame_size)
- await client._connect()
- return client
-
- @classmethod
- async def connect_current(cls, loop=None, max_frame_size=None):
- """Connect to the currently active model.
-
+ :return: The response field of login response JSON object.
"""
- jujudata = JujuData()
-
- controller_name = jujudata.current_controller()
- if not controller_name:
- raise JujuConnectionError('No current controller')
-
- model_name = jujudata.current_model()
-
- return await cls.connect_model(
- '{}:{}'.format(controller_name, model_name), loop, max_frame_size)
-
- @classmethod
- async def connect_current_controller(cls, loop=None, max_frame_size=None):
- """Connect to the currently active controller.
-
- """
- jujudata = JujuData()
- controller_name = jujudata.current_controller()
- if not controller_name:
- raise JujuConnectionError('No current controller')
-
- return await cls.connect_controller(controller_name, loop,
- max_frame_size)
-
- @classmethod
- async def connect_controller(cls, controller_name, loop=None,
- max_frame_size=None):
- """Connect to a controller by name.
-
- """
- jujudata = JujuData()
- controller = jujudata.controllers()[controller_name]
- endpoint = controller['api-endpoints'][0]
- cacert = controller.get('ca-cert')
- accounts = jujudata.accounts()[controller_name]
- username = accounts['user']
- password = accounts.get('password')
- macaroons = get_macaroons(controller_name) if not password else None
-
- return await cls.connect(
- endpoint, None, username, password, cacert, macaroons, loop,
- max_frame_size)
-
- @classmethod
- async def connect_model(cls, model, loop=None, max_frame_size=None):
- """Connect to a model by name.
-
- :param str model: [<controller>:]<model>
+ success = False
+ try:
+ await self._connect(endpoints)
+ # It's possible that we may get several discharge-required errors,
+ # corresponding to different levels of authentication, so retry
+ # a few times.
+ for i in range(0, 2):
+ result = (await self.login())['response']
+ macaroonJSON = result.get('discharge-required')
+ if macaroonJSON is None:
+ self.info = result
+ success = True
+ return result
+ macaroon = bakery.Macaroon.from_dict(macaroonJSON)
+ self.bakery_client.handle_error(
+ httpbakery.Error(
+ code=httpbakery.ERR_DISCHARGE_REQUIRED,
+ message=result.get('discharge-required-error'),
+ version=macaroon.version,
+ info=httpbakery.ErrorInfo(
+ macaroon=macaroon,
+ macaroon_path=result.get('macaroon-path'),
+ ),
+ ),
+ # note: remove the port number.
+ 'https://' + self.endpoint + '/',
+ )
+ raise errors.JujuAuthError('failed to authenticate '
+ 'after several attempts')
+ finally:
+ if not success:
+ await self.close()
- """
- jujudata = JujuData()
+ async def _connect_with_redirect(self, endpoints):
+ try:
+ login_result = await self._connect_with_login(endpoints)
+ except errors.JujuRedirectException as e:
+ login_result = await self._connect_with_login(e.endpoints)
+ self._build_facades(login_result.get('facades', {}))
+ self._pinger_task.start()
- if ':' in model:
- # explicit controller given
- controller_name, model_name = model.split(':')
- else:
- # use the current controller if one isn't explicitly given
- controller_name = jujudata.current_controller()
- model_name = model
-
- accounts = jujudata.accounts()[controller_name]
- username = accounts['user']
- # model name must include a user prefix, so add it if it doesn't
- if '/' not in model_name:
- model_name = '{}/{}'.format(username, model_name)
-
- controller = jujudata.controllers()[controller_name]
- endpoint = controller['api-endpoints'][0]
- cacert = controller.get('ca-cert')
- password = accounts.get('password')
- models = jujudata.models()[controller_name]
- model_uuid = models['models'][model_name]['uuid']
- macaroons = get_macaroons(controller_name) if not password else None
-
- return await cls.connect(
- endpoint, model_uuid, username, password, cacert, macaroons, loop,
- max_frame_size)
-
- def build_facades(self, facades):
+ def _build_facades(self, facades):
self.facades.clear()
for facade in facades:
self.facades[facade['name']] = facade['versions'][-1]
async def login(self):
- username = self.username
- if username and not username.startswith('user-'):
- username = 'user-{}'.format(username)
-
- result = await self.rpc({
- "type": "Admin",
- "request": "Login",
- "version": 3,
- "params": {
- "auth-tag": username,
- "credentials": self.password,
- "nonce": "".join(random.sample(string.printable, 12)),
- "macaroons": self.macaroons
- }})
- return result
+ params = {}
+ if self.password:
+ params['auth-tag'] = self.usertag
+ params['credentials'] = self.password
+ else:
+ macaroons = _macaroons_for_domain(self.bakery_client.cookies,
+ self.endpoint)
+ params['macaroons'] = [[bakery.macaroon_to_dict(m) for m in ms]
+ for ms in macaroons]
- async def redirect_info(self):
try:
- result = await self.rpc({
+ return await self.rpc({
"type": "Admin",
- "request": "RedirectInfo",
+ "request": "Login",
"version": 3,
+ "params": params,
})
- except JujuAPIError as e:
- if e.message == 'not redirected':
- return None
- raise
- return result['response']
-
-
-class JujuData:
- def __init__(self):
- self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
- self.path = os.path.abspath(os.path.expanduser(self.path))
-
- def current_controller(self):
- cmd = shlex.split('juju list-controllers --format yaml')
- output = subprocess.check_output(cmd)
- output = yaml.safe_load(output)
- return output.get('current-controller', '')
-
- def current_model(self, controller_name=None):
- if not controller_name:
- controller_name = self.current_controller()
- models = self.models()[controller_name]
- if 'current-model' not in models:
- raise JujuError('No current model')
- return models['current-model']
-
- def controllers(self):
- return self._load_yaml('controllers.yaml', 'controllers')
-
- def models(self):
- return self._load_yaml('models.yaml', 'controllers')
-
- def accounts(self):
- return self._load_yaml('accounts.yaml', 'controllers')
-
- def _load_yaml(self, filename, key):
- filepath = os.path.join(self.path, filename)
- with io.open(filepath, 'rt') as f:
- return yaml.safe_load(f)[key]
+ except errors.JujuAPIError as e:
+ if e.error_code != 'redirection required':
+ raise
+ log.info('Controller requested redirect')
+ # Fetch additional redirection information now so that
+ # we can safely close the connection after login
+ # fails.
+ redirect_info = (await self.rpc({
+ "type": "Admin",
+ "request": "RedirectInfo",
+ "version": 3,
+ }))['response']
+ raise errors.JujuRedirectException(redirect_info) from e
-def get_macaroons(controller_name=None):
- """Decode and return macaroons from default ~/.go-cookies
+class _Task:
+ def __init__(self, task, loop):
+ self.stopped = asyncio.Event(loop=loop)
+ self.stopped.set()
+ self.task = task
+ self.loop = loop
- """
- cookie_files = []
- if controller_name:
- cookie_files.append('~/.local/share/juju/cookies/{}.json'.format(
- controller_name))
- cookie_files.append('~/.go-cookies')
- for cookie_file in cookie_files:
- cookie_file = Path(cookie_file).expanduser()
- if cookie_file.exists():
+ def start(self):
+ async def run():
try:
- cookies = json.loads(cookie_file.read_text())
- break
- except (OSError, ValueError):
- log.warn("Couldn't load macaroons from %s", cookie_file)
- return []
- else:
- log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files))
- return []
-
- base64_macaroons = [
- c['Value'] for c in cookies
- if c['Name'].startswith('macaroon-') and c['Value']
- ]
-
- return [
- json.loads(base64.b64decode(value).decode('utf-8'))
- for value in base64_macaroons
- ]
+ return await(self.task())
+ finally:
+ self.stopped.set()
+ self.stopped.clear()
+ self.loop.create_task(run())
+
+
+def _macaroons_for_domain(cookies, domain):
+ '''Return any macaroons from the given cookie jar that
+ apply to the given domain name.'''
+ req = urllib.request.Request('https://' + domain + '/')
+ cookies.add_cookie_header(req)
+ return httpbakery.extract_macaroons(req)