Pin black version in tox.ini to 23.12.1
[osm/N2VC.git] / n2vc / libjuju.py
index bca3665..f36ff39 100644 (file)
 
 import asyncio
 import logging
+import os
 import typing
+import yaml
 
 import time
 
 import juju.errors
+from juju.bundle import BundleHandler
 from juju.model import Model
 from juju.machine import Machine
 from juju.application import Application
 from juju.unit import Unit
+from juju.url import URL
+from juju.version import DEFAULT_ARCHITECTURE
 from juju.client._definitions import (
     FullStatus,
     QueryApplicationOffersResults,
@@ -56,11 +61,18 @@ from retrying_async import retry
 RBAC_LABEL_KEY_NAME = "rbac-id"
 
 
+@asyncio.coroutine
+def retry_callback(attempt, exc, args, kwargs, delay=0.5, *, loop):
+    # Specifically overridden from upstream implementation so it can
+    # continue to work with Python 3.10
+    yield from asyncio.sleep(attempt * delay)
+    return retry
+
+
 class Libjuju:
     def __init__(
         self,
         vca_connection: Connection,
-        loop: asyncio.AbstractEventLoop = None,
         log: logging.Logger = None,
         n2vc: N2VCConnector = None,
     ):
@@ -68,7 +80,6 @@ class Libjuju:
         Constructor
 
         :param: vca_connection:         n2vc.vca.connection object
-        :param: loop:                   Asyncio loop
         :param: log:                    Logger
         :param: n2vc:                   N2VC object
         """
@@ -77,15 +88,13 @@ class Libjuju:
         self.n2vc = n2vc
         self.vca_connection = vca_connection
 
-        self.loop = loop or asyncio.get_event_loop()
-        self.loop.set_exception_handler(self.handle_exception)
-        self.creating_model = asyncio.Lock(loop=self.loop)
+        self.creating_model = asyncio.Lock()
 
         if self.vca_connection.is_default:
             self.health_check_task = self._create_health_check_task()
 
     def _create_health_check_task(self):
-        return self.loop.create_task(self.health_check())
+        return asyncio.get_event_loop().create_task(self.health_check())
 
     async def get_controller(self, timeout: float = 60.0) -> Controller:
         """
@@ -122,7 +131,10 @@ class Libjuju:
             )
             if controller:
                 await self.disconnect_controller(controller)
-            raise JujuControllerFailedConnecting(e)
+
+            raise JujuControllerFailedConnecting(
+                f"Error connecting to Juju controller: {e}"
+            )
 
     async def disconnect(self):
         """Disconnect"""
@@ -147,7 +159,7 @@ class Libjuju:
         if controller:
             await controller.disconnect()
 
-    @retry(attempts=3, delay=5, timeout=None)
+    @retry(attempts=3, delay=5, timeout=None, callback=retry_callback)
     async def add_model(self, model_name: str, cloud: VcaCloud):
         """
         Create model
@@ -262,7 +274,7 @@ class Libjuju:
             await self.disconnect_controller(controller)
         return application_configs
 
-    @retry(attempts=3, delay=5)
+    @retry(attempts=3, delay=5, callback=retry_callback)
     async def get_model(self, controller: Controller, model_name: str) -> Model:
         """
         Get model from controller
@@ -546,27 +558,122 @@ class Libjuju:
         return machine_id
 
     async def deploy(
-        self, uri: str, model_name: str, wait: bool = True, timeout: float = 3600
+        self,
+        uri: str,
+        model_name: str,
+        wait: bool = True,
+        timeout: float = 3600,
+        instantiation_params: dict = None,
     ):
         """
         Deploy bundle or charm: Similar to the juju CLI command `juju deploy`
 
-        :param: uri:            Path or Charm Store uri in which the charm or bundle can be found
-        :param: model_name:     Model name
-        :param: wait:           Indicates whether to wait or not until all applications are active
-        :param: timeout:        Time in seconds to wait until all applications are active
+        :param uri:            Path or Charm Store uri in which the charm or bundle can be found
+        :param model_name:     Model name
+        :param wait:           Indicates whether to wait or not until all applications are active
+        :param timeout:        Time in seconds to wait until all applications are active
+        :param instantiation_params: To be applied as overlay bundle over primary bundle.
         """
         controller = await self.get_controller()
         model = await self.get_model(controller, model_name)
+        overlays = []
         try:
-            await model.deploy(uri, trust=True)
+            await self._validate_instantiation_params(uri, model, instantiation_params)
+            overlays = self._get_overlays(model_name, instantiation_params)
+            await model.deploy(uri, trust=True, overlays=overlays)
             if wait:
                 await JujuModelWatcher.wait_for_model(model, timeout=timeout)
                 self.log.debug("All units active in model {}".format(model_name))
         finally:
+            self._remove_overlay_file(overlays)
             await self.disconnect_model(model)
             await self.disconnect_controller(controller)
 
