#########################################################################
# Copyright ETSI Contributors and Others.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import timedelta
from unittest.mock import Mock

import asynctest
from osm_common.temporal.activities.paas import TestVimConnectivity
from osm_common.temporal.activities.vim import (
    DeleteVimRecord,
    UpdateVimOperationState,
    UpdateVimState,
)
from osm_common.temporal.states import VimOperationState, VimState
from osm_common.temporal_task_queues.task_queues_mappings import LCM_TASK_QUEUE
from parameterized import parameterized_class
from temporalio import activity
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker

from osm_lcm.temporal.vim_workflows import (
    VimCreateWorkflow,
    VimCreateWorkflowImpl,
    VimDeleteWorkflow,
    VimDeleteWorkflowImpl,
    VimUpdateWorkflow,
    VimUpdateWorkflowImpl,
)
from osm_lcm.tests.utils import validate_workflow_failure_error_type

# Prevent the tasks from running indefinitely
TASK_TIMEOUT = timedelta(seconds=0.1)
# Prevent the workflow from running indefinitely
EXECUTION_TIMEOUT = timedelta(seconds=5)


class TestException(Exception):
    pass


@activity.defn(name=TestVimConnectivity.__name__)
async def mock_test_vim_connectivity(
    test_connectivity_input: TestVimConnectivity.Input,
) -> None:
    pass


@activity.defn(name=TestVimConnectivity.__name__)
async def mock_test_vim_connectivity_raises(
    test_connectivity_input: TestVimConnectivity.Input,
) -> None:
    raise TestException("Test exception")


@activity.defn(name=DeleteVimRecord.__name__)
async def mock_delete_vim_record_raises(data: DeleteVimRecord.Input) -> None:
    raise TestException("Test exception")


class TestVimWorkflowsBase(asynctest.TestCase):
    task_queue_name = LCM_TASK_QUEUE
    vim_id = "some-vim-uuid"
    worflow_id = vim_id
    vim_operation_input = VimCreateWorkflow.Input(vim_id, "op_id")

    @activity.defn(name=UpdateVimState.__name__)
    async def mock_update_vim_state(self, data: UpdateVimState.Input) -> None:
        self.mock_update_vim_state_tracker(data)

    @activity.defn(name=UpdateVimState.__name__)
    async def mock_update_vim_state_raises(self, data: UpdateVimState.Input) -> None:
        self.mock_update_vim_state_tracker(data)
        raise TestException("Test exception")

    @activity.defn(name=UpdateVimOperationState.__name__)
    async def mock_update_vim_operation_state(
        self, data: UpdateVimOperationState.Input
    ) -> None:
        self.mock_update_vim_operation_state_tracker(data)

    @activity.defn(name=UpdateVimOperationState.__name__)
    async def mock_update_vim_operation_state_raises(
        self, data: UpdateVimOperationState.Input
    ) -> None:
        self.mock_update_vim_operation_state_tracker(data)
        raise TestException("Test exception")

    @activity.defn(name=DeleteVimRecord.__name__)
    async def mock_delete_vim_record(self, data: DeleteVimRecord.Input) -> None:
        self.mock_delete_vim_record_tracker(data)

    async def setUp(self):
        self.env = await WorkflowEnvironment.start_time_skipping()
        self.client = self.env.client
        self.mock_update_vim_state_tracker = Mock()
        self.mock_update_vim_operation_state_tracker = Mock()
        self.mock_delete_vim_record_tracker = Mock()

    def get_worker(self, activities: list) -> Worker:
        return Worker(
            self.client,
            task_queue=self.task_queue_name,
            workflows=[
                VimCreateWorkflowImpl,
                VimUpdateWorkflowImpl,
                VimDeleteWorkflowImpl,
            ],
            activities=activities,
            debug_mode=True,
        )

    def check_vim_state_is_updated(self, expected_states):
        """Asserts that the VIM state was set to each of the expected states (in order)."""
        call_args_list = self.mock_update_vim_state_tracker.call_args_list
        self.assertTrue(call_args_list)
        called_states = list()
        for call in call_args_list:
            self.assertEqual(call.args[0].vim_uuid, self.vim_id)
            called_states.append(call.args[0].operational_state)
        self.assertEqual(called_states, expected_states)

    def check_vim_op_state_is_updated(self, expected_states):
        """Asserts that the VIM operation state was set to each of the expected states (in order)."""
        call_args_list = self.mock_update_vim_operation_state_tracker.call_args_list
        self.assertTrue(call_args_list)
        called_states = list()
        for call in call_args_list:
            self.assertEqual(call.args[0].vim_uuid, self.vim_id)
            called_states.append(call.args[0].op_state)
        self.assertEqual(called_states, expected_states)


