#######################################################################################
# Copyright ETSI Contributors and Others.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asynctest
from copy import deepcopy
from unittest import TestCase
from unittest.mock import Mock, patch

from osm_common.dbbase import DbException
from osm_common.temporal_task_queues.task_queues_mappings import LCM_TASK_QUEUE
from osm_lcm.temporal.vnf_activities import (
    DeleteVnfRecordImpl,
    ChangeVnfInstantiationStateImpl,
    ChangeVnfStateImpl,
    GetModelNamesImpl,
    GetTaskQueueImpl,
    GetVimCloudImpl,
    GetVnfDescriptorImpl,
    GetVnfRecordImpl,
    SetVnfModelImpl,
    VnfSpecifications,
)
from osm_common.temporal.dataclasses_common import VduComputeConstraints
from osm_common.temporal.states import VnfState, VnfInstantiationState
from temporalio.testing import ActivityEnvironment


vnfr_uuid = "9f472177-95c0-4335-b357-5cdc17a79965"
vnfd_uuid = "97784f19-d254-4252-946c-cf92d85443ca"
vim_uuid = "a64f7c6c-bc27-4ec8-b664-5500a3324eca"
model_name = "my-model-name"
set_vnf_model_input = SetVnfModelImpl.Input(vnfr_uuid=vnfr_uuid, model_name=model_name)
cloud = "microk8s"
nsr_uuid = "dcf4c922-5a73-41bf-a6ca-870c22d6336c"
sample_vim_record = {
    "_id": vim_uuid,
    "name": "juju",
    "vim_type": "paas",
    "vim_url": "192.168.1.100:17070",
    "vim_user": "admin",
    "vim_password": "c16gylWEepEREN6vWw==",
    "config": {
        "paas_provider": "juju",
        "cloud": cloud,
        "cloud_credentials": "microk8s",
        "authorized_keys": "$HOME/.local/share/juju/ssh/juju_id_rsa.pub",
        "ca_cert_content": "-----BEGIN-----",
    },
}
vim_account_id = "9b0bedc3-ea8e-42fd-acc9-dd03f4dee73c"
vdu_id = "hackfest_basic-VM"
vnf_index = "vnf-profile-id"
sample_vnfr = {
    "_id": vnfr_uuid,
    "id": vnfr_uuid,
    "nsr-id-ref": nsr_uuid,
    "vnfd-ref": "jar_vnfd_scalable",
    "vnfd-id": "f1b38eac-190c-485d-9a74-c6e169c929d8",
    "vim-account-id": vim_account_id,
    "namespace": model_name,
    "member-vnf-index-ref": vnf_index,
}
vnf_config = {
    "member-vnf-index": vnf_index,
    "vdu": [
        {
            "id": vdu_id,
            "configurable-properties": {
                "config::redirect-map": "https://osm.instantiation.params"
            },
        }
    ],
}
app_config = {"domain_name1": "osm.org", "domain_name2": "osm.com"}
config = [
    {"key": "config::domain_name1", "value": "osm.org"},
    {"key": "config::domain_name2", "value": "osm.com"},
    {"key": "track", "value": "latest"},
    {"key": "channel", "value": "stable"},
]
vdu = {
    "id": vdu_id,
    "name": "hackfest_basic-VM",
    "sw-image-desc": "ubuntu18.04",
    "virtual-compute-desc": "hackfest_basic-VM-compute",
    "virtual-storage-desc": ["hackfest_basic-VM-storage"],
    "configurable-properties": config,
}
vnfd_id = "97784f19-d254-4252-946c-cf92d85443ca"
flavor_1 = {
    "id": "compute-id-1",
    "virtual-memory": {"size": "4"},
    "virtual-cpu": {"cpu-architecture": "x86", "num-virtual-cpu": 2},
}
flavor_2 = {
    "id": "compute-id-2",
    "virtual-cpu": {"cpu-architecture": "x86", "num-virtual-cpu": 2},
}
flavor_3 = {
    "id": "compute-id-2",
    "virtual-memory": {"size": "4"},
}
sample_vnfd = {
    "_id": vnfd_id,
    "id": "sol006-vnf",
    "provider": "Canonical",
    "product-name": "test-vnf",
    "software-version": "1.0",
    "vdu": [vdu],
    "virtual-compute-desc": [flavor_1, flavor_2],
}


class TestException(Exception):
    pass


