Expand integration tests to use stable/edge versions of juju (#155)
[osm/N2VC.git] / juju / model.py
index 4db711b..bd8709a 100644 (file)
@@ -14,6 +14,7 @@ from concurrent.futures import CancelledError
 from functools import partial
 from pathlib import Path
 
+import websockets
 import yaml
 import theblues.charmstore
 import theblues.errors
@@ -21,6 +22,7 @@ import theblues.errors
 from . import tag, utils
 from .client import client
 from .client import connection
+from .client.client import ConfigValue
 from .constraints import parse as parse_constraints, normalize_key
 from .delta import get_entity_delta
 from .delta import get_entity_class
@@ -382,13 +384,17 @@ class Model(object):
     """
     The main API for interacting with a Juju model.
     """
-    def __init__(self, loop=None):
+    def __init__(self, loop=None,
+                 max_frame_size=connection.Connection.DEFAULT_FRAME_SIZE):
         """Instantiate a new connected Model.
 
         :param loop: an asyncio event loop
+        :param max_frame_size: See
+            `juju.client.connection.Connection.MAX_FRAME_SIZE`
 
         """
         self.loop = loop or asyncio.get_event_loop()
+        self.max_frame_size = max_frame_size
         self.connection = None
         self.observers = weakref.WeakValueDictionary()
         self.state = ModelState(self)
@@ -416,6 +422,8 @@ class Model(object):
         """
         if 'loop' not in kw:
             kw['loop'] = self.loop
+        if 'max_frame_size' not in kw:
+            kw['max_frame_size'] = self.max_frame_size
         self.connection = await connection.Connection.connect(*args, **kw)
         await self._after_connect()
 
@@ -424,7 +432,7 @@ class Model(object):
 
         """
         self.connection = await connection.Connection.connect_current(
-            self.loop)
+            self.loop, max_frame_size=self.max_frame_size)
         await self._after_connect()
 
     async def connect_model(self, model_name):
@@ -433,8 +441,8 @@ class Model(object):
         :param model_name:  Format [controller:][user/]model
 
         """
-        self.connection = await connection.Connection.connect_model(model_name,
-                                                                    self.loop)
+        self.connection = await connection.Connection.connect_model(
+            model_name, self.loop, self.max_frame_size)
         await self._after_connect()
 
     async def _after_connect(self):
@@ -544,6 +552,8 @@ class Model(object):
         """
         async def _block():
             while not all(c() for c in conditions):
+                if not (self.connection and self.connection.is_open):
+                    raise websockets.ConnectionClosed(1006, 'no reason')
                 await asyncio.sleep(wait_period, loop=self.loop)
         await asyncio.wait_for(_block(), timeout, loop=self.loop)
 
@@ -637,16 +647,45 @@ class Model(object):
         See :meth:`add_observer` to register an onchange callback.
 
         """
-        async def _start_watch():
+        async def _all_watcher():
             try:
                 allwatcher = client.AllWatcherFacade.from_connection(
                     self.connection)
                 while not self._watch_stopping.is_set():
-                    results = await utils.run_with_interrupt(
-                        allwatcher.Next(),
-                        self._watch_stopping,
-                        self.loop)
+                    try:
+                        results = await utils.run_with_interrupt(
+                            allwatcher.Next(),
+                            self._watch_stopping,
+                            self.loop)
+                    except JujuAPIError as e:
+                        if 'watcher was stopped' not in str(e):
+                            raise
+                        if self._watch_stopping.is_set():
+                            # this shouldn't ever actually happen, because
+                            # the event should trigger before the controller
+                            # has a chance to tell us the watcher is stopped
+                            # but handle it gracefully, just in case
+                            break
+                        # controller stopped our watcher for some reason
+                        # but we're not actually stopping, so just restart it
+                        log.warning(
+                            'Watcher: watcher stopped, restarting')
+                        del allwatcher.Id
+                        continue
+                    except websockets.ConnectionClosed:
+                        monitor = self.connection.monitor
+                        if monitor.status == monitor.ERROR:
+                            # closed unexpectedly, try to reopen
+                            log.warning(
+                                'Watcher: connection closed, reopening')
+                            await self.connection.reconnect()
+                            del allwatcher.Id
+                            continue
+                        else:
+                            # closed on request, go ahead and shutdown
+                            break
                     if self._watch_stopping.is_set():
+                        await allwatcher.Stop()
                         break
                     for delta in results.deltas:
                         delta = get_entity_delta(delta)
@@ -665,7 +704,7 @@ class Model(object):
         self._watch_received.clear()
         self._watch_stopping.clear()
         self._watch_stopped.clear()
-        self.loop.create_task(_start_watch())
+        self.loop.create_task(_all_watcher())
 
     async def _notify_observers(self, delta, old_obj, new_obj):
         """Call observing callbacks, notifying them of a change in model state
@@ -1217,11 +1256,20 @@ class Model(object):
         """
         raise NotImplementedError()
 
-    def get_config(self):
+    async def get_config(self):
         """Return the configuration settings for this model.
 
+        :returns: A ``dict`` mapping keys to `ConfigValue` instances,
+            which have `source` and `value` attributes.
         """
-        raise NotImplementedError()
+        config_facade = client.ModelConfigFacade.from_connection(
+            self.connection
+        )
+        result = await config_facade.ModelGet()
+        config = result.config
+        for key, value in config.items():
+            config[key] = ConfigValue.from_json(value)
+        return config
 
     def get_constraints(self):
         """Return the machine constraints for this model.
@@ -1402,13 +1450,19 @@ class Model(object):
         """
         raise NotImplementedError()
 
-    def set_config(self, **config):
+    async def set_config(self, config):
         """Set configuration keys on this model.
 
-        :param \*\*config: Config key/values
-
+        :param dict config: Mapping of config keys to either string values or
+            `ConfigValue` instances, as returned by `get_config`.
         """
-        raise NotImplementedError()
+        config_facade = client.ModelConfigFacade.from_connection(
+            self.connection
+        )
+        for key, value in config.items():
+            if isinstance(value, ConfigValue):
+                config[key] = value.value
+        await config_facade.ModelSet(config)
 
     def set_constraints(self, constraints):
         """Set machine constraints on this model.