@parameterized_class(
    [
        {"workflow_name": VimCreateWorkflow.__name__},
        {"workflow_name": VimUpdateWorkflow.__name__},
    ]
)
class TestVimWorkflow(TestVimWorkflowsBase):
    workflow_name: str

    async def test_nominal_case_updates_vim_state_and_vim_op_state(self):
        activities = [
            mock_test_vim_connectivity,
            self.mock_update_vim_state,
            self.mock_update_vim_operation_state,
        ]
        expected_vim_state = [VimState.ENABLED]
        expected_vim_op_state = [VimOperationState.COMPLETED]
        async with self.env, self.get_worker(activities):
            await self.client.execute_workflow(
                self.workflow_name,
                arg=self.vim_operation_input,
                id=self.worflow_id,
                task_queue=self.task_queue_name,
                task_timeout=TASK_TIMEOUT,
                execution_timeout=EXECUTION_TIMEOUT,
            )
        self.check_vim_state_is_updated(expected_vim_state)
        self.check_vim_op_state_is_updated(expected_vim_op_state)

    async def test_fail_update_vim_state_activity__updates_vim_operation_state(self):
        activities = [
            mock_test_vim_connectivity,
            self.mock_update_vim_state_raises,
            self.mock_update_vim_operation_state,
        ]
        retry_policy = 3
        expected_vim_state = [VimState.ENABLED] * retry_policy
        expected_vim_op_state = [VimOperationState.COMPLETED]
        async with self.env, self.get_worker(activities):
            with validate_workflow_failure_error_type(self, TestException):
                await self.client.execute_workflow(
                    self.workflow_name,
                    arg=self.vim_operation_input,
                    id=self.worflow_id,
                    task_queue=self.task_queue_name,
                    task_timeout=TASK_TIMEOUT,
                    execution_timeout=EXECUTION_TIMEOUT,
                )
        self.check_vim_state_is_updated(expected_vim_state)
        self.check_vim_op_state_is_updated(expected_vim_op_state)

    async def test_fail_update_vim_op_state_activity__updates_vim_state(self):
        activities = [
            mock_test_vim_connectivity,
            self.mock_update_vim_state,
            self.mock_update_vim_operation_state_raises,
        ]
        expected_vim_state = [VimState.ENABLED]
        retry_policy = 3
        expected_vim_op_state = [VimOperationState.COMPLETED] * retry_policy
        async with self.env, self.get_worker(activities):
            with validate_workflow_failure_error_type(self, TestException):
                await self.client.execute_workflow(
                    self.workflow_name,
                    arg=self.vim_operation_input,
                    id=self.worflow_id,
                    task_queue=self.task_queue_name,
                    task_timeout=TASK_TIMEOUT,
                    execution_timeout=EXECUTION_TIMEOUT,
                )
        self.check_vim_state_is_updated(expected_vim_state)
        self.check_vim_op_state_is_updated(expected_vim_op_state)

    async def test_connectivity_activity_failure__updates_vim_state_and_vim_op_state_to_failure(
        self,
    ):
        activities = [
            mock_test_vim_connectivity_raises,
            self.mock_update_vim_state,
            self.mock_update_vim_operation_state,
        ]
        expected_vim_state = [VimState.ERROR]
        expected_vim_op_state = [VimOperationState.FAILED]
        async with self.env, self.get_worker(activities):
            with validate_workflow_failure_error_type(self, TestException):
                await self.client.execute_workflow(
                    self.workflow_name,
                    arg=self.vim_operation_input,
                    id=self.worflow_id,
                    task_queue=self.task_queue_name,
                    task_timeout=TASK_TIMEOUT,
                    execution_timeout=EXECUTION_TIMEOUT,
                )
        self.check_vim_state_is_updated(expected_vim_state)
        self.check_vim_op_state_is_updated(expected_vim_op_state)

    async def test_connectivity_and_vim_state_update_activity_failures(self):
        activities = [
            mock_test_vim_connectivity_raises,
            self.mock_update_vim_state_raises,
            self.mock_update_vim_operation_state,
        ]
        retry_policy = 3
        expected_vim_state = [VimState.ERROR] * retry_policy
        expected_vim_op_state = [VimOperationState.FAILED]
        async with self.env, self.get_worker(activities):
            with validate_workflow_failure_error_type(self, TestException):
                await self.client.execute_workflow(
                    self.workflow_name,
                    arg=self.vim_operation_input,
                    id=self.worflow_id,
                    task_queue=self.task_queue_name,
                )
        self.check_vim_state_is_updated(expected_vim_state)
        self.check_vim_op_state_is_updated(expected_vim_op_state)

    async def test_connectivity_and_vim_op_state_update_activity_failures(self):
        activities = [
            mock_test_vim_connectivity_raises,
            self.mock_update_vim_state,
            self.mock_update_vim_operation_state_raises,
        ]
        retry_policy = 3
        expected_vim_state = [VimState.ERROR]
        expected_vim_op_state = [VimOperationState.FAILED] * retry_policy
        async with self.env, self.get_worker(activities):
            with validate_workflow_failure_error_type(self, TestException):
                await self.client.execute_workflow(
                    self.workflow_name,
                    arg=self.vim_operation_input,
                    id=self.worflow_id,
                    task_queue=self.task_queue_name,
                )
        self.check_vim_state_is_updated(expected_vim_state)
        self.check_vim_op_state_is_updated(expected_vim_op_state)

    async def test_connectivity_vim_state_update_and_vim_op_state_update_failures(self):
        activities = [
            mock_test_vim_connectivity_raises,
            self.mock_update_vim_state_raises,
            self.mock_update_vim_operation_state_raises,
        ]
        retry_policy = 3
        expected_vim_state = [VimState.ERROR] * retry_policy
        expected_vim_op_state = [VimOperationState.FAILED] * retry_policy
        async with self.env, self.get_worker(activities):
            with validate_workflow_failure_error_type(self, TestException):
                await self.client.execute_workflow(
                    self.workflow_name,
                    arg=self.vim_operation_input,
                    id=self.worflow_id,
                    task_queue=self.task_queue_name,
                )
        self.check_vim_state_is_updated(expected_vim_state)
        self.check_vim_op_state_is_updated(expected_vim_op_state)


class TestVimDeleteWorkflow(TestVimWorkflowsBase):
    async def test_vim_delete_nominal_case(self):
        activities = [self.mock_delete_vim_record]
        async with self.env, self.get_worker(activities):
            result = await self.client.execute_workflow(
                VimDeleteWorkflow.__name__,
                arg=self.vim_operation_input,
                id=self.worflow_id,
                task_queue=self.task_queue_name,
                task_timeout=TASK_TIMEOUT,
                execution_timeout=EXECUTION_TIMEOUT,
            )
            self.assertIsNone(result)
        self.mock_delete_vim_record_tracker.assert_called_once()

    async def test_vim_delete_exception(self):
        activities = [mock_delete_vim_record_raises]
        async with self.env, self.get_worker(activities):
            with validate_workflow_failure_error_type(self, TestException):
                result = await self.client.execute_workflow(
                    VimDeleteWorkflow.__name__,
                    arg=self.vim_operation_input,
                    id=self.worflow_id,
                    task_queue=self.task_queue_name,
                    task_timeout=TASK_TIMEOUT,
                    execution_timeout=EXECUTION_TIMEOUT,
                )
                self.assertIsNone(result)
