Pin black version in tox.ini to 23.12.1
[osm/N2VC.git] / n2vc / juju_watcher.py
index 9f9520f..747f08e 100644 (file)
@@ -14,6 +14,7 @@
 
 import asyncio
 import time
+
 from juju.client import client
 from n2vc.exceptions import EntityInvalidException
 from n2vc.n2vc_conn import N2VCConnector
@@ -42,6 +43,7 @@ def entity_ready(entity: ModelEntity) -> bool:
 
     :returns: boolean saying if the entity is ready or not
     """
+
     entity_type = entity.entity_type
     if entity_type == "machine":
         return entity.agent_status in ["started"]
@@ -50,6 +52,8 @@ def entity_ready(entity: ModelEntity) -> bool:
     elif entity_type == "application":
         # Workaround for bug: https://github.com/juju/python-libjuju/issues/441
         return entity.status in ["active", "blocked"]
+    elif entity_type == "unit":
+        return entity.agent_status in ["idle"]
     else:
         raise EntityInvalidException("Unknown entity type: {}".format(entity_type))
 
@@ -143,7 +147,7 @@ class JujuModelWatcher:
             total_timeout = 3600.0
 
         entity_type = entity.entity_type
-        if entity_type not in ["application", "action", "machine"]:
+        if entity_type not in ["application", "action", "machine", "unit"]:
             raise EntityInvalidException("Unknown entity type: {}".format(entity_type))
 
         # Coroutine to wait until the entity reaches the final state
@@ -177,6 +181,113 @@ class JujuModelWatcher:
             for task in tasks:
                 task.cancel()
 
+    @staticmethod
+    async def wait_for_units_idle(
+        model: Model, application: Application, timeout: float = 60
+    ):
+        """
+        Waits for the application and all its units to transition back to idle
+
+        :param: model:          Model to observe
+        :param: application:    The application to be observed
+        :param: timeout:        Maximum time between two updates in the model
+
+        :raises: asyncio.TimeoutError when timeout reaches
+        """
+
+        ensure_units_idle = asyncio.ensure_future(
+            asyncio.wait_for(
+                JujuModelWatcher.ensure_units_idle(model, application), timeout
+            )
+        )
+        tasks = [
+            ensure_units_idle,
+        ]
+        (done, pending) = await asyncio.wait(
+            tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
+        )
+
+        if ensure_units_idle in pending:
+            ensure_units_idle.cancel()
+            raise TimeoutError(
+                "Application's units failed to return to idle after {} seconds".format(
+                    timeout
+                )
+            )
+        if ensure_units_idle.result():
+            pass
+
+    @staticmethod
+    async def ensure_units_idle(model: Model, application: Application):
+        """
+        Waits forever until the application's units to transition back to idle
+
+        :param: model:          Model to observe
+        :param: application:    The application to be observed
+        """
+
+        try:
+            allwatcher = client.AllWatcherFacade.from_connection(model.connection())
+            unit_wanted_state = "executing"
+            final_state_reached = False
+
+            units = application.units
+            final_state_seen = {unit.entity_id: False for unit in units}
+            agent_state_seen = {unit.entity_id: False for unit in units}
+            workload_state = {unit.entity_id: False for unit in units}
+
+            try:
+                while not final_state_reached:
+                    change = await allwatcher.Next()
+
+                    # Keep checking to see if new units were added during the change
+                    for unit in units:
+                        if unit.entity_id not in final_state_seen:
+                            final_state_seen[unit.entity_id] = False
+                            agent_state_seen[unit.entity_id] = False
+                            workload_state[unit.entity_id] = False
+
+                    for delta in change.deltas:
+                        await asyncio.sleep(0)
+                        if delta.entity != units[0].entity_type:
+                            continue
+
+                        final_state_reached = True
+                        for unit in units:
+                            if delta.data["name"] == unit.entity_id:
+                                status = delta.data["agent-status"]["current"]
+                                workload_state[unit.entity_id] = delta.data[
+                                    "workload-status"
+                                ]["current"]
+
+                                if status == unit_wanted_state:
+                                    agent_state_seen[unit.entity_id] = True
+                                    final_state_seen[unit.entity_id] = False
+
+                                if (
+                                    status == "idle"
+                                    and agent_state_seen[unit.entity_id]
+                                ):
+                                    final_state_seen[unit.entity_id] = True
+
+                            final_state_reached = (
+                                final_state_reached
+                                and final_state_seen[unit.entity_id]
+                                and workload_state[unit.entity_id]
+                                in [
+                                    "active",
+                                    "error",
+                                ]
+                            )
+
+            except ConnectionClosed:
+                pass
+                # This is expected to happen when the
+                # entity reaches its final state, because
+                # the model connection is closed afterwards
+        except Exception as e:
+            raise e
+
     @staticmethod
     async def model_watcher(
         model: Model,
@@ -201,69 +312,76 @@ class JujuModelWatcher:
         :raises: asyncio.TimeoutError when timeout reaches
         """
 
