Configurable and larger max message size (#146)
[osm/N2VC.git] / juju / client / connection.py
index c2c6b2d..6f2f2a2 100644 (file)
@@ -11,13 +11,13 @@ import subprocess
 import websockets
 from concurrent.futures import CancelledError
 from http.client import HTTPSConnection
 import websockets
 from concurrent.futures import CancelledError
 from http.client import HTTPSConnection
+from pathlib import Path
 
 import asyncio
 import yaml
 
 from juju import tag, utils
 from juju.client import client
 
 import asyncio
 import yaml
 
 from juju import tag, utils
 from juju.client import client
-from juju.client.version_map import VERSION_MAP
 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
 from juju.utils import IdQueue
 
 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
 from juju.utils import IdQueue
 
@@ -111,9 +111,14 @@ class Connection:
     Note: Any connection method or constructor can accept an optional `loop`
     argument to override the default event loop from `asyncio.get_event_loop`.
     """
     Note: Any connection method or constructor can accept an optional `loop`
     argument to override the default event loop from `asyncio.get_event_loop`.
     """
+
+    DEFAULT_FRAME_SIZE = 'default_frame_size'
+    MAX_FRAME_SIZE = 2**22
+    "Maximum size for a single frame.  Defaults to 4MB."
+
     def __init__(
             self, endpoint, uuid, username, password, cacert=None,
     def __init__(
             self, endpoint, uuid, username, password, cacert=None,
-            macaroons=None, loop=None):
+            macaroons=None, loop=None, max_frame_size=DEFAULT_FRAME_SIZE):
         self.endpoint = endpoint
         self.uuid = uuid
         if macaroons:
         self.endpoint = endpoint
         self.uuid = uuid
         if macaroons:
@@ -133,6 +138,9 @@ class Connection:
         self.facades = {}
         self.messages = IdQueue(loop=self.loop)
         self.monitor = Monitor(connection=self)
         self.facades = {}
         self.messages = IdQueue(loop=self.loop)
         self.monitor = Monitor(connection=self)
+        if max_frame_size is self.DEFAULT_FRAME_SIZE:
+            max_frame_size = self.MAX_FRAME_SIZE
+        self.max_frame_size = max_frame_size
 
     @property
     def is_open(self):
 
     @property
     def is_open(self):
@@ -153,6 +161,7 @@ class Connection:
         kw = dict()
         kw['ssl'] = self._get_ssl(self.cacert)
         kw['loop'] = self.loop
         kw = dict()
         kw['ssl'] = self._get_ssl(self.cacert)
         kw['loop'] = self.loop
+        kw['max_size'] = self.max_frame_size
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
         self.loop.create_task(self.receiver())
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
         self.loop.create_task(self.receiver())
@@ -321,6 +330,7 @@ class Connection:
             self.cacert,
             self.macaroons,
             self.loop,
             self.cacert,
             self.macaroons,
             self.loop,
+            self.max_frame_size,
         )
 
     async def controller(self):
         )
 
     async def controller(self):
