blob: 7457391877e481c1c237e3bdc8c68a8e6cdadd65 [file] [log] [blame]
import base64
import io
import json
import logging
import os
import random
import shlex
import ssl
import string
import subprocess
import weakref
import websockets
from concurrent.futures import CancelledError
from http.client import HTTPSConnection
from pathlib import Path
import asyncio
import yaml
from juju import tag, utils
from juju.client import client
from juju.errors import JujuError, JujuAPIError, JujuConnectionError
from juju.utils import IdQueue
log = logging.getLogger("websocket")
class Monitor:
"""
Monitor helper class for our Connection class.
Contains a reference to an instantiated Connection, along with a
reference to the Connection.receiver Future. Upon inspecttion of
these objects, this class determines whether the connection is in
an 'error', 'connected' or 'disconnected' state.
Use this class to stay up to date on the health of a connection,
and take appropriate action if the connection errors out due to
network issues or other unexpected circumstances.
"""
ERROR = 'error'
CONNECTED = 'connected'
DISCONNECTING = 'disconnecting'
DISCONNECTED = 'disconnected'
def __init__(self, connection):
self.connection = weakref.ref(connection)
self.reconnecting = asyncio.Lock(loop=connection.loop)
self.close_called = asyncio.Event(loop=connection.loop)
self.receiver_stopped = asyncio.Event(loop=connection.loop)
self.pinger_stopped = asyncio.Event(loop=connection.loop)
self.receiver_stopped.set()
self.pinger_stopped.set()
@property
def status(self):
"""
Determine the status of the connection and receiver, and return
ERROR, CONNECTED, or DISCONNECTED as appropriate.
For simplicity, we only consider ourselves to be connected
after the Connection class has setup a receiver task. This
only happens after the websocket is open, and the connection
isn't usable until that receiver has been started.
"""
connection = self.connection()
# the connection instance was destroyed but someone kept
# a separate reference to the monitor for some reason
if not connection:
return self.DISCONNECTED
# connection cleanly disconnected or not yet opened
if not connection.ws:
return self.DISCONNECTED
# close called but not yet complete
if self.close_called.is_set():
return self.DISCONNECTING
# connection closed uncleanly (we didn't call connection.close)
if self.receiver_stopped.is_set() or not connection.ws.open:
return self.ERROR
# everything is fine!
return self.CONNECTED
class Connection:
"""
Usage::
# Connect to an arbitrary api server
client = await Connection.connect(
api_endpoint, model_uuid, username, password, cacert)
# Connect using a controller/model name
client = await Connection.connect_model('local.local:default')
# Connect to the currently active model
client = await Connection.connect_current()
Note: Any connection method or constructor can accept an optional `loop`
argument to override the default event loop from `asyncio.get_event_loop`.
"""
DEFAULT_FRAME_SIZE = 'default_frame_size'
MAX_FRAME_SIZE = 2**22
"Maximum size for a single frame. Defaults to 4MB."
def __init__(
self, endpoint, uuid, username, password, cacert=None,
macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
self.endpoint = endpoint
self._endpoint = endpoint
self.uuid = uuid
if macaroons:
self.macaroons = macaroons
self.username = ''
self.password = ''
else:
self.macaroons = []
self.username = username
self.password = password
self.cacert = cacert
self._cacert = cacert
self.loop = loop or asyncio.get_event_loop()
self.__request_id__ = 0
self.addr = None
self.ws = None
self.facades = {}
self.messages = IdQueue(loop=self.loop)
self.monitor = Monitor(connection=self)
if max_frame_size is self.DEFAULT_FRAME_SIZE:
max_frame_size = self.MAX_FRAME_SIZE
self.max_frame_size = max_frame_size
@property
def is_open(self):
return self.monitor.status == Monitor.CONNECTED
def _get_ssl(self, cert=None):
return ssl.create_default_context(
purpose=ssl.Purpose.CLIENT_AUTH, cadata=cert)
async def open(self):
if self.uuid:
url = "wss://{}/model/{}/api".format(self.endpoint, self.uuid)
else:
url = "wss://{}/api".format(self.endpoint)
kw = dict()
kw['ssl'] = self._get_ssl(self.cacert)
kw['loop'] = self.loop
kw['max_size'] = self.max_frame_size
self.addr = url
self.ws = await websockets.connect(url, **kw)
self.loop.create_task(self.receiver())
self.monitor.receiver_stopped.clear()
log.info("Driver connected to juju %s", url)
self.monitor.close_called.clear()
return self
async def close(self):
if not self.ws:
return
self.monitor.close_called.set()
await self.monitor.pinger_stopped.wait()
await self.monitor.receiver_stopped.wait()
await self.ws.close()
self.ws = None
async def recv(self, request_id):
if not self.is_open:
raise websockets.exceptions.ConnectionClosed(0, 'websocket closed')
return await self.messages.get(request_id)
async def receiver(self):
try:
while self.is_open:
result = await utils.run_with_interrupt(
self.ws.recv(),
self.monitor.close_called,
loop=self.loop)
if self.monitor.close_called.is_set():
break
if result is not None:
result = json.loads(result)
await self.messages.put(result['request-id'], result)
except CancelledError:
pass
except websockets.ConnectionClosed as e:
log.warning('Receiver: Connection closed, reconnecting')
await self.messages.put_all(e)
# the reconnect has to be done as a task because the receiver will
# be cancelled by the reconnect and we don't want the reconnect
# to be aborted half-way through
self.loop.create_task(self.reconnect())
return
except Exception as e:
log.exception("Error in receiver")
# make pending listeners aware of the error
await self.messages.put_all(e)
raise
finally:
self.monitor.receiver_stopped.set()
async def pinger(self):
'''
A Controller can time us out if we are silent for too long. This
is especially true in JaaS, which has a fairly strict timeout.
To prevent timing out, we send a ping every ten seconds.
'''
async def _do_ping():
try:
await pinger_facade.Ping()
await asyncio.sleep(10, loop=self.loop)
except CancelledError:
pass
pinger_facade = client.PingerFacade.from_connection(self)
try:
while True:
await utils.run_with_interrupt(
_do_ping(),
self.monitor.close_called,
loop=self.loop)
if self.monitor.close_called.is_set():
break
finally:
self.monitor.pinger_stopped.set()
return
async def rpc(self, msg, encoder=None):
self.__request_id__ += 1
msg['request-id'] = self.__request_id__
if'params' not in msg:
msg['params'] = {}
if "version" not in msg:
msg['version'] = self.facades[msg['type']]
outgoing = json.dumps(msg, indent=2, cls=encoder)
for attempt in range(3):
try:
await self.ws.send(outgoing)
break
except websockets.ConnectionClosed:
if attempt == 2:
raise
log.warning('RPC: Connection closed, reconnecting')
# the reconnect has to be done in a separate task because,
# if it is triggered by the pinger, then this RPC call will
# be cancelled when the pinger is cancelled by the reconnect,
# and we don't want the reconnect to be aborted halfway through
await asyncio.wait([self.reconnect()], loop=self.loop)
result = await self.recv(msg['request-id'])
if not result:
return result
if 'error' in result:
# API Error Response
raise JujuAPIError(result)
if 'response' not in result:
# This may never happen
return result
if 'results' in result['response']:
# Check for errors in a result list.
errors = []
for res in result['response']['results']:
if res.get('error', {}).get('message'):
errors.append(res['error']['message'])
if errors:
raise JujuError(errors)
elif result['response'].get('error', {}).get('message'):
raise JujuError(result['response']['error']['message'])
return result
def http_headers(self):
"""Return dictionary of http headers necessary for making an http
connection to the endpoint of this Connection.
:return: Dictionary of headers
"""
if not self.username:
return {}
creds = u'{}:{}'.format(
tag.user(self.username),
self.password or ''
)
token = base64.b64encode(creds.encode())
return {
'Authorization': 'Basic {}'.format(token.decode())
}
def https_connection(self):
"""Return an https connection to this Connection's endpoint.
Returns a 3-tuple containing::
1. The :class:`HTTPSConnection` instance
2. Dictionary of auth headers to be used with the connection
3. The root url path (str) to be used for requests.
"""
endpoint = self.endpoint
host, remainder = endpoint.split(':', 1)
port = remainder
if '/' in remainder:
port, _ = remainder.split('/', 1)
conn = HTTPSConnection(
host, int(port),
context=self._get_ssl(self.cacert),
)
path = (
"/model/{}".format(self.uuid)
if self.uuid else ""
)
return conn, self.http_headers(), path
async def clone(self):
"""Return a new Connection, connected to the same websocket endpoint
as this one.
"""
return await Connection.connect(
self.endpoint,
self.uuid,
self.username,
self.password,
self.cacert,
self.macaroons,
self.loop,
self.max_frame_size,
)
async def controller(self):
"""Return a Connection to the controller at self.endpoint
"""
return await Connection.connect(
self.endpoint,
None,
self.username,
self.password,
self.cacert,
self.macaroons,
self.loop,
)
async def _try_endpoint(self, endpoint, cacert):
success = False
result = None
new_endpoints = []
self.endpoint = endpoint
self.cacert = cacert
await self.open()
try:
result = await self.login()
if 'discharge-required-error' in result['response']:
log.info('Macaroon discharge required, disconnecting')
else:
# successful login!
log.info('Authenticated')
success = True
except JujuAPIError as e:
if e.error_code != 'redirection required':
raise
log.info('Controller requested redirect')
redirect_info = await self.redirect_info()
redir_cacert = redirect_info['ca-cert']
new_endpoints = [
("{value}:{port}".format(**s), redir_cacert)
for servers in redirect_info['servers']
for s in servers if s["scope"] == 'public'
]
finally:
if not success:
await self.close()
return success, result, new_endpoints
async def reconnect(self):
""" Force a reconnection.
"""
monitor = self.monitor
if monitor.reconnecting.locked() or monitor.close_called.is_set():
return
async with monitor.reconnecting:
await self.close()
await self._connect()
async def _connect(self):
endpoints = [(self._endpoint, self._cacert)]
while endpoints:
_endpoint, _cacert = endpoints.pop(0)
success, result, new_endpoints = await self._try_endpoint(
_endpoint, _cacert)
if success:
break
endpoints.extend(new_endpoints)
else:
# ran out of endpoints without a successful login
raise Exception("Couldn't authenticate to {}".format(
self._endpoint))
response = result['response']
self.info = response.copy()
self.build_facades(response.get('facades', {}))
self.loop.create_task(self.pinger())
self.monitor.pinger_stopped.clear()
@classmethod
async def connect(
cls, endpoint, uuid, username, password, cacert=None,
macaroons=None, loop=None, max_frame_size=None):
"""Connect to the websocket.
If uuid is None, the connection will be to the controller. Otherwise it
will be to the model.
"""
client = cls(endpoint, uuid, username, password, cacert, macaroons,
loop, max_frame_size)
await client._connect()
return client
@classmethod
async def connect_current(cls, loop=None, max_frame_size=None):
"""Connect to the currently active model.
"""
jujudata = JujuData()
controller_name = jujudata.current_controller()
if not controller_name:
raise JujuConnectionError('No current controller')
model_name = jujudata.current_model()
return await cls.connect_model(
'{}:{}'.format(controller_name, model_name), loop, max_frame_size)
@classmethod
async def connect_current_controller(cls, loop=None, max_frame_size=None):
"""Connect to the currently active controller.
"""
jujudata = JujuData()
controller_name = jujudata.current_controller()
if not controller_name:
raise JujuConnectionError('No current controller')
return await cls.connect_controller(controller_name, loop,
max_frame_size)
@classmethod
async def connect_controller(cls, controller_name, loop=None,
max_frame_size=None):
"""Connect to a controller by name.
"""
jujudata = JujuData()
controller = jujudata.controllers()[controller_name]
endpoint = controller['api-endpoints'][0]
cacert = controller.get('ca-cert')
accounts = jujudata.accounts()[controller_name]
username = accounts['user']
password = accounts.get('password')
macaroons = get_macaroons(controller_name) if not password else None
return await cls.connect(
endpoint, None, username, password, cacert, macaroons, loop,
max_frame_size)
@classmethod
async def connect_model(cls, model, loop=None, max_frame_size=None):
"""Connect to a model by name.
:param str model: [<controller>:]<model>
"""
jujudata = JujuData()
if ':' in model:
# explicit controller given
controller_name, model_name = model.split(':')
else:
# use the current controller if one isn't explicitly given
controller_name = jujudata.current_controller()
model_name = model
accounts = jujudata.accounts()[controller_name]
username = accounts['user']
# model name must include a user prefix, so add it if it doesn't
if '/' not in model_name:
model_name = '{}/{}'.format(username, model_name)
controller = jujudata.controllers()[controller_name]
endpoint = controller['api-endpoints'][0]
cacert = controller.get('ca-cert')
password = accounts.get('password')
models = jujudata.models()[controller_name]
model_uuid = models['models'][model_name]['uuid']
macaroons = get_macaroons(controller_name) if not password else None
return await cls.connect(
endpoint, model_uuid, username, password, cacert, macaroons, loop,
max_frame_size)
def build_facades(self, facades):
self.facades.clear()
for facade in facades:
self.facades[facade['name']] = facade['versions'][-1]
async def login(self):
username = self.username
if username and not username.startswith('user-'):
username = 'user-{}'.format(username)
result = await self.rpc({
"type": "Admin",
"request": "Login",
"version": 3,
"params": {
"auth-tag": username,
"credentials": self.password,
"nonce": "".join(random.sample(string.printable, 12)),
"macaroons": self.macaroons
}})
return result
async def redirect_info(self):
try:
result = await self.rpc({
"type": "Admin",
"request": "RedirectInfo",
"version": 3,
})
except JujuAPIError as e:
if e.message == 'not redirected':
return None
raise
return result['response']
class JujuData:
def __init__(self):
self.path = os.environ.get('JUJU_DATA') or '~/.local/share/juju'
self.path = os.path.abspath(os.path.expanduser(self.path))
def current_controller(self):
cmd = shlex.split('juju list-controllers --format yaml')
output = subprocess.check_output(cmd)
output = yaml.safe_load(output)
return output.get('current-controller', '')
def current_model(self, controller_name=None):
if not controller_name:
controller_name = self.current_controller()
models = self.models()[controller_name]
if 'current-model' not in models:
raise JujuError('No current model')
return models['current-model']
def controllers(self):
return self._load_yaml('controllers.yaml', 'controllers')
def models(self):
return self._load_yaml('models.yaml', 'controllers')
def accounts(self):
return self._load_yaml('accounts.yaml', 'controllers')
def _load_yaml(self, filename, key):
filepath = os.path.join(self.path, filename)
with io.open(filepath, 'rt') as f:
return yaml.safe_load(f)[key]
def get_macaroons(controller_name=None):
"""Decode and return macaroons from default ~/.go-cookies
"""
cookie_files = []
if controller_name:
cookie_files.append('~/.local/share/juju/cookies/{}.json'.format(
controller_name))
cookie_files.append('~/.go-cookies')
for cookie_file in cookie_files:
cookie_file = Path(cookie_file).expanduser()
if cookie_file.exists():
try:
cookies = json.loads(cookie_file.read_text())
break
except (OSError, ValueError):
log.warn("Couldn't load macaroons from %s", cookie_file)
return []
else:
log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files))
return []
base64_macaroons = [
c['Value'] for c in cookies
if c['Name'].startswith('macaroon-') and c['Value']
]
return [
json.loads(base64.b64decode(value).decode('utf-8'))
for value in base64_macaroons
]