#######################################################################################
# Copyright ETSI Contributors and Others.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asynctest

from osm_lcm.temporal.ns_activities import (
    GetNsRecordImpl,
    GetVnfDetailsImpl,
    DeleteNsRecordImpl,
)
from osm_common.dbbase import DbException
from temporalio.testing import ActivityEnvironment
from unittest.mock import Mock

ns_uuid = "00000000-0000-0000-0000-000000000000"
get_vnf_details_input = GetVnfDetailsImpl.Input(ns_uuid=ns_uuid)
sample_vnf_details = [
    {
        "id": "00000000-0000-0000-0000-000000000000",
        "member-vnf-index-ref": "vnf1",
    },
    {
        "id": "00000000-0000-0000-0000-000000000000",
        "member-vnf-index-ref": "vnf2",
    },
]
sample_nsr = {
    "_id": ns_uuid,
    "name": "sol006_juju24",
    "name-ref": "sol006_juju24",
    "short-name": "sol006_juju24",
    "admin-status": "ENABLED",
    "nsState": "NOT_INSTANTIATED",
    "currentOperation": "IDLE",
}


class TestException(Exception):
    pass


class TestGetVnfDetails(asynctest.TestCase):
    def setUp(self):
        self.db = Mock()
        self.env = ActivityEnvironment()
        self.get_vnf_details_impl = GetVnfDetailsImpl(self.db)

    async def test_activity__succeded__get_expected_result(self):
        self.db.get_list.return_value = sample_vnf_details
        result = await self.env.run(self.get_vnf_details_impl, get_vnf_details_input)

        self.assertEqual(
            result.vnf_details,
            [
                ("00000000-0000-0000-0000-000000000000", "vnf1"),
                ("00000000-0000-0000-0000-000000000000", "vnf2"),
            ],
        )

    async def test_activity__failed__raise_db_exception(self):
        self.db.get_list.side_effect = DbException("not found.")
        with self.assertRaises(DbException):
            await self.env.run(self.get_vnf_details_impl, get_vnf_details_input)


class TestGetNsRecord(asynctest.TestCase):
    async def setUp(self):
        self.db = Mock()
        self.env = ActivityEnvironment()
        self.get_ns_record_impl = GetNsRecordImpl(self.db)

    async def test_activity__succeeded__get_expected_result(self):
        self.db.get_one.return_value = sample_nsr
        activity_result = await self.env.run(
            self.get_ns_record_impl,
            GetNsRecordImpl.Input(nsr_uuid=sample_nsr["_id"]),
        )
        self.assertEqual(activity_result.nsr, sample_nsr)

    async def test_activity__failed__raise_test_exception(self):
        self.db.get_one.side_effect = TestException("Can not connect to Database.")
        with self.assertRaises(TestException):
            await self.env.run(
                self.get_ns_record_impl,
                GetNsRecordImpl.Input(nsr_uuid=sample_nsr["_id"]),
            )


class TestDeleteNsRecord(asynctest.TestCase):
    async def setUp(self):
        self.db = Mock()
        self.delete_ns_record = DeleteNsRecordImpl(self.db)
        self.env = ActivityEnvironment()

    async def test_activity__succeeded__expected_record_is_deleted(self):
        self.db.del_one.return_value = None
        await self.env.run(
            self.delete_ns_record,
            DeleteNsRecordImpl.Input(ns_uuid=ns_uuid),
        )
        self.db.del_one.assert_called_with(
            "nsrs", {"_id": ns_uuid}, fail_on_empty=False
        )

    async def test_activity__failed__raise_db_exception(self):
        self.db.del_one.side_effect = DbException("Can not connect to Database.")
        with self.assertRaises(DbException):
            await self.env.run(
                self.delete_ns_record,
                DeleteNsRecordImpl.Input(ns_uuid=ns_uuid),
            )
