import unittest
import mock
+import asynctest
def _make_delta(entity, type_, data=None):
assert model._get_series('~foo/ubuntu', entity) == 'xenial'
assert model._get_series('ubuntu', entity) == 'xenial'
assert model._get_series('cs:ubuntu', entity) == 'xenial'
+
+
+class TestContextManager(asynctest.TestCase):
+ @asynctest.patch('juju.model.Model.disconnect')
+ @asynctest.patch('juju.model.Model.connect_current')
+ async def test_normal_use(self, mock_connect, mock_disconnect):
+ from juju.model import Model
+
+ async with Model() as model:
+ self.assertTrue(isinstance(model, Model))
+
+ self.assertTrue(mock_connect.called)
+ self.assertTrue(mock_disconnect.called)
+
+ @asynctest.patch('juju.model.Model.disconnect')
+ @asynctest.patch('juju.model.Model.connect_current')
+ async def test_exception(self, mock_connect, mock_disconnect):
+ from juju.model import Model
+
+ class SomeException(Exception):
+ pass
+
+ with self.assertRaises(SomeException):
+ async with Model():
+ raise SomeException()
+
+ self.assertTrue(mock_connect.called)
+ self.assertTrue(mock_disconnect.called)
+
+ @asynctest.patch('juju.client.connection.JujuData.current_controller')
+ async def test_no_current_connection(self, mock_current_controller):
+ from juju.model import Model
+ from juju.errors import JujuConnectionError
+
+ mock_current_controller.return_value = ""
+
+ with self.assertRaises(JujuConnectionError):
+ async with Model():
+ pass