# Copyright 2021 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
#     Unless required by applicable law or agreed to in writing, software
#     distributed under the License is distributed on an "AS IS" BASIS,
#     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#     See the License for the specific language governing permissions and
#     limitations under the License.

import abc
import asyncio
from base64 import b64decode
import re
import typing

from Crypto.Cipher import AES
from motor.motor_asyncio import AsyncIOMotorClient
from n2vc.config import EnvironConfig
from n2vc.vca.connection_data import ConnectionData
from osm_common.dbmongo import DbMongo, DbException

DB_NAME = "osm"


class Store(abc.ABC):
    @abc.abstractmethod
    async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
        """
        Get VCA connection data

        :param: vca_id: VCA ID

        :returns: ConnectionData with the information of the database
        """

    @abc.abstractmethod
    async def update_vca_endpoints(self, hosts: typing.List[str], vca_id: str):
        """
        Update VCA endpoints

        :param: endpoints: List of endpoints to write in the database
        :param: vca_id: VCA ID
        """

    @abc.abstractmethod
    async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
        """
        Get list if VCA endpoints

        :param: vca_id: VCA ID

        :returns: List of endpoints
        """

    @abc.abstractmethod
    async def get_vca_id(self, vim_id: str = None) -> str:
        """
        Get VCA id for a VIM account

        :param: vim_id: Vim account ID
        """


class DbMongoStore(Store):
    def __init__(self, db: DbMongo):
        """
        Constructor

        :param: db: osm_common.dbmongo.DbMongo object
        """
        self.db = db

    async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
        """
        Get VCA connection data

        :param: vca_id: VCA ID

        :returns: ConnectionData with the information of the database
        """
        data = self.db.get_one("vca", q_filter={"_id": vca_id})
        self.db.encrypt_decrypt_fields(
            data,
            "decrypt",
            ["secret", "cacert"],
            schema_version=data["schema_version"],
            salt=data["_id"],
        )
        return ConnectionData(**data)

    async def update_vca_endpoints(
        self, endpoints: typing.List[str], vca_id: str = None
    ):
        """
        Update VCA endpoints

        :param: endpoints: List of endpoints to write in the database
        :param: vca_id: VCA ID
        """
        if vca_id:
            data = self.db.get_one("vca", q_filter={"_id": vca_id})
            data["endpoints"] = endpoints
            self._update("vca", vca_id, data)
        else:
            # The default VCA. Data for the endpoints is in a different place
            juju_info = self._get_juju_info()
            # If it doesn't, then create it
            if not juju_info:
                try:
                    self.db.create(
                        "vca",
                        {"_id": "juju"},
                    )
                except DbException as e:
                    # Racing condition: check if another N2VC worker has created it
                    juju_info = self._get_juju_info()
                    if not juju_info:
                        raise e
            self.db.set_one(
                "vca",
                {"_id": "juju"},
                {"api_endpoints": endpoints},
            )

    async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
        """
        Get list if VCA endpoints

        :param: vca_id: VCA ID

        :returns: List of endpoints
        """
        endpoints = []
        if vca_id:
            endpoints = self.get_vca_connection_data(vca_id).endpoints
        else:
            juju_info = self._get_juju_info()
            if juju_info and "api_endpoints" in juju_info:
                endpoints = juju_info["api_endpoints"]
        return endpoints

    async def get_vca_id(self, vim_id: str = None) -> str:
        """
        Get VCA ID from the database for a given VIM account ID

        :param: vim_id: VIM account ID
        """
        return (
            self.db.get_one(
                "vim_accounts",
                q_filter={"_id": vim_id},
                fail_on_empty=False,
            ).get("vca")
            if vim_id
            else None
        )

    def _update(self, collection: str, id: str, data: dict):
        """
        Update object in database

        :param: collection: Collection name
        :param: id: ID of the object
        :param: data: Object data
        """
        self.db.replace(
            collection,
            id,
            data,
        )

    def _get_juju_info(self):
        """Get Juju information (the default VCA) from the admin collection"""
        return self.db.get_one(
            "vca",
            q_filter={"_id": "juju"},
            fail_on_empty=False,
        )


