#######################################################################################
# Copyright ETSI Contributors and Others.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import logging
from dataclasses import dataclass

from juju.application import Application
from juju.controller import Controller
from juju.model import Model
from n2vc.config import EnvironConfig
from osm_common.temporal.activities.paas import (
    TestVimConnectivity,
    CheckCharmStatus,
    CheckCharmIsRemoved,
    CheckModelIsRemoved,
    CreateModel,
    DeployCharm,
    RemoveCharm,
    RemoveModel,
    ResolveCharmErrors,
)
from osm_common.temporal.dataclasses_common import (
    CharmInfo,
    VduComputeConstraints,
)
from osm_lcm.data_utils.database.database import Database
from temporalio import activity


@dataclass
class ConnectionInfo:
    """Information to connect to juju controller"""

    endpoint: str
    user: str
    password: str
    cacert: str

    def __repr__(self):
        return f"{self.__class__.__name__}(endpoint: {self.endpoint}, user: {self.user}, password: ******, cacert: ******)"

    def __str__(self):
        return f"{self.__class__.__name__}(endpoint: {self.endpoint}, user: {self.user}, password: ******, cacert: ******)"


class JujuPaasConnector:
    """Handles Juju Controller operations.

    Args:
        db  (Database):       Data Access Object
    """

    def __init__(self, db: Database):
        self.db: Database = db
        self.logger = logging.getLogger(f"lcm.act.{self.__class__.__name__}")
        self.config = EnvironConfig()

    def _decrypt_password(self, vim_content: dict) -> str:
        """Decrypt a password.
            vim_content     (dict):     VIM details as a dictionary

        Returns:
            plain text password (str)
        """
        return self.db.decrypt(
            vim_content["vim_password"],
            schema_version=vim_content["schema_version"],
            salt=vim_content["_id"],
        )

    def _get_connection_info(self, vim_id: str) -> ConnectionInfo:
        """Get VIM details from database using vim_id and returns
         the Connection Info to connect Juju Controller.

        Args:
            vim_id  (str):      VIM ID

        Returns:
            ConnectionInfo  (object)
        """
        vim_content = self.db.get_one("vim_accounts", {"_id": vim_id})
        endpoint = vim_content["vim_url"]
        username = vim_content["vim_user"]
        vim_config = vim_content["config"]
        cacert = vim_config["ca_cert_content"]
        password = self._decrypt_password(vim_content)
        return ConnectionInfo(endpoint, username, password, cacert)

    async def _get_controller(self, vim_uuid) -> Controller:
        connection_info = self._get_connection_info(vim_uuid)
        controller = Controller()
        await controller.connect(
            endpoint=connection_info.endpoint,
            username=connection_info.user,
            password=connection_info.password,
            cacert=connection_info.cacert,
        )
        return controller

    @staticmethod
    def _get_application_constraints(
        constraints: VduComputeConstraints, cloud: str
    ) -> dict:
        application_constraints = {}
        if constraints.mem:
            # Converting memory to MB as this is provided in GB
            application_constraints["mem"] = constraints.mem * 1024
        # Kubernetes cloud does not support setting cores.
        # https://juju.is/docs/olm/constraint#heading--kubernetes
        if constraints.cores and cloud not in ["microk8s", "kubernetes"]:
            application_constraints["cores"] = constraints.cores
        return application_constraints

    def _check_units_ready(
        self, application: Application, last_unit_status: dict
    ) -> bool:
        for unit in application.units:
            unit_workload_status = unit.workload_status
            if unit_workload_status != last_unit_status.get(unit, None):
                last_unit_status[unit] = unit_workload_status
                self.logger.debug(
                    f"Application `{application.name}` Unit `{unit}` is {unit_workload_status}"
                )
            if unit_workload_status not in ["active", "blocked"]:
                return False
        return True


@activity.defn(name=TestVimConnectivity.__name__)
class TestVimConnectivityImpl(TestVimConnectivity):
    async def __call__(self, activity_input: TestVimConnectivity.Input) -> None:
        vim_id = activity_input.vim_uuid
        await self.juju_controller._get_controller(vim_id)
        message = f"Connection to juju controller succeeded for {vim_id}"
        self.logger.info(message)


@activity.defn(name=CreateModel.__name__)
class CreateModelImpl(CreateModel):
    async def __call__(self, activity_input: CreateModel.Input) -> None:
        controller = await self.juju_controller._get_controller(activity_input.vim_uuid)
        if activity_input.model_name in await controller.list_models():
            self.logger.debug(f"Model {activity_input.model_name} already created")
            return

        vim_content = self.db.get_one("vim_accounts", {"_id": activity_input.vim_uuid})
        vim_config = vim_content["config"]

        config = {
            "endpoints": ",".join(await controller.api_endpoints),
            "user": vim_content["vim_user"],
            "secret": self.juju_controller._decrypt_password(vim_content),
            "cacert": base64.b64encode(
                vim_config["ca_cert_content"].encode("utf-8")
            ).decode("utf-8"),
            "authorized-keys": vim_config["authorized_keys"],
        }

        self.logger.debug(f"Creating model {activity_input.model_name}")
        await controller.add_model(
            activity_input.model_name,
            config=config,
            cloud_name=vim_config["cloud"],
            credential_name=vim_config["cloud_credentials"],
        )
        self.logger.debug(f"Model {activity_input.model_name} created")


@activity.defn(name=DeployCharm.__name__)
class DeployCharmImpl(DeployCharm):
    async def __call__(self, activity_input: DeployCharm.Input) -> None:
        model_name = activity_input.model_name
        charm_info = activity_input.charm_info
        application_name = charm_info.app_name
        constraints = self.juju_controller._get_application_constraints(
            activity_input.constraints, activity_input.cloud
        )
        controller = await self.juju_controller._get_controller(activity_input.vim_uuid)
        model = await controller.get_model(model_name)
        if application_name in model.applications:
            raise Exception("Application {} already exists".format(application_name))
        await model.deploy(
            entity_url=charm_info.entity_url,
            application_name=application_name,
            channel=charm_info.channel,
            constraints=constraints if constraints else None,
            config=activity_input.config,
        )


@activity.defn(name=CheckCharmStatus.__name__)
class CheckCharmStatusImpl(CheckCharmStatus):
    async def __call__(self, activity_input: CheckCharmStatus.Input) -> None:
        controller = await self.juju_controller._get_controller(activity_input.vim_uuid)
        model = await controller.get_model(activity_input.model_name)
        application = model.applications[activity_input.application_name]

        ready = False
        last_status = None
        application_status = None
        last_unit_status = {}

        while not ready:
            activity.heartbeat()
            await asyncio.sleep(activity_input.poll_interval)
            # Perform the fetch of the status only once and keep it locally
            application_status = application.status
            if application_status != last_status:
                last_status = application_status
                self.logger.debug(
                    f"Application `{activity_input.application_name}` is {application_status}"
                )

            if application_status in ["active", "blocked"]:
                # Check each unit to see if they are also ready
                if not self.juju_controller._check_units_ready(
                    application=application, last_unit_status=last_unit_status
                ):
                    continue
            else:
                continue
            ready = True


class CharmInfoUtils:
    @staticmethod
    def get_charm_info(vdu: dict, sw_image_descs: list) -> CharmInfo:
        """Extract the charm info of a VDU.
        Args:
            vdu (dict):    contains the charm information.
            sw_image_descs (list): list of images in the VNF.

        Returns:
            CharmInfo  (object)
        """
        app_name = vdu.get("id")
        entity_url, channel = CharmInfoUtils._get_entity_url_and_channel(
            vdu.get("sw-image-desc"), sw_image_descs
        )
        return CharmInfo(app_name, channel, entity_url)

    @staticmethod
    def _get_entity_url_and_channel(
        sw_image_desc_id: str, sw_image_descs: list
    ) -> tuple:
        """Extract the image field for a given image_id
        Args:
            sw_image_desc_id (str): ID of the image used by a VDU.
            sw_image_descs (list): information of available images in the VNF.

        Returns:
            image and version of the sw_image_desc_id (str)
        """
        filtered_image = next(
            filter(lambda image: image.get("id") == sw_image_desc_id, sw_image_descs),
            None,
        )
        if filtered_image:
            return filtered_image.get("image"), filtered_image.get("version")
        return None, None


