Pin black version in tox.ini to 23.12.1
[osm/N2VC.git] / n2vc / libjuju.py
index 053aaa8..f36ff39 100644 (file)
 
 import asyncio
 import logging
 
 import asyncio
 import logging
+import os
 import typing
 import typing
+import yaml
 
 import time
 
 import juju.errors
 
 import time
 
 import juju.errors
+from juju.bundle import BundleHandler
 from juju.model import Model
 from juju.machine import Machine
 from juju.application import Application
 from juju.unit import Unit
 from juju.model import Model
 from juju.machine import Machine
 from juju.application import Application
 from juju.unit import Unit
+from juju.url import URL
+from juju.version import DEFAULT_ARCHITECTURE
 from juju.client._definitions import (
     FullStatus,
     QueryApplicationOffersResults,
 from juju.client._definitions import (
     FullStatus,
     QueryApplicationOffersResults,
@@ -56,11 +61,18 @@ from retrying_async import retry
 RBAC_LABEL_KEY_NAME = "rbac-id"
 
 
 RBAC_LABEL_KEY_NAME = "rbac-id"
 
 
+@asyncio.coroutine
+def retry_callback(attempt, exc, args, kwargs, delay=0.5, *, loop):
+    # Specifically overridden from upstream implementation so it can
+    # continue to work with Python 3.10
+    yield from asyncio.sleep(attempt * delay)
+    return retry
+
+
 class Libjuju:
     def __init__(
         self,
         vca_connection: Connection,
 class Libjuju:
     def __init__(
         self,
         vca_connection: Connection,
-        loop: asyncio.AbstractEventLoop = None,
         log: logging.Logger = None,
         n2vc: N2VCConnector = None,
     ):
         log: logging.Logger = None,
         n2vc: N2VCConnector = None,
     ):
@@ -68,7 +80,6 @@ class Libjuju:
         Constructor
 
         :param: vca_connection:         n2vc.vca.connection object
         Constructor
 
         :param: vca_connection:         n2vc.vca.connection object
-        :param: loop:                   Asyncio loop
         :param: log:                    Logger
         :param: n2vc:                   N2VC object
         """
         :param: log:                    Logger
         :param: n2vc:                   N2VC object
         """
@@ -77,15 +88,13 @@ class Libjuju:
         self.n2vc = n2vc
         self.vca_connection = vca_connection
 
         self.n2vc = n2vc
         self.vca_connection = vca_connection
 
-        self.loop = loop or asyncio.get_event_loop()
-        self.loop.set_exception_handler(self.handle_exception)
-        self.creating_model = asyncio.Lock(loop=self.loop)
+        self.creating_model = asyncio.Lock()
 
         if self.vca_connection.is_default:
             self.health_check_task = self._create_health_check_task()
 
     def _create_health_check_task(self):
 
         if self.vca_connection.is_default:
             self.health_check_task = self._create_health_check_task()
 
     def _create_health_check_task(self):
-        return self.loop.create_task(self.health_check())
+        return asyncio.get_event_loop().create_task(self.health_check())
 
     async def get_controller(self, timeout: float = 60.0) -> Controller:
         """
 
     async def get_controller(self, timeout: float = 60.0) -> Controller:
         """
@@ -150,7 +159,7 @@ class Libjuju:
         if controller:
             await controller.disconnect()
 
         if controller:
             await controller.disconnect()
 
-    @retry(attempts=3, delay=5, timeout=None)
+    @retry(attempts=3, delay=5, timeout=None, callback=retry_callback)
     async def add_model(self, model_name: str, cloud: VcaCloud):
         """
         Create model
     async def add_model(self, model_name: str, cloud: VcaCloud):
         """
         Create model
@@ -265,7 +274,7 @@ class Libjuju:
             await self.disconnect_controller(controller)
         return application_configs
 
             await self.disconnect_controller(controller)
         return application_configs
 
-    @retry(attempts=3, delay=5)
+    @retry(attempts=3, delay=5, callback=retry_callback)
     async def get_model(self, controller: Controller, model_name: str) -> Model:
         """
         Get model from controller
     async def get_model(self, controller: Controller, model_name: str) -> Model:
         """
         Get model from controller
@@ -549,27 +558,122 @@ class Libjuju:
         return machine_id
 
     async def deploy(
         return machine_id
 
     async def deploy(
-        self, uri: str, model_name: str, wait: bool = True, timeout: float = 3600
+        self,
+        uri: str,
+        model_name: str,
+        wait: bool = True,
+        timeout: float = 3600,
+        instantiation_params: dict = None,
     ):
         """
         Deploy bundle or charm: Similar to the juju CLI command `juju deploy`
 
     ):
         """
         Deploy bundle or charm: Similar to the juju CLI command `juju deploy`
 
-        :param: uri:            Path or Charm Store uri in which the charm or bundle can be found
-        :param: model_name:     Model name
-        :param: wait:           Indicates whether to wait or not until all applications are active
-        :param: timeout:        Time in seconds to wait until all applications are active
+        :param uri:            Path or Charm Store uri in which the charm or bundle can be found
+        :param model_name:     Model name
+        :param wait:           Indicates whether to wait or not until all applications are active
+        :param timeout:        Time in seconds to wait until all applications are active
+        :param instantiation_params: To be applied as overlay bundle over primary bundle.
         """
         controller = await self.get_controller()
         model = await self.get_model(controller, model_name)
         """
         controller = await self.get_controller()
         model = await self.get_model(controller, model_name)
+        overlays = []
         try:
         try:
-            await model.deploy(uri, trust=True)
+            await self._validate_instantiation_params(uri, model, instantiation_params)
+            overlays = self._get_overlays(model_name, instantiation_params)
+            await model.deploy(uri, trust=True, overlays=overlays)
             if wait:
                 await JujuModelWatcher.wait_for_model(model, timeout=timeout)
                 self.log.debug("All units active in model {}".format(model_name))
         finally:
             if wait:
                 await JujuModelWatcher.wait_for_model(model, timeout=timeout)
                 self.log.debug("All units active in model {}".format(model_name))
         finally:
+            self._remove_overlay_file(overlays)
             await self.disconnect_model(model)
             await self.disconnect_controller(controller)
 
             await self.disconnect_model(model)
             await self.disconnect_controller(controller)
 
+    async def _validate_instantiation_params(
+        self, uri: str, model, instantiation_params: dict
+    ) -> None:
+        """Checks if all the applications in instantiation_params
+        exist ins the original bundle.
+
+        Raises:
+            JujuApplicationNotFound if there is an invalid app in
+            the instantiation params.
+        """
+        overlay_apps = self._get_apps_in_instantiation_params(instantiation_params)
+        if not overlay_apps:
+            return
+        original_apps = await self._get_apps_in_original_bundle(uri, model)
+        if not all(app in original_apps for app in overlay_apps):
+            raise JujuApplicationNotFound(
+                "Cannot find application {} in original bundle {}".format(
+                    overlay_apps, original_apps
+                )
+            )
+
+    async def _get_apps_in_original_bundle(self, uri: str, model) -> set:
+        """Bundle is downloaded in BundleHandler.fetch_plan.
+        That method takes care of opening and exception handling.
+
+        Resolve method gets all the information regarding the channel,
+        track, revision, type, source.
+
+        Returns:
+            Set with the names of the applications in original bundle.
+        """
+        url = URL.parse(uri)
+        architecture = DEFAULT_ARCHITECTURE  # only AMD64 is allowed
+        res = await model.deploy_types[str(url.schema)].resolve(
+            url, architecture, entity_url=uri
+        )
+        handler = BundleHandler(model, trusted=True, forced=False)
+        await handler.fetch_plan(url, res.origin)
+        return handler.applications
+
+    def _get_apps_in_instantiation_params(self, instantiation_params: dict) -> list:
+        """Extract applications key in instantiation params.
+
+        Returns:
+            List with the names of the applications in instantiation params.
+
+        Raises:
+            JujuError if applications key is not found.
+        """
+        if not instantiation_params:
+            return []
+        try:
+            return [key for key in instantiation_params.get("applications")]
+        except Exception as e:
+            raise JujuError("Invalid overlay format. {}".format(str(e)))
+
+    def _get_overlays(self, model_name: str, instantiation_params: dict) -> list:
+        """Creates a temporary overlay file which includes the instantiation params.
+        Only one overlay file is created.
+
+        Returns:
+            List with one overlay filename. Empty list if there are no instantiation params.
+        """
+        if not instantiation_params:
+            return []
+        file_name = model_name + "-overlay.yaml"
+        self._write_overlay_file(file_name, instantiation_params)
+        return [file_name]
+
+    def _write_overlay_file(self, file_name: str, instantiation_params: dict) -> None:
+        with open(file_name, "w") as file:
+            yaml.dump(instantiation_params, file)
+
+    def _remove_overlay_file(self, overlay: list) -> None:
+        """Overlay contains either one or zero file names."""
+        if not overlay:
+            return
+        try:
+            filename = overlay[0]
+            os.remove(filename)
+        except OSError as e:
+            self.log.warning(
+                "Overlay file {} could not be removed: {}".format(filename, e)
+            )
+
     async def add_unit(
         self,
         application_name: str,
     async def add_unit(
         self,
         application_name: str,
@@ -598,7 +702,6 @@ class Libjuju:
             application = self._get_application(model, application_name)
 
             if application is not None:
             application = self._get_application(model, application_name)
 
             if application is not None:
-
                 # Checks if the given machine id in the model,
                 # otherwise function raises an error
                 _machine, _series = self._get_machine_info(model, machine_id)
                 # Checks if the given machine id in the model,
                 # otherwise function raises an error
                 _machine, _series = self._get_machine_info(model, machine_id)
@@ -753,7 +856,6 @@ class Libjuju:
 
         try:
             if application_name not in model.applications:
 
         try:
             if application_name not in model.applications:
-
                 if machine_id is not None:
                     machine, series = self._get_machine_info(model, machine_id)
 
                 if machine_id is not None:
                     machine, series = self._get_machine_info(model, machine_id)
 
@@ -893,7 +995,6 @@ class Libjuju:
         return application
 
     async def resolve_application(self, model_name: str, application_name: str):
         return application
 
     async def resolve_application(self, model_name: str, application_name: str):
-
         controller = await self.get_controller()
         model = await self.get_model(controller, model_name)
 
         controller = await self.get_controller()
         model = await self.get_model(controller, model_name)
 
@@ -927,7 +1028,6 @@ class Libjuju:
             await self.disconnect_controller(controller)
 
     async def resolve(self, model_name: str):
             await self.disconnect_controller(controller)
 
     async def resolve(self, model_name: str):
-
         controller = await self.get_controller()
         model = await self.get_model(controller, model_name)
         all_units_active = False
         controller = await self.get_controller()
         model = await self.get_model(controller, model_name)
         all_units_active = False
@@ -1544,10 +1644,6 @@ class Libjuju:
                     await self.disconnect_model(model)
                 await self.disconnect_controller(controller)
 
                     await self.disconnect_model(model)
                 await self.disconnect_controller(controller)
 
-    def handle_exception(self, loop, context):
-        # All unhandled exceptions by libjuju are handled here.
-        pass
-
     async def health_check(self, interval: float = 300.0):
         """
         Health check to make sure controller and controller_model connections are OK
     async def health_check(self, interval: float = 300.0):
         """
         Health check to make sure controller and controller_model connections are OK
@@ -1752,7 +1848,9 @@ class Libjuju:
         finally:
             await self.disconnect_controller(controller)
 
         finally:
             await self.disconnect_controller(controller)
 
-    @retry(attempts=20, delay=5, fallback=JujuLeaderUnitNotFound())
+    @retry(
+        attempts=20, delay=5, fallback=JujuLeaderUnitNotFound(), callback=retry_callback
+    )
     async def _get_leader_unit(self, application: Application) -> Unit:
         unit = None
         for u in application.units:
     async def _get_leader_unit(self, application: Application) -> Unit:
         unit = None
         for u in application.units: