self._watch_received = asyncio.Event(loop=self.loop)
self._charmstore = CharmStore(self.loop)
+ async def __aenter__(self):
+ await self.connect_current()
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.disconnect()
+
+ if exc_type is not None:
+ return False
+
async def connect(self, *args, **kw):
"""Connect to an arbitrary Juju model.
import unittest
import mock
+import asynctest
def _make_delta(entity, type_, data=None):
assert model._get_series('~foo/ubuntu', entity) == 'xenial'
assert model._get_series('ubuntu', entity) == 'xenial'
assert model._get_series('cs:ubuntu', entity) == 'xenial'
+
+
+class TestContextManager(asynctest.TestCase):
+ @asynctest.patch('juju.model.Model.disconnect')
+ @asynctest.patch('juju.model.Model.connect_current')
+ async def test_normal_use(self, mock_connect, mock_disconnect):
+ from juju.model import Model
+
+ async with Model() as model:
+ self.assertTrue(isinstance(model, Model))
+
+ self.assertTrue(mock_connect.called)
+ self.assertTrue(mock_disconnect.called)
+
+ @asynctest.patch('juju.model.Model.disconnect')
+ @asynctest.patch('juju.model.Model.connect_current')
+ async def test_exception(self, mock_connect, mock_disconnect):
+ from juju.model import Model
+
+ class SomeException(Exception):
+ pass
+
+ with self.assertRaises(SomeException):
+ async with Model():
+ raise SomeException()
+
+ self.assertTrue(mock_connect.called)
+ self.assertTrue(mock_disconnect.called)
+
+ @asynctest.patch('juju.client.connection.JujuData.current_controller')
+ async def test_no_current_connection(self, mock_current_controller):
+ from juju.model import Model
+ from juju.errors import JujuConnectionError
+
+ mock_current_controller.return_value = ""
+
+ with self.assertRaises(JujuConnectionError):
+ async with Model():
+ pass