Configurable and larger max message size (#146)
[osm/N2VC.git] / juju / client / connection.py
index 6851707..6f2f2a2 100644 (file)
@@ -111,9 +111,14 @@ class Connection:
     Note: Any connection method or constructor can accept an optional `loop`
     argument to override the default event loop from `asyncio.get_event_loop`.
     """
+
+    DEFAULT_FRAME_SIZE = 'default_frame_size'
+    MAX_FRAME_SIZE = 2**22
+    "Maximum size for a single frame.  Defaults to 4MB."
+
     def __init__(
             self, endpoint, uuid, username, password, cacert=None,
-            macaroons=None, loop=None):
+            macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
         self.endpoint = endpoint
         self.uuid = uuid
         if macaroons:
@@ -133,6 +138,9 @@ class Connection:
         self.facades = {}
         self.messages = IdQueue(loop=self.loop)
         self.monitor = Monitor(connection=self)
+        if max_frame_size is self.DEFAULT_FRAME_SIZE:
+            max_frame_size = self.MAX_FRAME_SIZE
+        self.max_frame_size = max_frame_size
 
     @property
     def is_open(self):
@@ -153,6 +161,7 @@ class Connection:
         kw = dict()
         kw['ssl'] = self._get_ssl(self.cacert)
         kw['loop'] = self.loop
+        kw['max_size'] = self.max_frame_size
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
         self.loop.create_task(self.receiver())
@@ -321,6 +330,7 @@ class Connection:
             self.cacert,
             self.macaroons,
             self.loop,
+            self.max_frame_size,
         )
 
     async def controller(self):
@@ -372,7 +382,7 @@ class Connection:
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
-            macaroons=None, loop=None):
+            macaroons=None, loop=None, max_frame_size=None):
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
@@ -380,7 +390,7 @@ class Connection:
 
         """
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
-                     loop)
+                     loop, max_frame_size)
         endpoints = [(endpoint, cacert)]
         while endpoints:
             _endpoint, _cacert = endpoints.pop(0)
@@ -402,7 +412,7 @@ class Connection:
         return client
 
     @classmethod
-    async def connect_current(cls, loop=None):
+    async def connect_current(cls, loop=None, max_frame_size=None):
         """Connect to the currently active model.
 
         """
@@ -415,10 +425,10 @@ class Connection:
         model_name = jujudata.current_model()
 
         return await cls.connect_model(
-            '{}:{}'.format(controller_name, model_name), loop)
+            '{}:{}'.format(controller_name, model_name), loop, max_frame_size)
 
     @classmethod
-    async def connect_current_controller(cls, loop=None):
+    async def connect_current_controller(cls, loop=None, max_frame_size=None):
         """Connect to the currently active controller.
 
         """
@@ -427,10 +437,12 @@ class Connection:
         if not controller_name:
             raise JujuConnectionError('No current controller')
 
-        return await cls.connect_controller(controller_name, loop)
+        return await cls.connect_controller(controller_name, loop,
+                                            max_frame_size)
 
     @classmethod
-    async def connect_controller(cls, controller_name, loop=None):
+    async def connect_controller(cls, controller_name, loop=None,
+                                 max_frame_size=None):
         """Connect to a controller by name.
 
         """
@@ -444,10 +456,11 @@ class Connection:
         macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
-            endpoint, None, username, password, cacert, macaroons, loop)
+            endpoint, None, username, password, cacert, macaroons, loop,
+            max_frame_size)
 
     @classmethod
-    async def connect_model(cls, model, loop=None):
+    async def connect_model(cls, model, loop=None, max_frame_size=None):
         """Connect to a model by name.
 
         :param str model: [<controller>:]<model>
@@ -478,7 +491,8 @@ class Connection:
         macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert, macaroons, loop)
+            endpoint, model_uuid, username, password, cacert, macaroons, loop,
+            max_frame_size)
 
     def build_facades(self, facades):
         self.facades.clear()