class MotorStore(Store):
    def __init__(self, uri: str, loop=None):
        """
        Constructor

        :param: uri: Connection string to connect to the database.
        :param: loop: Asyncio Loop
        """
        self._client = AsyncIOMotorClient(uri)
        self.loop = loop or asyncio.get_event_loop()
        self._secret_key = None
        self._config = EnvironConfig(prefixes=["OSMLCM_", "OSMMON_"])

    @property
    def _database(self):
        return self._client[DB_NAME]

    @property
    def _vca_collection(self):
        return self._database["vca"]

    @property
    def _admin_collection(self):
        return self._database["admin"]

    @property
    def _vim_accounts_collection(self):
        return self._database["vim_accounts"]

    async def get_vca_connection_data(self, vca_id: str) -> ConnectionData:
        """
        Get VCA connection data

        :param: vca_id: VCA ID

        :returns: ConnectionData with the information of the database
        """
        data = await self._vca_collection.find_one({"_id": vca_id})
        if not data:
            raise Exception("vca with id {} not found".format(vca_id))
        await self.decrypt_fields(
            data,
            ["secret", "cacert"],
            schema_version=data["schema_version"],
            salt=data["_id"],
        )
        return ConnectionData(**data)

    async def update_vca_endpoints(
        self, endpoints: typing.List[str], vca_id: str = None
    ):
        """
        Update VCA endpoints

        :param: endpoints: List of endpoints to write in the database
        :param: vca_id: VCA ID
        """
        if vca_id:
            data = await self._vca_collection.find_one({"_id": vca_id})
            data["endpoints"] = endpoints
            await self._vca_collection.replace_one({"_id": vca_id}, data)
        else:
            # The default VCA. Data for the endpoints is in a different place
            juju_info = await self._get_juju_info()
            # If it doesn't, then create it
            if not juju_info:
                try:
                    await self._admin_collection.insert_one({"_id": "juju"})
                except Exception as e:
                    # Racing condition: check if another N2VC worker has created it
                    juju_info = await self._get_juju_info()
                    if not juju_info:
                        raise e

            await self._admin_collection.replace_one(
                {"_id": "juju"}, {"api_endpoints": endpoints}
            )

    async def get_vca_endpoints(self, vca_id: str = None) -> typing.List[str]:
        """
        Get list if VCA endpoints

        :param: vca_id: VCA ID

        :returns: List of endpoints
        """
        endpoints = []
        if vca_id:
            endpoints = (await self.get_vca_connection_data(vca_id)).endpoints
        else:
            juju_info = await self._get_juju_info()
            if juju_info and "api_endpoints" in juju_info:
                endpoints = juju_info["api_endpoints"]
        return endpoints

    async def get_vca_id(self, vim_id: str = None) -> str:
        """
        Get VCA ID from the database for a given VIM account ID

        :param: vim_id: VIM account ID
        """
        vca_id = None
        if vim_id:
            vim_account = await self._vim_accounts_collection.find_one({"_id": vim_id})
            if vim_account and "vca" in vim_account:
                vca_id = vim_account["vca"]
        return vca_id

    async def _get_juju_info(self):
        """Get Juju information (the default VCA) from the admin collection"""
        return await self._admin_collection.find_one({"_id": "juju"})

    # DECRYPT METHODS
    async def decrypt_fields(
        self,
        item: dict,
        fields: typing.List[str],
        schema_version: str = None,
        salt: str = None,
    ):
        """
        Decrypt fields

        Decrypt fields from a dictionary. Follows the same logic as in osm_common.

        :param: item: Dictionary with the keys to be decrypted
        :param: fields: List of keys to decrypt
        :param: schema version: Schema version. (i.e. 1.11)
        :param: salt: Salt for the decryption
        """
        flags = re.I

        async def process(_item):
            if isinstance(_item, list):
                for elem in _item:
                    await process(elem)
            elif isinstance(_item, dict):
                for key, val in _item.items():
                    if isinstance(val, str):
                        if any(re.search(f, key, flags) for f in fields):
                            _item[key] = await self.decrypt(val, schema_version, salt)
                    else:
                        await process(val)

        await process(item)

    async def decrypt(self, value, schema_version=None, salt=None):
        """
        Decrypt an encrypted value
        :param value: value to be decrypted. It is a base64 string
        :param schema_version: used for known encryption method used. If None or '1.0' no encryption has been done.
               If '1.1' symmetric AES encryption has been done
        :param salt: optional salt to be used
        :return: Plain content of value
        """
        await self.get_secret_key()
        if not self.secret_key or not schema_version or schema_version == "1.0":
            return value
        else:
            secret_key = self._join_secret_key(salt)
            encrypted_msg = b64decode(value)
            cipher = AES.new(secret_key)
            decrypted_msg = cipher.decrypt(encrypted_msg)
            try:
                unpadded_private_msg = decrypted_msg.decode().rstrip("\0")
            except UnicodeDecodeError:
                raise DbException(
                    "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
                    http_code=500,
                )
            return unpadded_private_msg

    def _join_secret_key(self, update_key: typing.Any) -> bytes:
        """
        Join key with secret key

        :param: update_key: str or bytes with the to update

        :return: Joined key
        """
        return self._join_keys(update_key, self.secret_key)

    def _join_keys(self, key: typing.Any, secret_key: bytes) -> bytes:
        """
        Join key with secret_key

        :param: key: str or bytesof the key to update
        :param: secret_key: bytes of the secret key

        :return: Joined key
        """
        if isinstance(key, str):
            update_key_bytes = key.encode()
        else:
            update_key_bytes = key
        new_secret_key = bytearray(secret_key) if secret_key else bytearray(32)
        for i, b in enumerate(update_key_bytes):
            new_secret_key[i % 32] ^= b
        return bytes(new_secret_key)

    @property
    def secret_key(self):
        return self._secret_key

    async def get_secret_key(self):
        """
        Get secret key using the database key and the serial key in the DB
        The key is populated in the property self.secret_key
        """
        if self.secret_key:
            return
        secret_key = None
        if self.database_key:
            secret_key = self._join_keys(self.database_key, None)
        version_data = await self._admin_collection.find_one({"_id": "version"})
        if version_data and version_data.get("serial"):
            secret_key = self._join_keys(b64decode(version_data["serial"]), secret_key)
        self._secret_key = secret_key

    @property
    def database_key(self):
        return self._config["database_commonkey"]