@@ -372,7 +382,7 @@ class Connection:
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
     @classmethod
     async def connect(
             cls, endpoint, uuid, username, password, cacert=None,
-            macaroons=None, loop=None):
+            macaroons=None, loop=None, max_frame_size=None):
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
         """Connect to the websocket.
 
         If uuid is None, the connection will be to the controller. Otherwise it
@@ -380,7 +390,7 @@ class Connection:
 
         """
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
 
         """
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
-                     loop)
+                     loop, max_frame_size)
         endpoints = [(endpoint, cacert)]
         while endpoints:
             _endpoint, _cacert = endpoints.pop(0)
         endpoints = [(endpoint, cacert)]
         while endpoints:
             _endpoint, _cacert = endpoints.pop(0)
@@ -402,19 +412,23 @@ class Connection:
         return client
 
     @classmethod
         return client
 
     @classmethod
-    async def connect_current(cls, loop=None):
+    async def connect_current(cls, loop=None, max_frame_size=None):
         """Connect to the currently active model.
 
         """
         jujudata = JujuData()
         """Connect to the currently active model.
 
         """
         jujudata = JujuData()
+
         controller_name = jujudata.current_controller()
         controller_name = jujudata.current_controller()
+        if not controller_name:
+            raise JujuConnectionError('No current controller')
+
         model_name = jujudata.current_model()
 
         return await cls.connect_model(
         model_name = jujudata.current_model()
 
         return await cls.connect_model(
-            '{}:{}'.format(controller_name, model_name), loop)
+            '{}:{}'.format(controller_name, model_name), loop, max_frame_size)
 
     @classmethod
 
     @classmethod
-    async def connect_current_controller(cls, loop=None):
+    async def connect_current_controller(cls, loop=None, max_frame_size=None):
         """Connect to the currently active controller.
 
         """
         """Connect to the currently active controller.
 
         """
@@ -423,10 +437,12 @@ class Connection:
         if not controller_name:
             raise JujuConnectionError('No current controller')
 
         if not controller_name:
             raise JujuConnectionError('No current controller')
 
-        return await cls.connect_controller(controller_name, loop)
+        return await cls.connect_controller(controller_name, loop,
+                                            max_frame_size)
 
     @classmethod
 
     @classmethod
-    async def connect_controller(cls, controller_name, loop=None):
+    async def connect_controller(cls, controller_name, loop=None,
+                                 max_frame_size=None):
         """Connect to a controller by name.
 
         """
         """Connect to a controller by name.
 
         """
@@ -437,13 +453,14 @@ class Connection:
         accounts = jujudata.accounts()[controller_name]
         username = accounts['user']
         password = accounts.get('password')
         accounts = jujudata.accounts()[controller_name]
         username = accounts['user']
         password = accounts.get('password')
-        macaroons = get_macaroons() if not password else None
+        macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
 
         return await cls.connect(
-            endpoint, None, username, password, cacert, macaroons, loop)
+            endpoint, None, username, password, cacert, macaroons, loop,
+            max_frame_size)
 
     @classmethod
 
     @classmethod
-    async def connect_model(cls, model, loop=None):
+    async def connect_model(cls, model, loop=None, max_frame_size=None):
         """Connect to a model by name.
 
         :param str model: [<controller>:]<model>
         """Connect to a model by name.
 
         :param str model: [<controller>:]<model>
@@ -471,29 +488,16 @@ class Connection:
         password = accounts.get('password')
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
         password = accounts.get('password')
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
-        macaroons = get_macaroons() if not password else None
+        macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
 
         return await cls.connect(
-            endpoint, model_uuid, username, password, cacert, macaroons, loop)
+            endpoint, model_uuid, username, password, cacert, macaroons, loop,
+            max_frame_size)
 
     def build_facades(self, facades):
         self.facades.clear()
 
     def build_facades(self, facades):
         self.facades.clear()
-        # In order to work around an issue where the juju api is not
-        # returning a complete list of facades, we simply look up the
-        # juju version in a pregenerated map, and use that info to
-        # populate our list of facades.
-
-        # TODO: if a future version of juju fixes this bug, restore
-        # the following code for that version and higher:
-        # for facade in facades:
-        #     self.facades[facade['name']] = facade['versions'][-1]
-        try:
-            self.facades = VERSION_MAP[self.info['server-version']]
-        except KeyError:
-            log.warning("Could not find a set of facades for {}. Using "
-                        "the latest facade set instead".format(
-                            self.info['server-version']))
-            self.facades = VERSION_MAP['latest']
+        for facade in facades:
+            self.facades[facade['name']] = facade['versions'][-1]
 
     async def login(self):
         username = self.username
 
     async def login(self):
         username = self.username
@@ -560,16 +564,26 @@ class JujuData:
             return yaml.safe_load(f)[key]
 
 
             return yaml.safe_load(f)[key]
 
 
-def get_macaroons():
+def get_macaroons(controller_name=None):
     """Decode and return macaroons from default ~/.go-cookies
 
     """
     """Decode and return macaroons from default ~/.go-cookies
 
     """
-    try:
-        cookie_file = os.path.expanduser('~/.go-cookies')
-        with open(cookie_file, 'r') as f:
-            cookies = json.load(f)
-    except (OSError, ValueError):
-        log.warn("Couldn't load macaroons from %s", cookie_file)
+    cookie_files = []
+    if controller_name:
+        cookie_files.append('~/.local/share/juju/cookies/{}.json'.format(
+            controller_name))
+    cookie_files.append('~/.go-cookies')
+    for cookie_file in cookie_files:
+        cookie_file = Path(cookie_file).expanduser()
+        if cookie_file.exists():
+            try:
+                cookies = json.loads(cookie_file.read_text())
+                break
+            except (OSError, ValueError):
+                log.warn("Couldn't load macaroons from %s", cookie_file)
+                return []
+    else:
+        log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files))
         return []
 
     base64_macaroons = [
         return []
 
     base64_macaroons = [