class TestVnfDbActivity(asynctest.TestCase):
    def setUp(self):
        self.db = Mock()
        self.env = ActivityEnvironment()
        self.change_vnf_state = ChangeVnfStateImpl(self.db)
        self.change_vnf_instantiation_state = ChangeVnfInstantiationStateImpl(self.db)
        self.set_vnf_model = SetVnfModelImpl(self.db)

    async def test_change_vnf_state__successful__db_updated_as_expected(self):
        vnf_state = VnfState.STOPPED
        change_vnf_state_input = ChangeVnfStateImpl.Input(vnfr_uuid, vnf_state)
        await self.env.run(self.change_vnf_state, change_vnf_state_input)
        self.db.set_one.assert_called_with(
            "vnfrs", {"_id": vnfr_uuid}, {"vnfState": vnf_state.name}
        )

    async def test_change_vnf_state__failed__raises_db_exception(self):
        change_vnf_state_input = ChangeVnfStateImpl.Input(vnfr_uuid, VnfState.STARTED)
        self.db.set_one.side_effect = DbException("not found")
        with self.assertRaises(DbException):
            await self.env.run(self.change_vnf_state, change_vnf_state_input)

    async def test_change_vnf_instantiation_state__successful__db_updated_as_expected(
        self,
    ):
        instantiation_state = VnfInstantiationState.NOT_INSTANTIATED
        change_instantiation_input = ChangeVnfInstantiationStateImpl.Input(
            vnfr_uuid, instantiation_state
        )
        await self.env.run(
            self.change_vnf_instantiation_state,
            change_instantiation_input,
        )
        self.db.set_one.assert_called_with(
            "vnfrs",
            {"_id": vnfr_uuid},
            {"instantiationState": instantiation_state.name},
        )

    async def test_change_vnf_instantiation_state__failed__raises_db_exception(self):
        change_instantiation_input = ChangeVnfInstantiationStateImpl.Input(
            vnfr_uuid, VnfInstantiationState.INSTANTIATED
        )
        self.db.set_one.side_effect = DbException("not found")
        with self.assertRaises(DbException):
            await self.env.run(
                self.change_vnf_instantiation_state,
                change_instantiation_input,
            )

    async def test_set_vnf_model__successful__db_updated_as_expected(self):
        await self.env.run(self.set_vnf_model, set_vnf_model_input)
        self.db.set_one.assert_called_with(
            "vnfrs", {"_id": vnfr_uuid}, {"namespace": model_name}
        )

    async def test_set_vnf_model__failed__raises_db_exception(self):
        self.db.set_one.side_effect = DbException("not found")
        with self.assertRaises(DbException):
            await self.env.run(self.set_vnf_model, set_vnf_model_input)


class TestGetTaskQueue(asynctest.TestCase):
    async def setUp(self):
        self.db = Mock()
        self.get_task_queue = GetTaskQueueImpl(self.db)
        self.env = ActivityEnvironment()
        self.get_task_queue_input = GetTaskQueueImpl.Input(vnfr_uuid=sample_vnfr["id"])

    async def test_activity__succeeded__get_expected_task_queue(self):
        self.db.get_one.side_effect = [sample_vnfr, sample_vim_record]
        activity_result = await self.env.run(
            self.get_task_queue,
            self.get_task_queue_input,
        )
        self.assertEqual(activity_result.task_queue, LCM_TASK_QUEUE)

    async def test_activity__failed__raises_db_exception(self):
        self.db.get_one.side_effect = DbException("not found")
        with self.assertRaises(DbException):
            await self.env.run(
                self.get_task_queue,
                self.get_task_queue_input,
            )

    async def test_activity__invalid_task_queue__raises_key_error(self):
        vim_record = deepcopy(sample_vim_record)
        vim_record["vim_type"] = "some-vim-type"
        self.db.get_one.side_effect = [sample_vnfr, vim_record]
        with self.assertRaises(KeyError):
            await self.env.run(
                self.get_task_queue,
                self.get_task_queue_input,
            )


class TestGetVnfDescriptor(asynctest.TestCase):
    async def setUp(self):
        self.db = Mock()
        self.get_vnf_descriptor = GetVnfDescriptorImpl(self.db)
        self.env = ActivityEnvironment()

    async def test_activity__succeeded__get_expected_vnfd(self):
        self.db.get_one.return_value = sample_vnfd
        activity_result = await self.env.run(
            self.get_vnf_descriptor,
            GetVnfDescriptorImpl.Input(vnfd_uuid=vnfd_uuid),
        )
        self.assertEqual(activity_result.vnfd, sample_vnfd)

    async def test_activity__failed__raises_db_exception(self):
        self.db.get_one.side_effect = DbException("Can not connect to Database.")
        with self.assertRaises(DbException):
            await self.env.run(
                self.get_vnf_descriptor,
                GetVnfDescriptorImpl.Input(vnfd_uuid=vnfd_uuid),
            )