+    async def _validate_instantiation_params(
+        self, uri: str, model, instantiation_params: dict
+    ) -> None:
+        """Checks if all the applications in instantiation_params
+        exist ins the original bundle.
+
+        Raises:
+            JujuApplicationNotFound if there is an invalid app in
+            the instantiation params.
+        """
+        overlay_apps = self._get_apps_in_instantiation_params(instantiation_params)
+        if not overlay_apps:
+            return
+        original_apps = await self._get_apps_in_original_bundle(uri, model)
+        if not all(app in original_apps for app in overlay_apps):
+            raise JujuApplicationNotFound(
+                "Cannot find application {} in original bundle {}".format(
+                    overlay_apps, original_apps
+                )
+            )
+
+    async def _get_apps_in_original_bundle(self, uri: str, model) -> set:
+        """Bundle is downloaded in BundleHandler.fetch_plan.
+        That method takes care of opening and exception handling.
+
+        Resolve method gets all the information regarding the channel,
+        track, revision, type, source.
+
+        Returns:
+            Set with the names of the applications in original bundle.
+        """
+        url = URL.parse(uri)
+        architecture = DEFAULT_ARCHITECTURE  # only AMD64 is allowed
+        res = await model.deploy_types[str(url.schema)].resolve(
+            url, architecture, entity_url=uri
+        )
+        handler = BundleHandler(model, trusted=True, forced=False)
+        await handler.fetch_plan(url, res.origin)
+        return handler.applications
+
+    def _get_apps_in_instantiation_params(self, instantiation_params: dict) -> list:
+        """Extract applications key in instantiation params.
+
+        Returns:
+            List with the names of the applications in instantiation params.
+
+        Raises:
+            JujuError if applications key is not found.
+        """
+        if not instantiation_params:
+            return []
+        try:
+            return [key for key in instantiation_params.get("applications")]
+        except Exception as e:
+            raise JujuError("Invalid overlay format. {}".format(str(e)))
+
+    def _get_overlays(self, model_name: str, instantiation_params: dict) -> list:
+        """Creates a temporary overlay file which includes the instantiation params.
+        Only one overlay file is created.
+
+        Returns:
+            List with one overlay filename. Empty list if there are no instantiation params.
+        """
+        if not instantiation_params:
+            return []
+        file_name = model_name + "-overlay.yaml"
+        self._write_overlay_file(file_name, instantiation_params)
+        return [file_name]
+
+    def _write_overlay_file(self, file_name: str, instantiation_params: dict) -> None:
+        with open(file_name, "w") as file:
+            yaml.dump(instantiation_params, file)
+
+    def _remove_overlay_file(self, overlay: list) -> None:
+        """Overlay contains either one or zero file names."""
+        if not overlay:
+            return
+        try:
+            filename = overlay[0]
+            os.remove(filename)
+        except OSError as e:
+            self.log.warning(
+                "Overlay file {} could not be removed: {}".format(filename, e)
+            )
+
     async def add_unit(
         self,
         application_name: str,
@@ -595,7 +702,6 @@ class Libjuju:
             application = self._get_application(model, application_name)
 
             if application is not None:
-
                 # Checks if the given machine id in the model,
                 # otherwise function raises an error
                 _machine, _series = self._get_machine_info(model, machine_id)
@@ -750,7 +856,6 @@ class Libjuju:
 
         try:
             if application_name not in model.applications:
-
                 if machine_id is not None:
                     machine, series = self._get_machine_info(model, machine_id)
 
@@ -890,7 +995,6 @@ class Libjuju:
         return application
 
     async def resolve_application(self, model_name: str, application_name: str):
-
         controller = await self.get_controller()
         model = await self.get_model(controller, model_name)
 
@@ -923,6 +1027,34 @@ class Libjuju:
             await self.disconnect_model(model)
             await self.disconnect_controller(controller)
 
+    async def resolve(self, model_name: str):
+        controller = await self.get_controller()
+        model = await self.get_model(controller, model_name)
+        all_units_active = False
+        try:
+            while not all_units_active:
+                all_units_active = True
+                for application_name, application in model.applications.items():
+                    if application.status == "error":
+                        for unit in application.units:
+                            if unit.workload_status == "error":
+                                self.log.debug(
+                                    "Model {}, Application {}, Unit {} in error state, resolving".format(
+                                        model_name, application_name, unit.entity_id
+                                    )
+                                )
+                                try:
+                                    await unit.resolved(retry=False)
+                                    all_units_active = False
+                                except Exception:
+                                    pass
+
+                if not all_units_active:
+                    await asyncio.sleep(5)
+        finally:
+            await self.disconnect_model(model)
+            await self.disconnect_controller(controller)
+
     async def scale_application(
         self,
         model_name: str,
@@ -1235,10 +1367,10 @@ class Libjuju:
         try:
             await model.add_relation(endpoint_1, endpoint_2)
         except juju.errors.JujuAPIError as e:
-            if "not found" in e.message:
+            if self._relation_is_not_found(e):
                 self.log.warning("Relation not found: {}".format(e.message))
                 return
-            if "already exists" in e.message:
+            if self._relation_already_exist(e):
                 self.log.warning("Relation already exists: {}".format(e.message))
                 return
             # another exception, raise it
@@ -1247,6 +1379,18 @@ class Libjuju:
             await self.disconnect_model(model)
             await self.disconnect_controller(controller)
 
+    def _relation_is_not_found(self, juju_error):
+        text = "not found"
+        return (text in juju_error.message) or (
+            juju_error.error_code and text in juju_error.error_code
+        )
+
+    def _relation_already_exist(self, juju_error):
+        text = "already exists"
+        return (text in juju_error.message) or (
+            juju_error.error_code and text in juju_error.error_code
+        )
+
     async def offer(self, endpoint: RelationEndpoint) -> Offer:
         """
         Create an offer from a RelationEndpoint
@@ -1326,24 +1470,28 @@ class Libjuju:
         model = None
         try:
             if not await self.model_exists(model_name, controller=controller):
+                self.log.warn(f"Model {model_name} doesn't exist")
                 return
 
-            self.log.debug("Destroying model {}".format(model_name))
-
+            self.log.debug(f"Getting model {model_name} to be destroyed")
             model = await self.get_model(controller, model_name)
+            self.log.debug(f"Destroying manual machines in model {model_name}")
             # Destroy machines that are manually provisioned
             # and still are in pending state
             await self._destroy_pending_machines(model, only_manual=True)
             await self.disconnect_model(model)
 
-            await self._destroy_model(
-                model_name,
-                controller,
+            await asyncio.wait_for(
+                self._destroy_model(model_name, controller),
                 timeout=total_timeout,
             )
         except Exception as e:
             if not await self.model_exists(model_name, controller=controller):
+                self.log.warn(
+                    f"Failed deleting model {model_name}: model doesn't exist"
+                )
                 return
+            self.log.warn(f"Failed deleting model {model_name}: {e}")
             raise e
         finally:
             if model:
@@ -1351,7 +1499,9 @@ class Libjuju:
             await self.disconnect_controller(controller)
 
     async def _destroy_model(
-        self, model_name: str, controller: Controller, timeout: float = 1800
+        self,
+        model_name: str,
+        controller: Controller,
     ):
         """
         Destroy model from controller
@@ -1360,25 +1510,41 @@ class Libjuju:
         :param: controller: Controller object
         :param: timeout: Timeout in seconds
         """
+        self.log.debug(f"Destroying model {model_name}")
 
-        async def _destroy_model_loop(model_name: str, controller: Controller):
-            while await self.model_exists(model_name, controller=controller):
+        async def _destroy_model_gracefully(model_name: str, controller: Controller):
+            self.log.info(f"Gracefully deleting model {model_name}")
+            resolved = False
+            while model_name in await controller.list_models():
+                if not resolved:
+                    await self.resolve(model_name)
+                    resolved = True
+                await controller.destroy_model(model_name, destroy_storage=True)
+
+                await asyncio.sleep(5)
+            self.log.info(f"Model {model_name} deleted gracefully")
+
+        async def _destroy_model_forcefully(model_name: str, controller: Controller):
+            self.log.info(f"Forcefully deleting model {model_name}")
+            while model_name in await controller.list_models():
                 await controller.destroy_model(
-                    model_name, destroy_storage=True, force=True, max_wait=0
+                    model_name, destroy_storage=True, force=True, max_wait=60
                 )
                 await asyncio.sleep(5)
+            self.log.info(f"Model {model_name} deleted forcefully")
 
         try:
-            await asyncio.wait_for(
-                _destroy_model_loop(model_name, controller), timeout=timeout
-            )
-        except asyncio.TimeoutError:
-            raise Exception(
-                "Timeout waiting for model {} to be destroyed".format(model_name)
-            )
+            try:
+                await asyncio.wait_for(
+                    _destroy_model_gracefully(model_name, controller), timeout=120
+                )
+            except asyncio.TimeoutError:
+                await _destroy_model_forcefully(model_name, controller)
         except juju.errors.JujuError as e:
             if any("has been removed" in error for error in e.errors):
                 return
+            if any("model not found" in error for error in e.errors):
+                return
             raise e
 
     async def destroy_application(
@@ -1478,10 +1644,6 @@ class Libjuju:
                     await self.disconnect_model(model)
                 await self.disconnect_controller(controller)
 
-    def handle_exception(self, loop, context):
-        # All unhandled exceptions by libjuju are handled here.
-        pass
-
     async def health_check(self, interval: float = 300.0):
         """
         Health check to make sure controller and controller_model connections are OK
@@ -1686,7 +1848,9 @@ class Libjuju:
         finally:
             await self.disconnect_controller(controller)
 
-    @retry(attempts=20, delay=5, fallback=JujuLeaderUnitNotFound())
+    @retry(
+        attempts=20, delay=5, fallback=JujuLeaderUnitNotFound(), callback=retry_callback
+    )
     async def _get_leader_unit(self, application: Application) -> Unit:
         unit = None
         for u in application.units: