#######################################################################################
# Copyright ETSI Contributors and Others.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from datetime import timedelta
import traceback

from osm_common.temporal.activities.paas import (
    CreateModel,
    RemoveModel,
    CheckModelIsRemoved,
)
from osm_common.temporal.activities.ns import (
    DeleteNsRecord,
    GetVnfDetails,
    GetNsRecord,
    UpdateNsState,
)
from osm_common.temporal.activities.vnf import (
    GetModelNames,
)
from osm_common.temporal.workflows.lcm import LcmOperationWorkflow
from osm_common.temporal.workflows.ns import (
    NsInstantiateWorkflow,
    NsTerminateWorkflow,
    NsDeleteRecordsWorkflow,
)
from osm_common.temporal.workflows.vnf import (
    VnfInstantiateWorkflow,
    VnfTerminateWorkflow,
    VnfDeleteWorkflow,
)
from osm_common.temporal_task_queues.task_queues_mappings import LCM_TASK_QUEUE
from osm_common.temporal.states import NsState
from temporalio import workflow
from temporalio.common import RetryPolicy
from temporalio.converter import value_to_type
from temporalio.exceptions import (
    ActivityError,
    ApplicationError,
    ChildWorkflowError,
    FailureError,
)

_SANDBOXED = False
retry_policy = RetryPolicy(maximum_attempts=3)
no_retry_policy = RetryPolicy(maximum_attempts=1)
default_schedule_to_close_timeout = timedelta(minutes=10)


@workflow.defn(name=NsInstantiateWorkflow.__name__, sandboxed=_SANDBOXED)
class NsInstantiateWorkflowImpl(LcmOperationWorkflow):
    @workflow.run
    async def wrap_nslcmop(self, workflow_input: NsInstantiateWorkflow.Input) -> None:
        await super().wrap_nslcmop(workflow_input=workflow_input)

    async def run(self, workflow_input: NsInstantiateWorkflow.Input) -> None:
        self.logger.info(
            f"Executing {NsInstantiateWorkflow.__name__} with {workflow_input}"
        )

        # TODO: Can we clean up the input? Perhaps this workflow could receive NsInstantiateInput directly.
        ns_uuid = workflow_input.nslcmop["nsInstanceId"]
        vim_uuid = workflow_input.nslcmop["operationParams"]["vimAccountId"]
        model_name = self._get_namespace(ns_uuid, vim_uuid)
        try:
            await workflow.execute_activity(
                activity=CreateModel.__name__,
                arg=CreateModel.Input(vim_uuid=vim_uuid, model_name=model_name),
                activity_id=f"{CreateModel.__name__}-{ns_uuid}",
                schedule_to_close_timeout=NsInstantiateWorkflow.default_schedule_to_close_timeout,
                retry_policy=NsInstantiateWorkflow.no_retry_policy,
            )
            activities_results = await asyncio.gather(
                workflow.execute_activity(
                    activity=GetVnfDetails.__name__,
                    arg=GetVnfDetails.Input(ns_uuid=ns_uuid),
                    activity_id=f"{GetVnfDetails.__name__}-{ns_uuid}",
                    schedule_to_close_timeout=NsInstantiateWorkflow.default_schedule_to_close_timeout,
                    retry_policy=NsInstantiateWorkflow.no_retry_policy,
                ),
                workflow.execute_activity(
                    activity=GetNsRecord.__name__,
                    arg=GetNsRecord.Input(nsr_uuid=ns_uuid),
                    activity_id=f"{GetNsRecord.__name__}-{ns_uuid}",
                    schedule_to_close_timeout=NsInstantiateWorkflow.default_schedule_to_close_timeout,
                    retry_policy=NsInstantiateWorkflow.no_retry_policy,
                ),
            )
            get_vnf_details, get_ns_record = value_to_type(
                GetVnfDetails.Output, activities_results[0]
            ), value_to_type(GetNsRecord.Output, activities_results[1])

            await asyncio.gather(
                *(
                    workflow.execute_child_workflow(
                        workflow=VnfInstantiateWorkflow.__name__,
                        arg=VnfInstantiateWorkflow.Input(
                            vnfr_uuid=vnfr_uuid,
                            model_name=model_name,
                            instantiation_config=NsInstantiateWorkflowImpl.get_vnf_config(
                                vnf_member_index_ref, get_ns_record.nsr
                            ),
                        ),
                        id=f"{VnfInstantiateWorkflow.__name__}-{vnfr_uuid}",
                    )
                    for vnfr_uuid, vnf_member_index_ref in get_vnf_details.vnf_details
                )
            )

        except ActivityError as e:
            err_details = str(e.cause.with_traceback(e.__traceback__))
            await self.update_ns_state(ns_uuid, NsState.INSTANTIATED, err_details)
            self.logger.error(
                f"{NsInstantiateWorkflow.__name__} failed with {err_details}"
            )
            raise e

        except ChildWorkflowError as e:
            err_details = str(e.cause.with_traceback(e.cause.__traceback__))
            await self.update_ns_state(ns_uuid, NsState.INSTANTIATED, err_details)
            self.logger.error(
                f"{NsInstantiateWorkflow.__name__} failed with {err_details}"
            )
            raise e

        except Exception as e:
            err_details = str(traceback.format_exc())
            await self.update_ns_state(ns_uuid, NsState.INSTANTIATED, err_details)
            self.logger.error(
                f"{NsInstantiateWorkflow.__name__} failed with {err_details}"
            )
            raise e

        await self.update_ns_state(ns_uuid, NsState.INSTANTIATED, "Done")

    @staticmethod
    async def update_ns_state(
        ns_uuid: str,
        state: NsState,
        message: str,
    ) -> None:
        activity_input = UpdateNsState.Input(ns_uuid, state, message)
        await workflow.execute_activity(
            activity=UpdateNsState.__name__,
            arg=activity_input,
            activity_id=f"{UpdateNsState.__name__}-{ns_uuid}",
            schedule_to_close_timeout=NsInstantiateWorkflow.default_schedule_to_close_timeout,
            retry_policy=NsInstantiateWorkflow.retry_policy,
        )

    def _get_namespace(self, ns_id: str, vim_id: str) -> str:
        """The NS namespace is the combination if the NS ID and the VIM ID."""
        return ns_id[-12:] + "-" + vim_id[-12:]

    @staticmethod
    def get_vnf_config(vnf_member_index_ref: str, nsr: dict) -> dict:
        """Get the VNF instantiation config
        Args:
            vnf_member_index_ref    (str):     VNF member-index-ref
            nsr                     (dict):     NS record

        Returns:
            vnf_config  (dict)  VNF instantiation config

        """
        for vnf_config in nsr.get("instantiate_params", {}).get("vnf", {}):
            if vnf_config.get("member-vnf-index") == vnf_member_index_ref:
                return vnf_config
        return {}


@workflow.defn(name=NsTerminateWorkflow.__name__, sandboxed=False)
class NsTerminateWorkflowImpl(NsTerminateWorkflow):
    @workflow.run
    async def run(self, workflow_input: NsTerminateWorkflow.Input) -> None:
        try:
            vnfs_details = value_to_type(
                GetVnfDetails.Output,
                await workflow.execute_activity(
                    activity=GetVnfDetails.__name__,
                    arg=GetVnfDetails.Input(ns_uuid=workflow_input.ns_uuid),
                    activity_id=f"{GetVnfDetails.__name__}-{workflow_input.ns_uuid}",
                    schedule_to_close_timeout=default_schedule_to_close_timeout,
                    retry_policy=retry_policy,
                ),
            )
            await asyncio.gather(
                *(
                    workflow.execute_child_workflow(
                        workflow=VnfTerminateWorkflow.__name__,
                        arg=VnfTerminateWorkflow.Input(
                            vnfr_uuid=vnfr_uuid,
                        ),
                        id=f"{VnfTerminateWorkflow.__name__}-{vnfr_uuid}",
                    )
                    for vnfr_uuid, _ in vnfs_details.vnf_details
                )
            )
            models_names = value_to_type(
                GetModelNames.Output,
                await workflow.execute_activity(
                    activity=GetModelNames.__name__,
                    arg=GetModelNames.Input(ns_uuid=workflow_input.ns_uuid),
                    activity_id=f"{GetModelNames.__name__}-{workflow_input.ns_uuid}",
                    schedule_to_close_timeout=default_schedule_to_close_timeout,
                    retry_policy=retry_policy,
                ),
            )
            await asyncio.gather(
                *(
                    NsTerminateWorkflowImpl.remove_model(
                        vim_uuid=workflow_input.vim_uuid,
                        model_name=model_name,
                        force_remove=False,
                    )
                    for model_name in models_names.model_names
                )
            )

            try:
                await asyncio.gather(
                    *(
                        NsTerminateWorkflowImpl.check_model_is_removed(
                            vim_uuid=workflow_input.vim_uuid,
                            model_name=model_name,
                        )
                        for model_name in models_names.model_names
                    )
                )

            except ActivityError:
                self.logger.info("Removing models forcefully.")
                await asyncio.gather(
                    *(
                        NsTerminateWorkflowImpl.remove_model(
                            vim_uuid=workflow_input.vim_uuid,
                            model_name=model_name,
                            force_remove=True,
                        )
                        for model_name in models_names.model_names
                    )
                )

                try:
                    await asyncio.gather(
                        *(
                            NsTerminateWorkflowImpl.check_model_is_removed(
                                vim_uuid=workflow_input.vim_uuid,
                                model_name=model_name,
                            )
                            for model_name in models_names.model_names
                        )
                    )
                except ActivityError as check_model_removal_error:
                    self.logger.error("Some models could not be removed.")
                    raise check_model_removal_error

        except FailureError as e:
            if not hasattr(e, "cause") or e.cause is None:
                raise e
            err_details = str(e.cause.with_traceback(e.__traceback__))
            self.logger.error(
                f"{NsTerminateWorkflow.__name__} failed with {err_details}"
            )
            raise e

    @staticmethod
    async def remove_model(vim_uuid, model_name, force_remove):
        await workflow.execute_activity(
            activity=RemoveModel.__name__,
            arg=RemoveModel.Input(
                vim_uuid=vim_uuid,
                model_name=model_name,
                force_remove=force_remove,
            ),
            activity_id=f"{RemoveModel.__name__}-{vim_uuid}-{model_name}",
            task_queue=LCM_TASK_QUEUE,
            schedule_to_close_timeout=default_schedule_to_close_timeout,
            retry_policy=retry_policy,
        )

    @staticmethod
    async def check_model_is_removed(vim_uuid, model_name):
        await workflow.execute_activity(
            activity=CheckModelIsRemoved.__name__,
            arg=CheckModelIsRemoved.Input(
                vim_uuid=vim_uuid,
                model_name=model_name,
            ),
            activity_id=f"{CheckModelIsRemoved.__name__}-{vim_uuid}-{model_name}",
            task_queue=LCM_TASK_QUEUE,
            start_to_close_timeout=timedelta(minutes=5),
            heartbeat_timeout=timedelta(seconds=30),
            retry_policy=retry_policy,
        )


@workflow.defn(name=NsDeleteRecordsWorkflow.__name__, sandboxed=_SANDBOXED)
class NsDeleteRecordsWorkflowImpl(NsDeleteRecordsWorkflow):
    @workflow.run
    async def run(self, workflow_input: NsDeleteRecordsWorkflow.Input) -> None:
        ns_uuid = workflow_input.ns_uuid
        try:
            nsr = value_to_type(
                GetNsRecord.Output,
                await workflow.execute_activity(
                    activity=GetNsRecord.__name__,
                    arg=GetNsRecord.Input(ns_uuid),
                    activity_id=f"{GetNsRecord.__name__}-{ns_uuid}",
                    task_queue=LCM_TASK_QUEUE,
                    schedule_to_close_timeout=default_schedule_to_close_timeout,
                    retry_policy=retry_policy,
                ),
            ).nsr
            instantiation_state = nsr.get("nsState")
            if instantiation_state != NsState.NOT_INSTANTIATED.name:
                raise ApplicationError(
                    f"NS must be in {NsState.NOT_INSTANTIATED.name} state",
                    non_retryable=True,
                )
            vnf_details = value_to_type(
                GetVnfDetails.Output,
                await workflow.execute_activity(
                    activity=GetVnfDetails.__name__,
                    arg=GetVnfDetails.Input(ns_uuid=ns_uuid),
                    activity_id=f"{GetVnfDetails.__name__}-{ns_uuid}",
                    schedule_to_close_timeout=default_schedule_to_close_timeout,
                    retry_policy=retry_policy,
                ),
            )

            await asyncio.gather(
                *(
                    workflow.execute_child_workflow(
                        workflow=VnfDeleteWorkflow.__name__,
                        arg=VnfDeleteWorkflow.Input(
                            vnfr_uuid=vnfr_uuid,
                        ),
                        id=f"{VnfDeleteWorkflow.__name__}-{vnfr_uuid}",
                    )
                    for vnfr_uuid, _ in vnf_details.vnf_details
                )
            )
            await workflow.execute_activity(
                activity=DeleteNsRecord.__name__,
                arg=DeleteNsRecord.Input(ns_uuid),
                activity_id=f"{DeleteNsRecord.__name__}-{ns_uuid}",
                task_queue=LCM_TASK_QUEUE,
                schedule_to_close_timeout=default_schedule_to_close_timeout,
                retry_policy=retry_policy,
            )

        except FailureError as e:
            if not hasattr(e, "cause") or e.cause is None:
                self.logger.error(
                    f"{NsDeleteRecordsWorkflow.__name__} failed with {str(e)}"
                )
                raise e
            err_details = str(e.cause.with_traceback(e.__traceback__))
            self.logger.error(
                f"{NsDeleteRecordsWorkflow.__name__} failed with {err_details}"
            )
            raise e