-        allwatcher = client.AllWatcherFacade.from_connection(model.connection())
+        try:
+            allwatcher = client.AllWatcherFacade.from_connection(model.connection())
 
-        # Genenerate array with entity types to listen
-        entity_types = (
-            [entity_type, "unit"]
-            if entity_type == "application"  # TODO: Add "action" too
-            else [entity_type]
-        )
+            # Genenerate array with entity types to listen
+            entity_types = (
+                [entity_type, "unit"]
+                if entity_type == "application"  # TODO: Add "action" too
+                else [entity_type]
+            )
 
-        # Get time when it should timeout
-        timeout_end = time.time() + timeout
+            # Get time when it should timeout
+            timeout_end = time.time() + timeout
 
-        try:
-            while True:
-                change = await allwatcher.Next()
-                for delta in change.deltas:
-                    write = False
-                    delta_entity = None
-
-                    # Get delta EntityType
-                    delta_entity = delta.entity
-
-                    if delta_entity in entity_types:
-                        # Get entity id
-                        if entity_type == "application":
-                            id = (
-                                delta.data["application"]
-                                if delta_entity == "unit"
-                                else delta.data["name"]
-                            )
-                        else:
-                            id = delta.data["id"]
-
-                        # Write if the entity id match
-                        write = True if id == entity_id else False
-
-                        # Update timeout
-                        timeout_end = time.time() + timeout
-                        (
-                            status,
-                            status_message,
-                            vca_status,
-                        ) = JujuModelWatcher.get_status(delta)
-
-                        if write and n2vc is not None and db_dict:
-                            # Write status to DB
-                            status = n2vc.osm_status(delta_entity, status)
-                            await n2vc.write_app_status_to_db(
-                                db_dict=db_dict,
-                                status=status,
-                                detailed_status=status_message,
-                                vca_status=vca_status,
-                                entity_type=delta_entity,
-                                vca_id=vca_id,
-                            )
-                # Check if timeout
-                if time.time() > timeout_end:
-                    raise asyncio.TimeoutError()
-        except ConnectionClosed:
-            pass
-            # This is expected to happen when the
-            # entity reaches its final state, because
-            # the model connection is closed afterwards
+            try:
+                while True:
+                    change = await allwatcher.Next()
+                    for delta in change.deltas:
+                        write = False
+                        delta_entity = None
+
+                        # Get delta EntityType
+                        delta_entity = delta.entity
+
+                        if delta_entity in entity_types:
+                            # Get entity id
+                            id = None
+                            if entity_type == "application":
+                                id = (
+                                    delta.data["application"]
+                                    if delta_entity == "unit"
+                                    else delta.data["name"]
+                                )
+                            else:
+                                if "id" in delta.data:
+                                    id = delta.data["id"]
+                                else:
+                                    print("No id {}".format(delta.data))
+
+                            # Write if the entity id match
+                            write = True if id == entity_id else False
+
+                            # Update timeout
+                            timeout_end = time.time() + timeout
+                            (
+                                status,
+                                status_message,
+                                vca_status,
+                            ) = JujuModelWatcher.get_status(delta)
+
+                            if write and n2vc is not None and db_dict:
+                                # Write status to DB
+                                status = n2vc.osm_status(delta_entity, status)
+                                await n2vc.write_app_status_to_db(
+                                    db_dict=db_dict,
+                                    status=status,
+                                    detailed_status=status_message,
+                                    vca_status=vca_status,
+                                    entity_type=delta_entity,
+                                    vca_id=vca_id,
+                                )
+                    # Check if timeout
+                    if time.time() > timeout_end:
+                        raise asyncio.TimeoutError()
+            except ConnectionClosed:
+                pass
+                # This is expected to happen when the
+                # entity reaches its final state, because
+                # the model connection is closed afterwards
+        except Exception as e:
+            raise e
 
     @staticmethod
     def get_status(delta: Delta) -> (str, str, str):