@activity.defn(name=RemoveCharm.__name__)
class RemoveCharmImpl(RemoveCharm):
    async def __call__(self, activity_input: RemoveCharm.Input) -> None:
        app_name = activity_input.application_name
        model_name = activity_input.model_name
        force_remove = activity_input.force_remove
        controller: Controller = await self.juju_controller._get_controller(
            activity_input.vim_uuid
        )
        if model_name not in (await controller.list_models()):
            return
        model = await controller.get_model(model_name)
        if app_name not in model.applications:
            return
        await model.remove_application(
            app_name=app_name,
            block_until_done=False,
            force=force_remove,
            no_wait=force_remove,
            destroy_storage=True,
        )


@activity.defn(name=CheckCharmIsRemoved.__name__)
class CheckCharmIsRemovedImpl(CheckCharmIsRemoved):
    async def __call__(self, activity_input: CheckCharmIsRemoved.Input) -> None:
        controller = await self.juju_controller._get_controller(activity_input.vim_uuid)
        if activity_input.model_name not in (await controller.list_models()):
            return
        model = await controller.get_model(activity_input.model_name)
        app_name = activity_input.application_name
        while app_name in model.applications:
            activity.heartbeat()
            await asyncio.sleep(activity_input.poll_interval)


@activity.defn(name=RemoveModel.__name__)
class RemoveModelImpl(RemoveModel):
    async def __call__(self, activity_input: RemoveModel.Input) -> None:
        model_name = activity_input.model_name
        controller: Controller = await self.juju_controller._get_controller(
            activity_input.vim_uuid
        )
        if model_name not in (await controller.list_models()):
            return
        await controller.destroy_models(
            model_name,
            destroy_storage=True,
            force=activity_input.force_remove,
        )


@activity.defn(name=CheckModelIsRemoved.__name__)
class CheckModelIsRemovedImpl(CheckModelIsRemoved):
    async def __call__(self, activity_input: CheckModelIsRemoved.Input) -> None:
        model_name = activity_input.model_name
        controller: Controller = await self.juju_controller._get_controller(
            activity_input.vim_uuid
        )
        while model_name in (await controller.list_models()):
            activity.heartbeat()
            await asyncio.sleep(activity_input.poll_interval)


@activity.defn(name=ResolveCharmErrors.__name__)
class ResolveCharmErrorsImpl(ResolveCharmErrors):
    async def __call__(self, activity_input: ResolveCharmErrors.Input) -> None:
        model_name: str = activity_input.model_name
        application_name: str = activity_input.application_name
        controller: Controller = await self.juju_controller._get_controller(
            activity_input.vim_uuid
        )
        if model_name not in (await controller.list_models()):
            return
        model: Model = await controller.get_model(model_name)
        if application_name not in model.applications:
            return
        application: Application = model.applications[application_name]
        while not await ResolveCharmErrorsImpl.is_error_resolved(application):
            await self.resolve_error(application)
            activity.heartbeat()
            await asyncio.sleep(activity_input.poll_interval)

    @staticmethod
    async def is_error_resolved(application) -> bool:
        return application.status != "error"

    async def resolve_error(self, application) -> None:
        for unit in application.units:
            if unit.workload_status == "error":
                self.logger.debug(
                    f"Application `{application.entity_id}`, Unit `{unit.entity_id}` in error state, resolving"
                )
                await unit.resolved(retry=False)
