import ssl
import string
import subprocess
+import weakref
import websockets
from concurrent.futures import CancelledError
from http.client import HTTPSConnection
"""
ERROR = 'error'
CONNECTED = 'connected'
+ DISCONNECTING = 'disconnecting'
DISCONNECTED = 'disconnected'
- UNKNOWN = 'unknown'
def __init__(self, connection):
- self.connection = connection
- self.close_called = asyncio.Event(loop=self.connection.loop)
- self.receiver_stopped = asyncio.Event(loop=self.connection.loop)
- self.pinger_stopped = asyncio.Event(loop=self.connection.loop)
+ self.connection = weakref.ref(connection)
+ self.reconnecting = asyncio.Lock(loop=connection.loop)
+ self.close_called = asyncio.Event(loop=connection.loop)
+ self.receiver_stopped = asyncio.Event(loop=connection.loop)
+ self.pinger_stopped = asyncio.Event(loop=connection.loop)
self.receiver_stopped.set()
self.pinger_stopped.set()
isn't usable until that receiver has been started.
"""
+ connection = self.connection()
- # DISCONNECTED: connection not yet open
- if not self.connection.ws:
+ # the connection instance was destroyed but someone kept
+ # a separate reference to the monitor for some reason
+ if not connection:
return self.DISCONNECTED
- if self.receiver_stopped.is_set():
- return self.DISCONNECTED
-
- # ERROR: Connection closed (or errored), but we didn't call
- # connection.close
- if not self.close_called.is_set() and self.receiver_stopped.is_set():
- return self.ERROR
- if not self.close_called.is_set() and not self.connection.ws.open:
- # The check for self.receiver_stopped existing above guards
- # against the case where we're not open because we simply
- # haven't setup the connection yet.
- return self.ERROR
- # DISCONNECTED: cleanly disconnected.
- if self.close_called.is_set() and not self.connection.ws.open:
+ # connection cleanly disconnected or not yet opened
+ if not connection.ws:
return self.DISCONNECTED
- # CONNECTED: everything is fine!
- if self.connection.ws.open:
- return self.CONNECTED
+ # close called but not yet complete
+ if self.close_called.is_set():
+ return self.DISCONNECTING
+
+ # connection closed uncleanly (we didn't call connection.close)
+ if self.receiver_stopped.is_set() or not connection.ws.open:
+ return self.ERROR
- # UNKNOWN: We should never hit this state -- if we do,
- # something went wrong with the logic above, and we do not
- # know what state the connection is in.
- return self.UNKNOWN
+ # everything is fine!
+ return self.CONNECTED
class Connection:
Note: Any connection method or constructor can accept an optional `loop`
argument to override the default event loop from `asyncio.get_event_loop`.
"""
+
+ DEFAULT_FRAME_SIZE = 'default_frame_size'
+ MAX_FRAME_SIZE = 2**22
+ "Maximum size for a single frame. Defaults to 4MB."
+
def __init__(
self, endpoint, uuid, username, password, cacert=None,
- macaroons=None, loop=None):
+ macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
self.endpoint = endpoint
+ self._endpoint = endpoint
self.uuid = uuid
if macaroons:
self.macaroons = macaroons
self.username = username
self.password = password
self.cacert = cacert
+ self._cacert = cacert
self.loop = loop or asyncio.get_event_loop()
self.__request_id__ = 0
self.facades = {}
self.messages = IdQueue(loop=self.loop)
self.monitor = Monitor(connection=self)
+ if max_frame_size is self.DEFAULT_FRAME_SIZE:
+ max_frame_size = self.MAX_FRAME_SIZE
+ self.max_frame_size = max_frame_size
@property
def is_open(self):
- if self.ws:
- return self.ws.open
- return False
+ return self.monitor.status == Monitor.CONNECTED
def _get_ssl(self, cert=None):
return ssl.create_default_context(
kw = dict()
kw['ssl'] = self._get_ssl(self.cacert)
kw['loop'] = self.loop
+ kw['max_size'] = self.max_frame_size
self.addr = url
self.ws = await websockets.connect(url, **kw)
self.loop.create_task(self.receiver())
return self
async def close(self):
- if not self.is_open:
+ if not self.ws:
return
self.monitor.close_called.set()
await self.monitor.pinger_stopped.wait()
await self.monitor.receiver_stopped.wait()
await self.ws.close()
+ self.ws = None
async def recv(self, request_id):
if not self.is_open:
await self.messages.put(result['request-id'], result)
except CancelledError:
pass
- except Exception as e:
+ except websockets.ConnectionClosed as e:
+ log.warning('Receiver: Connection closed, reconnecting')
await self.messages.put_all(e)
- if isinstance(e, websockets.ConnectionClosed):
- # ConnectionClosed is not really exceptional for us,
- # but it may be for any pending message listeners
- return
+ # the reconnect has to be done as a task because the receiver will
+ # be cancelled by the reconnect and we don't want the reconnect
+ # to be aborted half-way through
+ self.loop.create_task(self.reconnect())
+ return
+ except Exception as e:
log.exception("Error in receiver")
+ # make pending listeners aware of the error
+ await self.messages.put_all(e)
raise
finally:
self.monitor.receiver_stopped.set()
pinger_facade = client.PingerFacade.from_connection(self)
try:
- while self.is_open:
+ while True:
await utils.run_with_interrupt(
_do_ping(),
self.monitor.close_called,
break
finally:
self.monitor.pinger_stopped.set()
+ return
async def rpc(self, msg, encoder=None):
self.__request_id__ += 1
if "version" not in msg:
msg['version'] = self.facades[msg['type']]
outgoing = json.dumps(msg, indent=2, cls=encoder)
- await self.ws.send(outgoing)
+ for attempt in range(3):
+ try:
+ await self.ws.send(outgoing)
+ break
+ except websockets.ConnectionClosed:
+ if attempt == 2:
+ raise
+ log.warning('RPC: Connection closed, reconnecting')
+ # the reconnect has to be done in a separate task because,
+ # if it is triggered by the pinger, then this RPC call will
+ # be cancelled when the pinger is cancelled by the reconnect,
+ # and we don't want the reconnect to be aborted halfway through
+ await asyncio.wait([self.reconnect()], loop=self.loop)
result = await self.recv(msg['request-id'])
if not result:
self.cacert,
self.macaroons,
self.loop,
+ self.max_frame_size,
)
async def controller(self):
await self.close()
return success, result, new_endpoints
- @classmethod
- async def connect(
- cls, endpoint, uuid, username, password, cacert=None,
- macaroons=None, loop=None):
- """Connect to the websocket.
-
- If uuid is None, the connection will be to the controller. Otherwise it
- will be to the model.
-
+ async def reconnect(self):
+ """ Force a reconnection.
"""
- client = cls(endpoint, uuid, username, password, cacert, macaroons,
- loop)
- endpoints = [(endpoint, cacert)]
+ monitor = self.monitor
+ if monitor.reconnecting.locked() or monitor.close_called.is_set():
+ return
+ async with monitor.reconnecting:
+ await self.close()
+ await self._connect()
+
+ async def _connect(self):
+ endpoints = [(self._endpoint, self._cacert)]
while endpoints:
_endpoint, _cacert = endpoints.pop(0)
- success, result, new_endpoints = await client._try_endpoint(
+ success, result, new_endpoints = await self._try_endpoint(
_endpoint, _cacert)
if success:
break
endpoints.extend(new_endpoints)
else:
# ran out of endpoints without a successful login
- raise Exception("Couldn't authenticate to {}".format(endpoint))
+ raise Exception("Couldn't authenticate to {}".format(
+ self._endpoint))
response = result['response']
- client.info = response.copy()
- client.build_facades(response.get('facades', {}))
- client.loop.create_task(client.pinger())
- client.monitor.pinger_stopped.clear()
+ self.info = response.copy()
+ self.build_facades(response.get('facades', {}))
+ self.loop.create_task(self.pinger())
+ self.monitor.pinger_stopped.clear()
+ @classmethod
+ async def connect(
+ cls, endpoint, uuid, username, password, cacert=None,
+ macaroons=None, loop=None, max_frame_size=None):
+ """Connect to the websocket.
+
+ If uuid is None, the connection will be to the controller. Otherwise it
+ will be to the model.
+
+ """
+ client = cls(endpoint, uuid, username, password, cacert, macaroons,
+ loop, max_frame_size)
+ await client._connect()
return client
@classmethod
- async def connect_current(cls, loop=None):
+ async def connect_current(cls, loop=None, max_frame_size=None):
"""Connect to the currently active model.
"""
jujudata = JujuData()
+
controller_name = jujudata.current_controller()
+ if not controller_name:
+ raise JujuConnectionError('No current controller')
+
model_name = jujudata.current_model()
return await cls.connect_model(
- '{}:{}'.format(controller_name, model_name), loop)
+ '{}:{}'.format(controller_name, model_name), loop, max_frame_size)
@classmethod
- async def connect_current_controller(cls, loop=None):
+ async def connect_current_controller(cls, loop=None, max_frame_size=None):
"""Connect to the currently active controller.
"""
if not controller_name:
raise JujuConnectionError('No current controller')
- return await cls.connect_controller(controller_name, loop)
+ return await cls.connect_controller(controller_name, loop,
+ max_frame_size)
@classmethod
- async def connect_controller(cls, controller_name, loop=None):
+ async def connect_controller(cls, controller_name, loop=None,
+ max_frame_size=None):
"""Connect to a controller by name.
"""
macaroons = get_macaroons(controller_name) if not password else None
return await cls.connect(
- endpoint, None, username, password, cacert, macaroons, loop)
+ endpoint, None, username, password, cacert, macaroons, loop,
+ max_frame_size)
@classmethod
- async def connect_model(cls, model, loop=None):
+ async def connect_model(cls, model, loop=None, max_frame_size=None):
"""Connect to a model by name.
:param str model: [<controller>:]<model>
macaroons = get_macaroons(controller_name) if not password else None
return await cls.connect(
- endpoint, model_uuid, username, password, cacert, macaroons, loop)
+ endpoint, model_uuid, username, password, cacert, macaroons, loop,
+ max_frame_size)
def build_facades(self, facades):
self.facades.clear()