Bump rev for release
[osm/N2VC.git] / juju / client / connection.py
index 2be360f..6851707 100644 (file)
@@ -9,14 +9,15 @@ import ssl
 import string
 import subprocess
 import websockets
+from concurrent.futures import CancelledError
 from http.client import HTTPSConnection
+from pathlib import Path
 
 import asyncio
 import yaml
 
-from juju import tag
+from juju import tag, utils
 from juju.client import client
-from juju.client.version_map import VERSION_MAP
 from juju.errors import JujuError, JujuAPIError, JujuConnectionError
 from juju.utils import IdQueue
 
@@ -44,8 +45,11 @@ class Monitor:
 
     def __init__(self, connection):
         self.connection = connection
-        self.receiver = None
-        self.pinger = None
+        self.close_called = asyncio.Event(loop=self.connection.loop)
+        self.receiver_stopped = asyncio.Event(loop=self.connection.loop)
+        self.pinger_stopped = asyncio.Event(loop=self.connection.loop)
+        self.receiver_stopped.set()
+        self.pinger_stopped.set()
 
     @property
     def status(self):
@@ -63,21 +67,21 @@ class Monitor:
         # DISCONNECTED: connection not yet open
         if not self.connection.ws:
             return self.DISCONNECTED
-        if not self.receiver:
+        if self.receiver_stopped.is_set():
             return self.DISCONNECTED
 
         # ERROR: Connection closed (or errored), but we didn't call
         # connection.close
-        if not self.connection.close_called and self.receiver_exceptions():
+        if not self.close_called.is_set() and self.receiver_stopped.is_set():
             return self.ERROR
-        if not self.connection.close_called and not self.connection.ws.open:
-            # The check for self.receiver existing above guards against the
-            # case where we're not open because we simply haven't
-            # setup the connection yet.
+        if not self.close_called.is_set() and not self.connection.ws.open:
+            # The check for self.receiver_stopped existing above guards
+            # against the case where we're not open because we simply
+            # haven't setup the connection yet.
             return self.ERROR
 
         # DISCONNECTED: cleanly disconnected.
-        if self.connection.close_called and not self.connection.ws.open:
+        if self.close_called.is_set() and not self.connection.ws.open:
             return self.DISCONNECTED
 
         # CONNECTED: everything is fine!
@@ -89,17 +93,6 @@ class Monitor:
         # know what state the connection is in.
         return self.UNKNOWN
 
-    def receiver_exceptions(self):
-        """
-        Return exceptions in the receiver, if any.
-
-        """
-        if not self.receiver:
-            return None
-        if not self.receiver.done():
-            return None
-        return self.receiver.exception()
-
 
 class Connection:
     """
@@ -139,7 +132,6 @@ class Connection:
         self.ws = None
         self.facades = {}
         self.messages = IdQueue(loop=self.loop)
-        self.close_called = False
         self.monitor = Monitor(connection=self)
 
     @property
@@ -163,19 +155,18 @@ class Connection:
         kw['loop'] = self.loop
         self.addr = url
         self.ws = await websockets.connect(url, **kw)
-        self.monitor.receiver = self.loop.create_task(self.receiver())
+        self.loop.create_task(self.receiver())
+        self.monitor.receiver_stopped.clear()
         log.info("Driver connected to juju %s", url)
+        self.monitor.close_called.clear()
         return self
 
     async def close(self):
         if not self.is_open:
             return
-        self.close_called = True
-        if self.monitor.pinger:
-            # might be closing due to login failure,
-            # in which case we won't have a pinger yet
-            self.monitor.pinger.cancel()
-        self.monitor.receiver.cancel()
+        self.monitor.close_called.set()
+        await self.monitor.pinger_stopped.wait()
+        await self.monitor.receiver_stopped.wait()
         await self.ws.close()
 
     async def recv(self, request_id):
@@ -184,19 +175,29 @@ class Connection:
         return await self.messages.get(request_id)
 
     async def receiver(self):
-        while self.is_open:
-            try:
-                result = await self.ws.recv()
+        try:
+            while self.is_open:
+                result = await utils.run_with_interrupt(
+                    self.ws.recv(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
                 if result is not None:
                     result = json.loads(result)
                     await self.messages.put(result['request-id'], result)
-            except Exception as e:
-                await self.messages.put_all(e)
-                if isinstance(e, websockets.ConnectionClosed):
-                    # ConnectionClosed is not really exceptional for us,
-                    # but it may be for any pending message listeners
-                    return
-                raise
+        except CancelledError:
+            pass
+        except Exception as e:
+            await self.messages.put_all(e)
+            if isinstance(e, websockets.ConnectionClosed):
+                # ConnectionClosed is not really exceptional for us,
+                # but it may be for any pending message listeners
+                return
+            log.exception("Error in receiver")
+            raise
+        finally:
+            self.monitor.receiver_stopped.set()
 
     async def pinger(self):
         '''
@@ -206,10 +207,24 @@ class Connection:
         To prevent timing out, we send a ping every ten seconds.
 
         '''
+        async def _do_ping():
+            try:
+                await pinger_facade.Ping()
+                await asyncio.sleep(10, loop=self.loop)
+            except CancelledError:
+                pass
+
         pinger_facade = client.PingerFacade.from_connection(self)
-        while self.is_open:
-            await pinger_facade.Ping()
-            await asyncio.sleep(10, loop=self.loop)
+        try:
+            while self.is_open:
+                await utils.run_with_interrupt(
+                    _do_ping(),
+                    self.monitor.close_called,
+                    loop=self.loop)
+                if self.monitor.close_called.is_set():
+                    break
+        finally:
+            self.monitor.pinger_stopped.set()
 
     async def rpc(self, msg, encoder=None):
         self.__request_id__ += 1
@@ -381,7 +396,8 @@ class Connection:
         response = result['response']
         client.info = response.copy()
         client.build_facades(response.get('facades', {}))
-        client.monitor.pinger = client.loop.create_task(client.pinger())
+        client.loop.create_task(client.pinger())
+        client.monitor.pinger_stopped.clear()
 
         return client
 
@@ -391,7 +407,11 @@ class Connection:
 
         """
         jujudata = JujuData()