class TestGetVnfRecord(asynctest.TestCase):
    async def setUp(self):
        self.db = Mock()
        self.get_vnf_record = GetVnfRecordImpl(self.db)
        self.env = ActivityEnvironment()

    async def test_activity__succeeded__get_expected_vnfr(self):
        self.db.get_one.return_value = sample_vnfr
        activity_result = await self.env.run(
            self.get_vnf_record,
            GetVnfRecordImpl.Input(vnfr_uuid=vnfr_uuid),
        )
        self.assertEqual(activity_result.vnfr, sample_vnfr)

    async def test_activity__failed__raise_db_exception(self):
        self.db.get_one.side_effect = DbException("Can not connect to Database.")
        with self.assertRaises(DbException):
            await self.env.run(
                self.get_vnf_record,
                GetVnfRecordImpl.Input(vnfr_uuid=vnfr_uuid),
            )


class TestDeleteVnfRecord(asynctest.TestCase):
    async def setUp(self):
        self.db = Mock()
        self.delete_vnf_record = DeleteVnfRecordImpl(self.db)
        self.env = ActivityEnvironment()

    async def test_activity__succeeded__expected_record_is_deleted(self):
        self.db.del_one.return_value = None
        await self.env.run(
            self.delete_vnf_record,
            DeleteVnfRecordImpl.Input(vnfr_uuid=vnfr_uuid),
        )
        self.db.del_one.assert_called_with(
            "vnfrs", {"_id": vnfr_uuid}, fail_on_empty=False
        )

    async def test_activity__failed__raise_db_exception(self):
        self.db.del_one.side_effect = DbException("Can not connect to Database.")
        with self.assertRaises(DbException):
            await self.env.run(
                self.delete_vnf_record,
                DeleteVnfRecordImpl.Input(vnfr_uuid=vnfr_uuid),
            )


class TestGetVimCloud(asynctest.TestCase):
    async def setUp(self):
        self.db = Mock()
        self.get_vim_cloud = GetVimCloudImpl(self.db)
        self.env = ActivityEnvironment()

    async def test_activity__succeeded__get_vim_cloud(self):
        self.db.get_one.side_effect = [sample_vnfr, sample_vim_record]
        activity_result = await self.env.run(
            self.get_vim_cloud,
            GetVimCloudImpl.Input(vnfr_uuid=sample_vnfr["id"]),
        )
        self.assertEqual(activity_result.cloud, cloud)

    async def test_activity__vim_record_without_cloud__no_cloud_info(self):
        vim_record = deepcopy(sample_vim_record)
        vim_record["config"].pop("cloud")
        self.db.get_one.side_effect = [sample_vnfr, vim_record]
        activity_result = await self.env.run(
            self.get_vim_cloud,
            GetVimCloudImpl.Input(vnfr_uuid=sample_vnfr["id"]),
        )
        self.assertEqual(activity_result.cloud, "")

    async def test_activity__failed__raise_db_exception(self):
        self.db.get_one.side_effect = DbException("Can not connect to Database.")
        with self.assertRaises(DbException):
            await self.env.run(
                self.get_vim_cloud,
                GetVimCloudImpl.Input(vnfr_uuid=sample_vnfr["id"]),
            )

    async def test_activity__wrong_vim_record__raise_key_error(self):
        vim_record = deepcopy(sample_vim_record)
        vim_record.pop("config")
        self.db.get_one.side_effect = [sample_vnfr, vim_record]
        with self.assertRaises(KeyError):
            await self.env.run(
                self.get_vim_cloud,
                GetVimCloudImpl.Input(vnfr_uuid=sample_vnfr["id"]),
            )


