import ssl
import string
import subprocess
+import weakref
import websockets
+from concurrent.futures import CancelledError
+from http.client import HTTPSConnection
+from pathlib import Path
+import asyncio
import yaml
-from juju.errors import JujuAPIError
+from juju import tag, utils
+from juju.client import client
+from juju.errors import JujuError, JujuAPIError, JujuConnectionError
+from juju.utils import IdQueue
log = logging.getLogger("websocket")
+class Monitor:
+ """
+ Monitor helper class for our Connection class.
+
+ Contains a reference to an instantiated Connection, along with a
+ reference to the Connection.receiver Future. Upon inspecttion of
+ these objects, this class determines whether the connection is in
+ an 'error', 'connected' or 'disconnected' state.
+
+ Use this class to stay up to date on the health of a connection,
+ and take appropriate action if the connection errors out due to
+ network issues or other unexpected circumstances.
+
+ """
+ ERROR = 'error'
+ CONNECTED = 'connected'
+ DISCONNECTING = 'disconnecting'
+ DISCONNECTED = 'disconnected'
+
+ def __init__(self, connection):
+ self.connection = weakref.ref(connection)
+ self.reconnecting = asyncio.Lock(loop=connection.loop)
+ self.close_called = asyncio.Event(loop=connection.loop)
+ self.receiver_stopped = asyncio.Event(loop=connection.loop)
+ self.pinger_stopped = asyncio.Event(loop=connection.loop)
+ self.receiver_stopped.set()
+ self.pinger_stopped.set()
+
+ @property
+ def status(self):
+ """
+ Determine the status of the connection and receiver, and return
+ ERROR, CONNECTED, or DISCONNECTED as appropriate.
+
+ For simplicity, we only consider ourselves to be connected
+ after the Connection class has setup a receiver task. This
+ only happens after the websocket is open, and the connection
+ isn't usable until that receiver has been started.
+
+ """
+ connection = self.connection()
+
+ # the connection instance was destroyed but someone kept
+ # a separate reference to the monitor for some reason
+ if not connection:
+ return self.DISCONNECTED
+
+ # connection cleanly disconnected or not yet opened
+ if not connection.ws:
+ return self.DISCONNECTED
+
+ # close called but not yet complete
+ if self.close_called.is_set():
+ return self.DISCONNECTING
+
+ # connection closed uncleanly (we didn't call connection.close)
+ if self.receiver_stopped.is_set() or not connection.ws.open:
+ return self.ERROR
+
+ # everything is fine!
+ return self.CONNECTED
+
+
class Connection:
"""
Usage::
# Connect to the currently active model
client = await Connection.connect_current()
+ Note: Any connection method or constructor can accept an optional `loop`
+ argument to override the default event loop from `asyncio.get_event_loop`.
"""
+
+ DEFAULT_FRAME_SIZE = 'default_frame_size'
+ MAX_FRAME_SIZE = 2**22
+ "Maximum size for a single frame. Defaults to 4MB."
+
def __init__(
self, endpoint, uuid, username, password, cacert=None,
- macaroons=None):
+ macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
self.endpoint = endpoint
+ self._endpoint = endpoint
self.uuid = uuid
- self.username = username
- self.password = password
- self.macaroons = macaroons
+ if macaroons:
+ self.macaroons = macaroons
+ self.username = ''
+ self.password = ''
+ else:
+ self.macaroons = []
+ self.username = username
+ self.password = password
self.cacert = cacert
+ self._cacert = cacert
+ self.loop = loop or asyncio.get_event_loop()
self.__request_id__ = 0
self.addr = None
self.ws = None
self.facades = {}
+ self.messages = IdQueue(loop=self.loop)
+ self.monitor = Monitor(connection=self)
+ if max_frame_size is self.DEFAULT_FRAME_SIZE:
+ max_frame_size = self.MAX_FRAME_SIZE
+ self.max_frame_size = max_frame_size
@property
def is_open(self):
- if self.ws:
- return self.ws.open
- return False
+ return self.monitor.status == Monitor.CONNECTED
def _get_ssl(self, cert=None):
return ssl.create_default_context(
kw = dict()
kw['ssl'] = self._get_ssl(self.cacert)
+ kw['loop'] = self.loop
+ kw['max_size'] = self.max_frame_size
self.addr = url
self.ws = await websockets.connect(url, **kw)
+ self.loop.create_task(self.receiver())
+ self.monitor.receiver_stopped.clear()
log.info("Driver connected to juju %s", url)
+ self.monitor.close_called.clear()
return self
async def close(self):
+ if not self.ws:
+ return
+ self.monitor.close_called.set()
+ await self.monitor.pinger_stopped.wait()
+ await self.monitor.receiver_stopped.wait()
await self.ws.close()
+ self.ws = None
- async def recv(self):
- result = await self.ws.recv()
- if result is not None:
- result = json.loads(result)
- return result
+ async def recv(self, request_id):
+ if not self.is_open:
+ raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
+ return await self.messages.get(request_id)
+
+ async def receiver(self):
+ try:
+ while self.is_open:
+ result = await utils.run_with_interrupt(
+ self.ws.recv(),
+ self.monitor.close_called,
+ loop=self.loop)
+ if self.monitor.close_called.is_set():
+ break
+ if result is not None:
+ result = json.loads(result)
+ await self.messages.put(result['request-id'], result)
+ except CancelledError:
+ pass
+ except websockets.ConnectionClosed as e:
+ log.warning('Receiver: Connection closed, reconnecting')
+ await self.messages.put_all(e)
+ # the reconnect has to be done as a task because the receiver will
+ # be cancelled by the reconnect and we don't want the reconnect
+ # to be aborted half-way through
+ self.loop.create_task(self.reconnect())
+ return
+ except Exception as e:
+ log.exception("Error in receiver")
+ # make pending listeners aware of the error
+ await self.messages.put_all(e)
+ raise
+ finally:
+ self.monitor.receiver_stopped.set()
+
+ async def pinger(self):
+ '''
+ A Controller can time us out if we are silent for too long. This
+ is especially true in JaaS, which has a fairly strict timeout.
+
+ To prevent timing out, we send a ping every ten seconds.
+
+ '''
+ async def _do_ping():
+ try:
+ await pinger_facade.Ping()
+ await asyncio.sleep(10, loop=self.loop)
+ except CancelledError:
+ pass
+
+ pinger_facade = client.PingerFacade.from_connection(self)
+ try:
+ while True:
+ await utils.run_with_interrupt(
+ _do_ping(),
+ self.monitor.close_called,
+ loop=self.loop)
+ if self.monitor.close_called.is_set():
+ break
+ finally:
+ self.monitor.pinger_stopped.set()
+ return
async def rpc(self, msg, encoder=None):
self.__request_id__ += 1
if "version" not in msg:
msg['version'] = self.facades[msg['type']]
outgoing = json.dumps(msg, indent=2, cls=encoder)
- await self.ws.send(outgoing)
- result = await self.recv()
- if result and 'error' in result:
+ for attempt in range(3):
+ try:
+ await self.ws.send(outgoing)
+ break
+ except websockets.ConnectionClosed:
+ if attempt == 2:
+ raise
+ log.warning('RPC: Connection closed, reconnecting')
+ # the reconnect has to be done in a separate task because,
+ # if it is triggered by the pinger, then this RPC call will
+ # be cancelled when the pinger is cancelled by the reconnect,
+ # and we don't want the reconnect to be aborted halfway through
+ await asyncio.wait([self.reconnect()], loop=self.loop)
+ result = await self.recv(msg['request-id'])
+
+ if not result:
+ return result
+
+ if 'error' in result:
+ # API Error Response
raise JujuAPIError(result)
+
+ if 'response' not in result:
+ # This may never happen
+ return result
+
+ if 'results' in result['response']:
+ # Check for errors in a result list.
+ errors = []
+ for res in result['response']['results']:
+ if res.get('error', {}).get('message'):
+ errors.append(res['error']['message'])
+ if errors:
+ raise JujuError(errors)
+
+ elif result['response'].get('error', {}).get('message'):
+ raise JujuError(result['response']['error']['message'])
+
return result
+ def http_headers(self):
+ """Return dictionary of http headers necessary for making an http
+ connection to the endpoint of this Connection.
+
+ :return: Dictionary of headers
+
+ """
+ if not self.username:
+ return {}
+
+ creds = u'{}:{}'.format(
+ tag.user(self.username),
+ self.password or ''
+ )
+ token = base64.b64encode(creds.encode())
+ return {
+ 'Authorization': 'Basic {}'.format(token.decode())
+ }
+
+ def https_connection(self):
+ """Return an https connection to this Connection's endpoint.
+
+ Returns a 3-tuple containing::
+
+ 1. The :class:`HTTPSConnection` instance
+ 2. Dictionary of auth headers to be used with the connection
+ 3. The root url path (str) to be used for requests.
+
+ """
+ endpoint = self.endpoint
+ host, remainder = endpoint.split(':', 1)
+ port = remainder
+ if '/' in remainder:
+ port, _ = remainder.split('/', 1)
+
+ conn = HTTPSConnection(
+ host, int(port),
+ context=self._get_ssl(self.cacert),
+ )
+
+ path = (
+ "/model/{}".format(self.uuid)
+ if self.uuid else ""
+ )
+ return conn, self.http_headers(), path
+
async def clone(self):
"""Return a new Connection, connected to the same websocket endpoint
as this one.
self.password,
self.cacert,
self.macaroons,
+ self.loop,
+ self.max_frame_size,
)
async def controller(self):
self.password,
self.cacert,
self.macaroons,
+ self.loop,
)
+ async def _try_endpoint(self, endpoint, cacert):
+ success = False
+ result = None
+ new_endpoints = []
+
+ self.endpoint = endpoint
+ self.cacert = cacert
+ await self.open()
+ try:
+ result = await self.login()
+ if 'discharge-required-error' in result['response']:
+ log.info('Macaroon discharge required, disconnecting')
+ else:
+ # successful login!
+ log.info('Authenticated')
+ success = True
+ except JujuAPIError as e:
+ if e.error_code != 'redirection required':
+ raise
+ log.info('Controller requested redirect')
+ redirect_info = await self.redirect_info()
+ redir_cacert = redirect_info['ca-cert']
+ new_endpoints = [
+ ("{value}:{port}".format(**s), redir_cacert)
+ for servers in redirect_info['servers']
+ for s in servers if s["scope"] == 'public'
+ ]
+ finally:
+ if not success:
+ await self.close()
+ return success, result, new_endpoints
+
+ async def reconnect(self):
+ """ Force a reconnection.
+ """
+ monitor = self.monitor
+ if monitor.reconnecting.locked() or monitor.close_called.is_set():
+ return
+ async with monitor.reconnecting:
+ await self.close()
+ await self._connect()
+
+ async def _connect(self):
+ endpoints = [(self._endpoint, self._cacert)]
+ while endpoints:
+ _endpoint, _cacert = endpoints.pop(0)
+ success, result, new_endpoints = await self._try_endpoint(
+ _endpoint, _cacert)
+ if success:
+ break
+ endpoints.extend(new_endpoints)
+ else:
+ # ran out of endpoints without a successful login
+ raise Exception("Couldn't authenticate to {}".format(
+ self._endpoint))
+
+ response = result['response']
+ self.info = response.copy()
+ self.build_facades(response.get('facades', {}))
+ self.loop.create_task(self.pinger())
+ self.monitor.pinger_stopped.clear()
+
@classmethod
async def connect(
cls, endpoint, uuid, username, password, cacert=None,
- macaroons=None):
+ macaroons=None, loop=None, max_frame_size=None):
"""Connect to the websocket.
If uuid is None, the connection will be to the controller. Otherwise it
will be to the model.
"""
- client = cls(endpoint, uuid, username, password, cacert, macaroons)
- await client.open()
-
- redirect_info = await client.redirect_info()
- if not redirect_info:
- server_info = await client.login(username, password, macaroons)
- client.build_facades(server_info['facades'])
- return client
-
- await client.close()
- servers = [
- s for servers in redirect_info['servers']
- for s in servers if s["scope"] == 'public'
- ]
- for server in servers:
- client = cls(
- "{value}:{port}".format(**server), uuid, username,
- password, redirect_info['ca-cert'], macaroons)
- await client.open()
- try:
- result = await client.login(username, password, macaroons)
- if 'discharge-required-error' in result:
- continue
- client.build_facades(result['facades'])
- return client
- except Exception as e:
- await client.close()
- log.exception(e)
-
- raise Exception(
- "Couldn't authenticate to %s", endpoint)
+ client = cls(endpoint, uuid, username, password, cacert, macaroons,
+ loop, max_frame_size)
+ await client._connect()
+ return client
@classmethod
- async def connect_current(cls):
+ async def connect_current(cls, loop=None, max_frame_size=None):
"""Connect to the currently active model.
"""
jujudata = JujuData()
+
controller_name = jujudata.current_controller()
- models = jujudata.models()[controller_name]
- model_name = models['current-model']
+ if not controller_name:
+ raise JujuConnectionError('No current controller')
+
+ model_name = jujudata.current_model()
return await cls.connect_model(
- '{}:{}'.format(controller_name, model_name))
+ '{}:{}'.format(controller_name, model_name), loop, max_frame_size)
@classmethod
- async def connect_model(cls, model):
- """Connect to a model by name.
-
- :param str model: <controller>:<model>
+ async def connect_current_controller(cls, loop=None, max_frame_size=None):
+ """Connect to the currently active controller.
"""
- controller_name, model_name = model.split(':')
+ jujudata = JujuData()
+ controller_name = jujudata.current_controller()
+ if not controller_name:
+ raise JujuConnectionError('No current controller')
+
+ return await cls.connect_controller(controller_name, loop,
+ max_frame_size)
+
+ @classmethod
+ async def connect_controller(cls, controller_name, loop=None,
+ max_frame_size=None):
+ """Connect to a controller by name.
+ """
jujudata = JujuData()
controller = jujudata.controllers()[controller_name]
endpoint = controller['api-endpoints'][0]
accounts = jujudata.accounts()[controller_name]
username = accounts['user']
password = accounts.get('password')
+ macaroons = get_macaroons(controller_name) if not password else None
+
+ return await cls.connect(
+ endpoint, None, username, password, cacert, macaroons, loop,
+ max_frame_size)
+
+ @classmethod
+ async def connect_model(cls, model, loop=None, max_frame_size=None):
+ """Connect to a model by name.
+
+ :param str model: [<controller>:]<model>
+
+ """
+ jujudata = JujuData()
+
+ if ':' in model:
+ # explicit controller given
+ controller_name, model_name = model.split(':')
+ else:
+ # use the current controller if one isn't explicitly given
+ controller_name = jujudata.current_controller()
+ model_name = model
+
+ accounts = jujudata.accounts()[controller_name]
+ username = accounts['user']
+ # model name must include a user prefix, so add it if it doesn't
+ if '/' not in model_name:
+ model_name = '{}/{}'.format(username, model_name)
+
+ controller = jujudata.controllers()[controller_name]
+ endpoint = controller['api-endpoints'][0]
+ cacert = controller.get('ca-cert')
+ password = accounts.get('password')
models = jujudata.models()[controller_name]
model_uuid = models['models'][model_name]['uuid']
- macaroons = get_macaroons() if not password else None
+ macaroons = get_macaroons(controller_name) if not password else None
return await cls.connect(
- endpoint, model_uuid, username, password, cacert, macaroons)
+ endpoint, model_uuid, username, password, cacert, macaroons, loop,
+ max_frame_size)
- def build_facades(self, info):
+ def build_facades(self, facades):
self.facades.clear()
- for facade in info:
+ for facade in facades:
self.facades[facade['name']] = facade['versions'][-1]
- async def login(self, username, password, macaroons=None):
- if macaroons:
- username = ''
- password = ''
-
+ async def login(self):
+ username = self.username
if username and not username.startswith('user-'):
username = 'user-{}'.format(username)
"version": 3,
"params": {
"auth-tag": username,
- "credentials": password,
+ "credentials": self.password,
"nonce": "".join(random.sample(string.printable, 12)),
- "macaroons": macaroons or []
+ "macaroons": self.macaroons
}})
- return result['response']
+ return result
async def redirect_info(self):
try:
self.path = os.path.abspath(os.path.expanduser(self.path))
def current_controller(self):
- cmd = shlex.split('juju show-controller --format yaml')
+ cmd = shlex.split('juju list-controllers --format yaml')
output = subprocess.check_output(cmd)
output = yaml.safe_load(output)
- return list(output.keys())[0]
+ return output.get('current-controller', '')
+
+ def current_model(self, controller_name=None):
+ if not controller_name:
+ controller_name = self.current_controller()
+ models = self.models()[controller_name]
+ if 'current-model' not in models:
+ raise JujuError('No current model')
+ return models['current-model']
def controllers(self):
return self._load_yaml('controllers.yaml', 'controllers')
return yaml.safe_load(f)[key]
-def get_macaroons():
+def get_macaroons(controller_name=None):
"""Decode and return macaroons from default ~/.go-cookies
"""
- try:
- cookie_file = os.path.expanduser('~/.go-cookies')
- with open(cookie_file, 'r') as f:
- cookies = json.load(f)
- except (OSError, ValueError) as e:
- log.warn("Couldn't load macaroons from %s", cookie_file)
+ cookie_files = []
+ if controller_name:
+ cookie_files.append('~/.local/share/juju/cookies/{}.json'.format(
+ controller_name))
+ cookie_files.append('~/.go-cookies')
+ for cookie_file in cookie_files:
+ cookie_file = Path(cookie_file).expanduser()
+ if cookie_file.exists():
+ try:
+ cookies = json.loads(cookie_file.read_text())
+ break
+ except (OSError, ValueError):
+ log.warn("Couldn't load macaroons from %s", cookie_file)
+ return []
+ else:
+ log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files))
return []
base64_macaroons = [