+
         controller_name = jujudata.current_controller()
+        if not controller_name:
+            raise JujuConnectionError('No current controller')
+
         model_name = jujudata.current_model()
 
         return await cls.connect_model(
@@ -421,7 +441,7 @@ class Connection:
         accounts = jujudata.accounts()[controller_name]
         username = accounts['user']
         password = accounts.get('password')
-        macaroons = get_macaroons() if not password else None
+        macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
             endpoint, None, username, password, cacert, macaroons, loop)
@@ -455,29 +475,15 @@ class Connection:
         password = accounts.get('password')
         models = jujudata.models()[controller_name]
         model_uuid = models['models'][model_name]['uuid']
-        macaroons = get_macaroons() if not password else None
+        macaroons = get_macaroons(controller_name) if not password else None
 
         return await cls.connect(
             endpoint, model_uuid, username, password, cacert, macaroons, loop)
 
     def build_facades(self, facades):
         self.facades.clear()
-        # In order to work around an issue where the juju api is not
-        # returning a complete list of facades, we simply look up the
-        # juju version in a pregenerated map, and use that info to
-        # populate our list of facades.
-
-        # TODO: if a future version of juju fixes this bug, restore
-        # the following code for that version and higher:
-        # for facade in facades:
-        #     self.facades[facade['name']] = facade['versions'][-1]
-        try:
-            self.facades = VERSION_MAP[self.info['server-version']]
-        except KeyError:
-            log.warning("Could not find a set of facades for {}. Using "
-                        "the latest facade set instead".format(
-                            self.info['server-version']))
-            self.facades = VERSION_MAP['latest']
+        for facade in facades:
+            self.facades[facade['name']] = facade['versions'][-1]
 
     async def login(self):
         username = self.username
@@ -544,16 +550,26 @@ class JujuData:
             return yaml.safe_load(f)[key]
 
 
-def get_macaroons():
+def get_macaroons(controller_name=None):
     """Decode and return macaroons from default ~/.go-cookies
 
     """
-    try:
-        cookie_file = os.path.expanduser('~/.go-cookies')
-        with open(cookie_file, 'r') as f:
-            cookies = json.load(f)
-    except (OSError, ValueError):
-        log.warn("Couldn't load macaroons from %s", cookie_file)
+    cookie_files = []
+    if controller_name:
+        cookie_files.append('~/.local/share/juju/cookies/{}.json'.format(
+            controller_name))
+    cookie_files.append('~/.go-cookies')
+    for cookie_file in cookie_files:
+        cookie_file = Path(cookie_file).expanduser()
+        if cookie_file.exists():
+            try:
+                cookies = json.loads(cookie_file.read_text())
+                break
+            except (OSError, ValueError):
+                log.warn("Couldn't load macaroons from %s", cookie_file)
+                return []
+    else:
+        log.warn("Couldn't load macaroons from %s", ' or '.join(cookie_files))
         return []
 
     base64_macaroons = [