Initial refactor of N2VC
[osm/N2VC.git] / n2vc / juju_watcher.py
diff --git a/n2vc/juju_watcher.py b/n2vc/juju_watcher.py
new file mode 100644 (file)
index 0000000..815abf9
--- /dev/null
@@ -0,0 +1,209 @@
+# Copyright 2020 Canonical Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+#     Unless required by applicable law or agreed to in writing, software
+#     distributed under the License is distributed on an "AS IS" BASIS,
+#     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#     See the License for the specific language governing permissions and
+#     limitations under the License.
+
+import asyncio
+import time
+from juju.client import client
+from n2vc.utils import FinalStatus, EntityType
+from n2vc.exceptions import EntityInvalidException
+from n2vc.n2vc_conn import N2VCConnector
+from juju.model import ModelEntity, Model
+from juju.client.overrides import Delta
+
+import logging
+
+logger = logging.getLogger("__main__")
+
+
+class JujuModelWatcher:
+    @staticmethod
+    async def wait_for(
+        model,
+        entity: ModelEntity,
+        progress_timeout: float = 3600,
+        total_timeout: float = 3600,
+        db_dict: dict = None,
+        n2vc: N2VCConnector = None,
+    ):
+        """
+        Wait for entity to reach its final state.
+
+        :param: model:              Model to observe
+        :param: entity:             Entity object
+        :param: progress_timeout:   Maximum time between two updates in the model
+        :param: total_timeout:      Timeout for the entity to be active
+        :param: db_dict:            Dictionary with data of the DB to write the updates
+        :param: n2vc:               N2VC Connector objector
+
+        :raises: asyncio.TimeoutError when timeout reaches
+        """
+
+        if progress_timeout is None:
+            progress_timeout = 3600.0
+        if total_timeout is None:
+            total_timeout = 3600.0
+
+        entity_type = EntityType.get_entity(type(entity))
+        if entity_type not in FinalStatus:
+            raise EntityInvalidException("Entity type not found")
+
+        # Get final states
+        final_states = FinalStatus[entity_type].status
+        field_to_check = FinalStatus[entity_type].field
+
+        # Coroutine to wait until the entity reaches the final state
+        wait_for_entity = asyncio.ensure_future(
+            asyncio.wait_for(
+                model.block_until(
+                    lambda: entity.__getattribute__(field_to_check) in final_states
+                ),
+                timeout=total_timeout,
+            )
+        )
+
+        # Coroutine to watch the model for changes (and write them to DB)
+        watcher = asyncio.ensure_future(
+            JujuModelWatcher.model_watcher(
+                model,
+                entity_id=entity.entity_id,
+                entity_type=entity_type,
+                timeout=progress_timeout,
+                db_dict=db_dict,
+                n2vc=n2vc,
+            )
+        )
+
+        tasks = [wait_for_entity, watcher]
+        try:
+            # Execute tasks, and stop when the first is finished
+            # The watcher task won't never finish (unless it timeouts)
+            await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
+        except Exception as e:
+            raise e
+        finally:
+            # Cancel tasks
+            for task in tasks:
+                task.cancel()
+
+    @staticmethod
+    async def model_watcher(
+        model: Model,
+        entity_id: str,
+        entity_type: EntityType,
+        timeout: float,
+        db_dict: dict = None,
+        n2vc: N2VCConnector = None,
+    ):
+        """
+        Observes the changes related to an specific entity in a model
+
+        :param: model:          Model to observe
+        :param: entity_id:      ID of the entity to be observed
+        :param: entity_type:    EntityType (p.e. .APPLICATION, .MACHINE, and .ACTION)
+        :param: timeout:        Maximum time between two updates in the model
+        :param: db_dict:        Dictionary with data of the DB to write the updates
+        :param: n2vc:           N2VC Connector objector
+
+        :raises: asyncio.TimeoutError when timeout reaches
+        """
+
+        allwatcher = client.AllWatcherFacade.from_connection(model.connection())
+
+        # Genenerate array with entity types to listen
+        entity_types = (
+            [entity_type, EntityType.UNIT]
+            if entity_type == EntityType.APPLICATION  # TODO: Add .ACTION too
+            else [entity_type]
+        )
+
+        # Get time when it should timeout
+        timeout_end = time.time() + timeout
+
+        while True:
+            change = await allwatcher.Next()
+            for delta in change.deltas:
+                write = False
+                delta_entity = None
+
+                # Get delta EntityType
+                delta_entity = EntityType.get_entity_from_delta(delta.entity)
+
+                if delta_entity in entity_types:
+                    # Get entity id
+                    if entity_type == EntityType.APPLICATION:
+                        id = (
+                            delta.data["application"]
+                            if delta_entity == EntityType.UNIT
+                            else delta.data["name"]
+                        )
+                    else:
+                        id = delta.data["id"]
+
+                    # Write if the entity id match
+                    write = True if id == entity_id else False
+
+                    # Update timeout
+                    timeout_end = time.time() + timeout
+                    (status, status_message, vca_status) = JujuModelWatcher.get_status(
+                        delta, entity_type=delta_entity
+                    )
+
+                    if write and n2vc is not None and db_dict:
+                        # Write status to DB
+                        status = n2vc.osm_status(delta_entity, status)
+                        await n2vc.write_app_status_to_db(
+                            db_dict=db_dict,
+                            status=status,
+                            detailed_status=status_message,
+                            vca_status=vca_status,
+                            entity_type=delta_entity.value.__name__.lower(),
+                        )
+            # Check if timeout
+            if time.time() > timeout_end:
+                raise asyncio.TimeoutError()
+
+    @staticmethod
+    def get_status(delta: Delta, entity_type: EntityType) -> (str, str, str):
+        """
+        Get status from delta
+
+        :param: delta:          Delta generated by the allwatcher
+        :param: entity_type:    EntityType (p.e. .APPLICATION, .MACHINE, and .ACTION)
+
+        :return (status, message, vca_status)
+        """
+        if entity_type == EntityType.MACHINE:
+            return (
+                delta.data["agent-status"]["current"],
+                delta.data["instance-status"]["message"],
+                delta.data["instance-status"]["current"],
+            )
+        elif entity_type == EntityType.ACTION:
+            return (
+                delta.data["status"],
+                delta.data["status"],
+                delta.data["status"],
+            )
+        elif entity_type == EntityType.APPLICATION:
+            return (
+                delta.data["status"]["current"],
+                delta.data["status"]["message"],
+                delta.data["status"]["current"],
+            )
+        elif entity_type == EntityType.UNIT:
+            return (
+                delta.data["workload-status"]["current"],
+                delta.data["workload-status"]["message"],
+                delta.data["workload-status"]["current"],
+            )