class TestGetVduInstantiateInfoMethods(TestCase):
    def test_get_flavor_details__successful__get_flavor(self):
        compute_desc_id = "compute-id-1"
        result = VnfSpecifications._get_flavor_details(compute_desc_id, sample_vnfd)
        self.assertEqual(result, flavor_1)

    def test_get_flavor_details__empty_compute_desc__flavor_is_none(self):
        compute_desc_id = ""
        result = VnfSpecifications._get_flavor_details(compute_desc_id, sample_vnfd)
        self.assertEqual(result, None)

    def test_get_flavor_details__compute_desc_not_found__flavor_is_none(self):
        compute_desc_id = "compute-id-5"
        result = VnfSpecifications._get_flavor_details(compute_desc_id, sample_vnfd)
        self.assertEqual(result, None)

    def test_get_flavor_details__empty_vnfd__flavor_is_none(self):
        compute_desc_id = "compute-id-5"
        result = VnfSpecifications._get_flavor_details(compute_desc_id, {})
        self.assertEqual(result, None)

    def test_get_flavor_details__wrong_vnfd_format__flavor_is_none(self):
        compute_desc_id = "compute-id-2"
        sample_vnfd = {
            "_id": vnfd_id,
            "vdu": [vdu],
            "virtual-compute-desc": [
                {
                    "virtual-memory": {"size": "4"},
                }
            ],
        }
        result = VnfSpecifications._get_flavor_details(compute_desc_id, sample_vnfd)
        self.assertEqual(result, None)

    @patch("osm_lcm.temporal.vnf_activities.VnfSpecifications._get_flavor_details")
    def test_get_compute_constraints__succeeded__get_constraints(
        self, mock_get_flavor_details
    ):
        mock_get_flavor_details.return_value = flavor_1
        result = VnfSpecifications.get_compute_constraints(vdu, sample_vnfd)
        self.assertEqual(VduComputeConstraints(cores=2, mem=4), result)

    @patch("osm_lcm.temporal.vnf_activities.VnfSpecifications._get_flavor_details")
    def test_get_compute_constraints__empty_flavor_details__no_constraints(
        self, mock_get_flavor_details
    ):
        mock_get_flavor_details.return_value = {}
        result = VnfSpecifications.get_compute_constraints(vdu, sample_vnfd)
        self.assertEqual(VduComputeConstraints(cores=0, mem=0), result)

    @patch("osm_lcm.temporal.vnf_activities.VnfSpecifications._get_flavor_details")
    def test_get_compute_constraints__flavor_details_is_none__no_constraints(
        self, mock_get_flavor_details
    ):
        mock_get_flavor_details.return_value = None
        result = VnfSpecifications.get_compute_constraints(vdu, sample_vnfd)
        self.assertEqual(VduComputeConstraints(cores=0, mem=0), result)

    @patch("osm_lcm.temporal.vnf_activities.VnfSpecifications._get_flavor_details")
    def test_get_compute_constraints__flavor_has_only_cpu__get_cpu_constraint(
        self, mock_get_flavor_details
    ):
        mock_get_flavor_details.return_value = flavor_2
        result = VnfSpecifications.get_compute_constraints(vdu, sample_vnfd)
        self.assertEqual(VduComputeConstraints(cores=2, mem=0), result)

    @patch("osm_lcm.temporal.vnf_activities.VnfSpecifications._get_flavor_details")
    def test_get_compute_constraints__flavor_has_only_memory__get_memory_constraint(
        self, mock_get_flavor_details
    ):
        mock_get_flavor_details.return_value = flavor_3
        result = VnfSpecifications.get_compute_constraints(vdu, sample_vnfd)
        self.assertEqual(VduComputeConstraints(cores=0, mem=4), result)

    def test_list_to_dict__config_has_several_items__get_expected_dict(self):
        config = [
            {"key": "config::domain_name1", "value": "osm.org"},
            {"key": "domain_name2", "value": "osm.com"},
        ]
        result = VnfSpecifications._list_to_dict(config)
        self.assertEqual(
            result, {"config::domain_name1": "osm.org", "domain_name2": "osm.com"}
        )

    def test_list_to_dict__empty_input__get_empty_dict(self):
        config = []
        result = VnfSpecifications._list_to_dict(config)
        self.assertEqual(result, {})

    def test_get_only_config_items__with_identifier__get_expected_config(self):
        config = {"config::redirect-map": "https://osm.instantiation.params"}
        result = VnfSpecifications._get_only_config_items(config)
        self.assertEqual(result, {"redirect-map": "https://osm.instantiation.params"})

    def test_get_only_config_items__without_identifier__no_config(self):
        config = {"key": "domain_name1", "value": "osm.org"}
        result = VnfSpecifications._get_only_config_items(config)
        self.assertEqual(result, {})

    def test_get_only_config_items__empty_input__no_config(self):
        config = {}
        result = VnfSpecifications._get_only_config_items(config)
        self.assertEqual(result, {})

    def test_get_vdu_instantiation_params__empty_vnf_config__get_params(self):
        result = VnfSpecifications.get_vdu_instantiation_params(vdu_id, vnf_config)
        self.assertEqual(
            result, {"config::redirect-map": "https://osm.instantiation.params"}
        )

    def test_get_vdu_instantiation_params__empty_vnf_config__no_params(self):
        result = VnfSpecifications.get_vdu_instantiation_params(vdu_id, {})
        self.assertEqual(result, {})

    def test_get_vdu_instantiation_params__vdu_id_mismatch__no_params(self):
        config = deepcopy(vnf_config)
        config["vdu"][0]["id"] = "other-vdu-id"
        result = VnfSpecifications.get_vdu_instantiation_params(vdu_id, config)
        self.assertEqual(result, {})

    def test_get_vdu_instantiation_params__empty_configurable_properties__no_params(
        self,
    ):
        config = deepcopy(vnf_config)
        config["vdu"][0]["configurable-properties"] = {}
        result = VnfSpecifications.get_vdu_instantiation_params(vdu_id, config)
        self.assertEqual(result, {})

    @patch("osm_lcm.temporal.vnf_activities.VnfSpecifications._get_only_config_items")
    @patch("osm_lcm.temporal.vnf_activities.VnfSpecifications._list_to_dict")
    def test_get_application_config__instantiate_config_overrides_vdu_config__get_expected_application_config(
        self,
        mock_list_to_dict,
        mock_get_only_config_items,
    ):
        vdu_instantiate_config = {"config::domain_name1": "osm.public"}
        mock_get_only_config_items.side_effect = [
            {
                "domain_name1": "osm.org",
                "domain_name2": "osm.com",
            },
            {"domain_name1": "osm.public"},
        ]
        result = VnfSpecifications.get_application_config(vdu, vdu_instantiate_config)
        self.assertEqual(
            result, {"domain_name1": "osm.public", "domain_name2": "osm.com"}
        )

    @patch("osm_lcm.temporal.vnf_activities.VnfSpecifications._get_only_config_items")
    @patch("osm_lcm.temporal.vnf_activities.VnfSpecifications._list_to_dict")
    def test_get_application_config__empty_instantiate_config__get_descriptor_config(
        self,
        mock_list_to_dict,
        mock_get_only_config_items,
    ):
        vdu_instantiate_config = {}
        mock_get_only_config_items.side_effect = [
            {
                "domain_name1": "osm.org",
                "domain_name2": "osm.com",
            },
            {},
        ]
        result = VnfSpecifications.get_application_config(vdu, vdu_instantiate_config)
        self.assertEqual(result, app_config)

    @patch("osm_lcm.temporal.vnf_activities.VnfSpecifications._get_only_config_items")
    @patch("osm_lcm.temporal.vnf_activities.VnfSpecifications._list_to_dict")
    def test_get_application_config__no_config__empty_application_config(
        self,
        mock_list_to_dict,
        mock_get_only_config_items,
    ):
        vdu_instantiate_config = {}
        vdu_copy = deepcopy(vdu)
        vdu_copy["configurable-properties"] = []
        mock_get_only_config_items.side_effect = [{}, {}]
        result = VnfSpecifications.get_application_config(
            vdu_copy, vdu_instantiate_config
        )
        self.assertEqual(result, {})


class TestGetModelNames(asynctest.TestCase):
    async def setUp(self):
        self.db = Mock()
        self.get_model_names = GetModelNamesImpl(self.db)
        self.env = ActivityEnvironment()

    async def test_activity__success(self):
        # Two namespaces, one of them repeated
        self.db.get_list.return_value = [
            {
                "nsr-id-ref": nsr_uuid,
                "namespace": namespace,
            }
            for namespace in ["namespace1", "namespace1", "namespace2"]
        ]

        activity_result = await self.env.run(
            self.get_model_names,
            GetModelNamesImpl.Input(ns_uuid=nsr_uuid),
        )

        self.assertEqual(activity_result.model_names, {"namespace1", "namespace2"})

    async def test_activity__raise_db_exception(self):
        self.db.get_list.side_effect = DbException("Can not connect to Database.")

        with self.assertRaises(DbException):
            await self.env.run(
                self.get_model_names,
                GetModelNamesImpl.Input(ns_uuid=nsr_uuid),
            )


if __name__ == "__main__":
    asynctest.main()
