Fix issue where we do not check to make sure that we are receiving the correct response.
[osm/N2VC.git] / juju / client / connection.py
index 3011a8a..486a57f 100644 (file)
@@ -15,7 +15,7 @@ import asyncio
 import yaml
 
 from juju import tag
-from juju.errors import JujuAPIError, JujuConnectionError
+from juju.errors import JujuError, JujuAPIError, JujuConnectionError
 
 log = logging.getLogger("websocket")
 
@@ -34,6 +34,8 @@ class Connection:
         # Connect to the currently active model
         client = await Connection.connect_current()
 
+    Note: Any connection method or constructor can accept an optional `loop`
+    argument to override the default event loop from `asyncio.get_event_loop`.
     """
     def __init__(
             self, endpoint, uuid, username, password, cacert=None,
@@ -50,6 +52,7 @@ class Connection:
         self.addr = None
         self.ws = None
         self.facades = {}
+        self.messages = {}
 
     @property
     def is_open(self):
@@ -78,12 +81,22 @@ class Connection:
     async def close(self):
         await self.ws.close()
 
-    async def recv(self):
-        result = await self.ws.recv()
-        if result is not None:
-            result = json.loads(result)
+    async def recv(self, request_id):
+        while not self.messages.get(request_id):
+            await asyncio.sleep(0)
+
+        result = self.messages[request_id]
+
+        del self.messages[request_id]
         return result
 
+    async def receiver(self):
+        while self.is_open:
+            result = await self.ws.recv()
+            if result is not None:
+                result = json.loads(result)
+                self.messages[result['request-id']] = result
+
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
         msg['request-id'] = self.__request_id__
@@ -93,9 +106,31 @@ class Connection:
             msg['version'] = self.facades[msg['type']]
         outgoing = json.dumps(msg, indent=2, cls=encoder)
         await self.ws.send(outgoing)
-        result = await self.recv()
-        if result and 'error' in result:
+        result = await self.recv(msg['request-id'])
+
+        if not result:
+            return result
+
+        if 'error' in result:
+            # API Error Response
             raise JujuAPIError(result)
+
+        if not 'response' in result:
+            # This may never happen
+            return result
+
+        if 'results' in result['response']:
+            # Check for errors in a result list.
+            errors = []
+            for res in result['response']['results']:
+                if res.get('error', {}).get('message'):
+                    errors.append(res['error']['message'])
+            if errors:
+                raise JujuError(errors)
+
+        elif result['response'].get('error', {}).get('message'):
+            raise JujuError(result['response']['error']['message'])
+
         return result
 
     def http_headers(self):
@@ -186,6 +221,7 @@ class Connection:
         client = cls(endpoint, uuid, username, password, cacert, macaroons,
                      loop)
         await client.open()
+        self.loop.create_task(self.receiver)
 
         redirect_info = await client.redirect_info()
         if not redirect_info:
@@ -221,8 +257,7 @@ class Connection:
         """
         jujudata = JujuData()
         controller_name = jujudata.current_controller()
-        models = jujudata.models()[controller_name]
-        model_name = models['current-model']
+        model_name = jujudata.current_model()
 
         return await cls.connect_model(
             '{}:{}'.format(controller_name, model_name), loop)
@@ -260,20 +295,30 @@ class Connection:
     async def connect_model(cls, model, loop=None):
         """Connect to a model by name.
 
-        :param str model: <controller>:<model>
+        :param str model: [<controller>:]<model>
 
         """
-        controller_name, model_name = model.split(':')
-
         jujudata = JujuData()
+
+        if ':' in model:
+            # explicit controller given
+            controller_name, model_name = model.split(':')
+        else:
+            # use the current controller if one isn't explicitly given
+            controller_name = jujudata.current_controller()
+            model_name = model
+
+        accounts = jujudata.accounts()[controller_name]
+        username = accounts['user']
+        # model name must include a user prefix, so add it if it doesn't
+        if '/' not in model_name:
+            model_name = '{}/{}'.format(username, model_name)
+
         controller = jujudata.controllers()[controller_name]
         endpoint = controller['api-endpoints'][0]
         cacert = controller.get('ca-cert')
-        accounts = jujudata.accounts()[controller_name]
-        username = accounts['user']
         password = accounts.get('password')
         models = jujudata.models()[controller_name]
-        model_name = '{}/{}'.format(username, model_name)
         model_uuid = models['models'][model_name]['uuid']
         macaroons = get_macaroons() if not password else None
 
@@ -333,6 +378,14 @@ class JujuData:
         output = yaml.safe_load(output)
         return output.get('current-controller', '')
 
+    def current_model(self, controller_name=None):
+        if not controller_name:
+            controller_name = self.current_controller()
+        models = self.models()[controller_name]
+        if 'current-model' not in models:
+            raise JujuError('No current model')
+        return models['current-model']
+
     def controllers(self):
         return self._load_yaml('controllers.yaml', 'controllers')