Add model context manager (#128)
authorsimonklb <simonkollberg@gmail.com>
Fri, 2 Jun 2017 16:28:35 +0000 (18:28 +0200)
committerCory Johns <johnsca@gmail.com>
Fri, 2 Jun 2017 16:28:35 +0000 (12:28 -0400)
juju/client/connection.py
juju/model.py
tests/unit/test_model.py
tox.ini

index 02e9d43..6851707 100644 (file)
@@ -407,7 +407,11 @@ class Connection:
 
         """
         jujudata = JujuData()
+
         controller_name = jujudata.current_controller()
+        if not controller_name:
+            raise JujuConnectionError('No current controller')
+
         model_name = jujudata.current_model()
 
         return await cls.connect_model(
index e56ad66..b97798a 100644 (file)
@@ -390,6 +390,16 @@ class Model(object):
         self._watch_received = asyncio.Event(loop=self.loop)
         self._charmstore = CharmStore(self.loop)
 
+    async def __aenter__(self):
+        await self.connect_current()
+        return self
+
+    async def __aexit__(self, exc_type, exc, tb):
+        await self.disconnect()
+
+        if exc_type is not None:
+            return False
+
     async def connect(self, *args, **kw):
         """Connect to an arbitrary Juju model.
 
index 67db5ae..222d881 100644 (file)
@@ -1,6 +1,7 @@
 import unittest
 
 import mock
+import asynctest
 
 
 def _make_delta(entity, type_, data=None):
@@ -113,3 +114,42 @@ def test_get_series():
     assert model._get_series('~foo/ubuntu', entity) == 'xenial'
     assert model._get_series('ubuntu', entity) == 'xenial'
     assert model._get_series('cs:ubuntu', entity) == 'xenial'
+
+
+class TestContextManager(asynctest.TestCase):
+    @asynctest.patch('juju.model.Model.disconnect')
+    @asynctest.patch('juju.model.Model.connect_current')
+    async def test_normal_use(self, mock_connect, mock_disconnect):
+        from juju.model import Model
+
+        async with Model() as model:
+            self.assertTrue(isinstance(model, Model))
+
+        self.assertTrue(mock_connect.called)
+        self.assertTrue(mock_disconnect.called)
+
+    @asynctest.patch('juju.model.Model.disconnect')
+    @asynctest.patch('juju.model.Model.connect_current')
+    async def test_exception(self, mock_connect, mock_disconnect):
+        from juju.model import Model
+
+        class SomeException(Exception):
+            pass
+
+        with self.assertRaises(SomeException):
+            async with Model():
+                raise SomeException()
+
+        self.assertTrue(mock_connect.called)
+        self.assertTrue(mock_disconnect.called)
+
+    @asynctest.patch('juju.client.connection.JujuData.current_controller')
+    async def test_no_current_connection(self, mock_current_controller):
+        from juju.model import Model
+        from juju.errors import JujuConnectionError
+
+        mock_current_controller.return_value = ""
+
+        with self.assertRaises(JujuConnectionError):
+            async with Model():
+                pass
diff --git a/tox.ini b/tox.ini
index 1ac6356..5b2d82e 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -16,10 +16,11 @@ deps =
     pytest-asyncio
     pytest-xdist
     mock
+    asynctest
 
 [testenv:py35]
 # default tox env excludes integration tests
-commands = py.test -ra -s -x -n auto -k 'not integration'
+commands = py.test -ra -s -x -n auto -k 'not integration' {posargs}
 
 [testenv:integration]
 basepython=python3