.pydevproject
deb_dist
*.tar.gz
+
+src/osm_common/_version.py
+#######################################################################################
+# Copyright ETSI Contributors and Others.
+#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-########################################################################################
-# This Dockerfile is intented for devops testing and deb package generation
-#
-# To run stage 2 locally:
-#
-# docker build -t stage2 .
-# docker run -ti -v `pwd`:/work -w /work --entrypoint /bin/bash stage2
-# devops-stages/stage-test.sh
-# devops-stages/stage-build.sh
-#
+#######################################################################################
+FROM 31z4/tox
-FROM ubuntu:22.04
+USER root
-ARG APT_PROXY
-RUN if [ ! -z $APT_PROXY ] ; then \
- echo "Acquire::http::Proxy \"$APT_PROXY\";" > /etc/apt/apt.conf.d/proxy.conf ;\
- echo "Acquire::https::Proxy \"$APT_PROXY\";" >> /etc/apt/apt.conf.d/proxy.conf ;\
- fi
+RUN set -eux; \
+ apt-get update; \
+ apt-get install -y --no-install-recommends patch wget libmagic-dev; \
+ rm -rf /var/lib/apt/lists/*
-RUN DEBIAN_FRONTEND=noninteractive apt-get update && \
- DEBIAN_FRONTEND=noninteractive apt-get -y install \
- debhelper \
- dh-python \
- git \
- python3 \
- python3-all \
- python3-dev \
- python3-setuptools \
- python3-pip \
- tox
+COPY entrypoint.sh /entrypoint.sh
+RUN chmod +x /entrypoint.sh
-ENV LC_ALL C.UTF-8
-ENV LANG C.UTF-8
+ENTRYPOINT ["/entrypoint.sh"]
+CMD ["devops-stages/stage-test.sh"]
# contact: bdiaz@whitestack.com or glavado@whitestack.com
##
+graft src
+graft tests
+#graft docs
+
include README.rst
-recursive-include osm_common *.py *.xml *.sh *.txt
-recursive-include devops-stages *
\ No newline at end of file
+include tox.ini
+recursive-include devops-stages *
+
# See the License for the specific language governing permissions and
# limitations under the License.
-MDG=common
-rm -rf pool
-rm -rf dists
-mkdir -p pool/$MDG
-mv deb_dist/*.deb pool/$MDG/
-
+echo "Nothing to be done"
# See the License for the specific language governing permissions and
# limitations under the License.
-rm -rf dist deb_dist osm_common-*.tar.gz osm_common.egg-info .eggs .tox
-
-tox -e dist
+echo "Nothing to be done"
--- /dev/null
+#!/usr/bin/env bash
+#######################################################################################
+# Copyright ETSI Contributors and Others.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#######################################################################################
+
+set -euo pipefail
+
+UID=${UID:-10000}
+GID=${GID:-10000}
+
+# ensure group exists with given GID
+if ! getent group "${GID}" > /dev/null; then
+ groupmod -o -g "${GID}" tox 2>/dev/null \
+ || groupadd -o -g "${GID}" tox
+fi
+
+# ensure tox user has the given UID
+usermod -o -u "${UID}" -g "${GID}" tox \
+ || useradd -o -m -u "${UID}" -g "${GID}" tox
+
+exec su tox "$@"
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-
-version = "7.0.0.post4"
-date_version = "2019-01-21"
-
-# try to get version from installed package. Skip if fails
-try:
- from pkg_resources import get_distribution
-
- version = get_distribution("osm_common").version
-
-except Exception as init_error:
- logging.exception(f"{init_error} occured while getting the common version")
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
-
-
-class FakeLock:
- """Implements a fake lock that can be called with the "with" statement or acquire, release methods"""
-
- def __enter__(self):
- pass
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- pass
-
- def acquire(self):
- pass
-
- def release(self):
- pass
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from base64 import b64decode, b64encode
-from copy import deepcopy
-from http import HTTPStatus
-import logging
-import re
-from threading import Lock
-import typing
-
-
-from Crypto.Cipher import AES
-from motor.motor_asyncio import AsyncIOMotorClient
-from osm_common.common_utils import FakeLock
-import yaml
-
-__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
-
-
-DB_NAME = "osm"
-
-
-class DbException(Exception):
- def __init__(self, message, http_code=HTTPStatus.NOT_FOUND):
- self.http_code = http_code
- Exception.__init__(self, "database exception " + str(message))
-
-
-class DbBase(object):
- def __init__(self, encoding_type="ascii", logger_name="db", lock=False):
- """
- Constructor of dbBase
- :param logger_name: logging name
- :param lock: Used to protect simultaneous access to the same instance class by several threads:
- False, None: Do not protect, this object will only be accessed by one thread
- True: This object needs to be protected by several threads accessing.
- Lock object. Use thi Lock for the threads access protection
- """
- self.logger = logging.getLogger(logger_name)
- self.secret_key = None # 32 bytes length array used for encrypt/decrypt
- self.encrypt_mode = AES.MODE_ECB
- self.encoding_type = encoding_type
- if not lock:
- self.lock = FakeLock()
- elif lock is True:
- self.lock = Lock()
- elif isinstance(lock, Lock):
- self.lock = lock
- else:
- raise ValueError("lock parameter must be a Lock classclass or boolean")
-
- def db_connect(self, config, target_version=None):
- """
- Connect to database
- :param config: Configuration of database. Contains among others:
- host: database host (mandatory)
- port: database port (mandatory)
- name: database name (mandatory)
- user: database username
- password: database password
- commonkey: common OSM key used for sensible information encryption
- masterpassword: same as commonkey, for backward compatibility. Deprecated, to be removed in the future
- :param target_version: if provided it checks if database contains required version, raising exception otherwise.
- :return: None or raises DbException on error
- """
- raise DbException("Method 'db_connect' not implemented")
-
- def db_disconnect(self):
- """
- Disconnect from database
- :return: None
- """
- pass
-
- def get_list(self, table, q_filter=None):
- """
- Obtain a list of entries matching q_filter
- :param table: collection or table
- :param q_filter: Filter
- :return: a list (can be empty) with the found entries. Raises DbException on error
- """
- raise DbException("Method 'get_list' not implemented")
-
- def count(self, table, q_filter=None):
- """
- Count the number of entries matching q_filter
- :param table: collection or table
- :param q_filter: Filter
- :return: number of entries found (can be zero)
- :raise: DbException on error
- """
- raise DbException("Method 'count' not implemented")
-
- def get_one(self, table, q_filter=None, fail_on_empty=True, fail_on_more=True):
- """
- Obtain one entry matching q_filter
- :param table: collection or table
- :param q_filter: Filter
- :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
- it raises a DbException
- :param fail_on_more: If more than one matches filter it returns one of then unless this flag is set tu True, so
- that it raises a DbException
- :return: The requested element, or None
- """
- raise DbException("Method 'get_one' not implemented")
-
- def del_list(self, table, q_filter=None):
- """
- Deletes all entries that match q_filter
- :param table: collection or table
- :param q_filter: Filter
- :return: Dict with the number of entries deleted
- """
- raise DbException("Method 'del_list' not implemented")
-
- def del_one(self, table, q_filter=None, fail_on_empty=True):
- """
- Deletes one entry that matches q_filter
- :param table: collection or table
- :param q_filter: Filter
- :param fail_on_empty: If nothing matches filter it returns '0' deleted unless this flag is set tu True, in
- which case it raises a DbException
- :return: Dict with the number of entries deleted
- """
- raise DbException("Method 'del_one' not implemented")
-
- def create(self, table, indata):
- """
- Add a new entry at database
- :param table: collection or table
- :param indata: content to be added
- :return: database '_id' of the inserted element. Raises a DbException on error
- """
- raise DbException("Method 'create' not implemented")
-
- def create_list(self, table, indata_list):
- """
- Add several entries at once
- :param table: collection or table
- :param indata_list: list of elements to insert. Each element must be a dictionary.
- An '_id' key based on random uuid is added at each element if missing
- :return: list of inserted '_id's. Exception on error
- """
- raise DbException("Method 'create_list' not implemented")
-
- def set_one(
- self,
- table,
- q_filter,
- update_dict,
- fail_on_empty=True,
- unset=None,
- pull=None,
- push=None,
- push_list=None,
- pull_list=None,
- ):
- """
- Modifies an entry at database
- :param table: collection or table
- :param q_filter: Filter
- :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
- :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
- it raises a DbException
- :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
- ignored. If not exist, it is ignored
- :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
- if exist in the array is removed. If not exist, it is ignored
- :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
- is appended to the end of the array
- :param pull_list: Same as pull but values are arrays where each item is removed from the array
- :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
- whole array
- :return: Dict with the number of entries modified. None if no matching is found.
- """
- raise DbException("Method 'set_one' not implemented")
-
- def set_list(
- self,
- table,
- q_filter,
- update_dict,
- unset=None,
- pull=None,
- push=None,
- push_list=None,
- pull_list=None,
- ):
- """
- Modifies al matching entries at database
- :param table: collection or table
- :param q_filter: Filter
- :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
- :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
- ignored. If not exist, it is ignored
- :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
- if exist in the array is removed. If not exist, it is ignored
- :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
- is appended to the end of the array
- :param pull_list: Same as pull but values are arrays where each item is removed from the array
- :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
- whole array
- :return: Dict with the number of entries modified
- """
- raise DbException("Method 'set_list' not implemented")
-
- def replace(self, table, _id, indata, fail_on_empty=True):
- """
- Replace the content of an entry
- :param table: collection or table
- :param _id: internal database id
- :param indata: content to replace
- :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
- it raises a DbException
- :return: Dict with the number of entries replaced
- """
- raise DbException("Method 'replace' not implemented")
-
- def _join_secret_key(self, update_key):
- """
- Returns a xor byte combination of the internal secret_key and the provided update_key.
- It does not modify the internal secret_key. Used for adding salt, join keys, etc.
- :param update_key: Can be a string, byte or None. Recommended a long one (e.g. 32 byte length)
- :return: joined key in bytes with a 32 bytes length. Can be None if both internal secret_key and update_key
- are None
- """
- if not update_key:
- return self.secret_key
- elif isinstance(update_key, str):
- update_key_bytes = update_key.encode()
- else:
- update_key_bytes = update_key
-
- new_secret_key = (
- bytearray(self.secret_key) if self.secret_key else bytearray(32)
- )
- for i, b in enumerate(update_key_bytes):
- new_secret_key[i % 32] ^= b
- return bytes(new_secret_key)
-
- def set_secret_key(self, new_secret_key, replace=False):
- """
- Updates internal secret_key used for encryption, with a byte xor
- :param new_secret_key: string or byte array. It is recommended a 32 byte length
- :param replace: if True, old value of internal secret_key is ignored and replaced. If false, a byte xor is used
- :return: None
- """
- if replace:
- self.secret_key = None
- self.secret_key = self._join_secret_key(new_secret_key)
-
- def get_secret_key(self):
- """
- Get the database secret key in case it is not done when "connect" is called. It can happens when database is
- empty after an initial install. It should skip if secret is already obtained.
- """
- pass
-
- @staticmethod
- def pad_data(value: str) -> str:
- if not isinstance(value, str):
- raise DbException(
- f"Incorrect data type: type({value}), string is expected."
- )
- return value + ("\0" * ((16 - len(value)) % 16))
-
- @staticmethod
- def unpad_data(value: str) -> str:
- if not isinstance(value, str):
- raise DbException(
- f"Incorrect data type: type({value}), string is expected."
- )
- return value.rstrip("\0")
-
- def _encrypt_value(self, value: str, schema_version: str, salt: str):
- """Encrypt a value.
-
- Args:
- value (str): value to be encrypted. It is string/unicode
- schema_version (str): used for version control. If None or '1.0' no encryption is done.
- If '1.1' symmetric AES encryption is done
- salt (str): optional salt to be used. Must be str
-
- Returns:
- Encrypted content of value (str)
-
- """
- if not self.secret_key or not schema_version or schema_version == "1.0":
- return value
-
- else:
- # Secret key as bytes
- secret_key = self._join_secret_key(salt)
- cipher = AES.new(secret_key, self.encrypt_mode)
- # Padded data as string
- padded_private_msg = self.pad_data(value)
- # Padded data as bytes
- padded_private_msg_bytes = padded_private_msg.encode(self.encoding_type)
- # Encrypt padded data
- encrypted_msg = cipher.encrypt(padded_private_msg_bytes)
- # Base64 encoded encrypted data
- encoded_encrypted_msg = b64encode(encrypted_msg)
- # Converting to string
- return encoded_encrypted_msg.decode(self.encoding_type)
-
- def encrypt(self, value: str, schema_version: str = None, salt: str = None) -> str:
- """Encrypt a value.
-
- Args:
- value (str): value to be encrypted. It is string/unicode
- schema_version (str): used for version control. If None or '1.0' no encryption is done.
- If '1.1' symmetric AES encryption is done
- salt (str): optional salt to be used. Must be str
-
- Returns:
- Encrypted content of value (str)
-
- """
- self.get_secret_key()
- return self._encrypt_value(value, schema_version, salt)
-
- def _decrypt_value(self, value: str, schema_version: str, salt: str) -> str:
- """Decrypt an encrypted value.
- Args:
-
- value (str): value to be decrypted. It is a base64 string
- schema_version (str): used for known encryption method used.
- If None or '1.0' no encryption has been done.
- If '1.1' symmetric AES encryption has been done
- salt (str): optional salt to be used
-
- Returns:
- Plain content of value (str)
-
- """
- if not self.secret_key or not schema_version or schema_version == "1.0":
- return value
-
- else:
- secret_key = self._join_secret_key(salt)
- # Decoding encrypted data, output bytes
- encrypted_msg = b64decode(value)
- cipher = AES.new(secret_key, self.encrypt_mode)
- # Decrypted data, output bytes
- decrypted_msg = cipher.decrypt(encrypted_msg)
- try:
- # Converting to string
- private_msg = decrypted_msg.decode(self.encoding_type)
- except UnicodeDecodeError:
- raise DbException(
- "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
- http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
- )
- # Unpadded data as string
- return self.unpad_data(private_msg)
-
- def decrypt(self, value: str, schema_version: str = None, salt: str = None) -> str:
- """Decrypt an encrypted value.
- Args:
-
- value (str): value to be decrypted. It is a base64 string
- schema_version (str): used for known encryption method used.
- If None or '1.0' no encryption has been done.
- If '1.1' symmetric AES encryption has been done
- salt (str): optional salt to be used
-
- Returns:
- Plain content of value (str)
-
- """
- self.get_secret_key()
- return self._decrypt_value(value, schema_version, salt)
-
- def encrypt_decrypt_fields(
- self, item, action, fields=None, flags=None, schema_version=None, salt=None
- ):
- if not fields:
- return
- self.get_secret_key()
- actions = ["encrypt", "decrypt"]
- if action.lower() not in actions:
- raise DbException(
- "Unknown action ({}): Must be one of {}".format(action, actions),
- http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
- )
- method = self.encrypt if action.lower() == "encrypt" else self.decrypt
- if flags is None:
- flags = re.I
-
- def process(_item):
- if isinstance(_item, list):
- for elem in _item:
- process(elem)
- elif isinstance(_item, dict):
- for key, val in _item.items():
- if isinstance(val, str):
- if any(re.search(f, key, flags) for f in fields):
- _item[key] = method(val, schema_version, salt)
- else:
- process(val)
-
- process(item)
-
-
-def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
- """
- Modifies one dictionary with the information of the other following https://tools.ietf.org/html/rfc7396
- Basically is a recursive python 'dict_to_change.update(dict_reference)', but a value of None is used to delete.
- It implements an extra feature that allows modifying an array. RFC7396 only allows replacing the entire array.
- For that, dict_reference should contains a dict with keys starting by "$" with the following meaning:
- $[index] <index> is an integer for targeting a concrete index from dict_to_change array. If the value is None
- the element of the array is deleted, otherwise it is edited.
- $+[index] The value is inserted at this <index>. A value of None has not sense and an exception is raised.
- $+ The value is appended at the end. A value of None has not sense and an exception is raised.
- $val It looks for all the items in the array dict_to_change equal to <val>. <val> is evaluated as yaml,
- that is, numbers are taken as type int, true/false as boolean, etc. Use quotes to force string.
- Nothing happens if no match is found. If the value is None the matched elements are deleted.
- $key: val In case a dictionary is passed in yaml format, if looks for all items in the array dict_to_change
- that are dictionaries and contains this <key> equal to <val>. Several keys can be used by yaml
- format '{key: val, key: val, ...}'; and all of them must match. Nothing happens if no match is
- found. If value is None the matched items are deleted, otherwise they are edited.
- $+val If no match if found (see '$val'), the value is appended to the array. If any match is found nothing
- is changed. A value of None has not sense.
- $+key: val If no match if found (see '$key: val'), the value is appended to the array. If any match is found
- nothing is changed. A value of None has not sense.
- If there are several editions, insertions and deletions; editions and deletions are done first in reverse index
- order; then insertions also in reverse index order; and finally appends in any order. So indexes used at
- insertions must take into account the deleted items.
- :param dict_to_change: Target dictionary to be changed.
- :param dict_reference: Dictionary that contains changes to be applied.
- :param key_list: This is used internally for recursive calls. Do not fill this parameter.
- :return: none or raises and exception only at array modification when there is a bad format or conflict.
- """
-
- def _deep_update_array(array_to_change, _dict_reference, _key_list):
- to_append = {}
- to_insert_at_index = {}
- values_to_edit_delete = {}
- indexes_to_edit_delete = []
- array_edition = None
- _key_list.append("")
- for k in _dict_reference:
- _key_list[-1] = str(k)
- if not isinstance(k, str) or not k.startswith("$"):
- if array_edition is True:
- raise DbException(
- "Found array edition (keys starting with '$') and pure dictionary edition in the"
- " same dict at '{}'".format(":".join(_key_list[:-1]))
- )
- array_edition = False
- continue
- else:
- if array_edition is False:
- raise DbException(
- "Found array edition (keys starting with '$') and pure dictionary edition in the"
- " same dict at '{}'".format(":".join(_key_list[:-1]))
- )
- array_edition = True
- insert = False
- indexes = [] # indexes to edit or insert
- kitem = k[1:]
- if kitem.startswith("+"):
- insert = True
- kitem = kitem[1:]
- if _dict_reference[k] is None:
- raise DbException(
- "A value of None has not sense for insertions at '{}'".format(
- ":".join(_key_list)
- )
- )
-
- if kitem.startswith("[") and kitem.endswith("]"):
- try:
- index = int(kitem[1:-1])
- if index < 0:
- index += len(array_to_change)
- if index < 0:
- index = 0 # skip outside index edition
- indexes.append(index)
- except Exception:
- raise DbException(
- "Wrong format at '{}'. Expecting integer index inside quotes".format(
- ":".join(_key_list)
- )
- )
- elif kitem:
- # match_found_skip = False
- try:
- filter_in = yaml.safe_load(kitem)
- except Exception:
- raise DbException(
- "Wrong format at '{}'. Expecting '$<yaml-format>'".format(
- ":".join(_key_list)
- )
- )
- if isinstance(filter_in, dict):
- for index, item in enumerate(array_to_change):
- for filter_k, filter_v in filter_in.items():
- if (
- not isinstance(item, dict)
- or filter_k not in item
- or item[filter_k] != filter_v
- ):
- break
- else: # match found
- if insert:
- # match_found_skip = True
- insert = False
- break
- else:
- indexes.append(index)
- else:
- index = 0
- try:
- while True: # if not match a ValueError exception will be raise
- index = array_to_change.index(filter_in, index)
- if insert:
- # match_found_skip = True
- insert = False
- break
- indexes.append(index)
- index += 1
- except ValueError:
- pass
-
- # if match_found_skip:
- # continue
- elif not insert:
- raise DbException(
- "Wrong format at '{}'. Expecting '$+', '$[<index]' or '$[<filter>]'".format(
- ":".join(_key_list)
- )
- )
- for index in indexes:
- if insert:
- if (
- index in to_insert_at_index
- and to_insert_at_index[index] != _dict_reference[k]
- ):
- # Several different insertions on the same item of the array
- raise DbException(
- "Conflict at '{}'. Several insertions on same array index {}".format(
- ":".join(_key_list), index
- )
- )
- to_insert_at_index[index] = _dict_reference[k]
- else:
- if (
- index in indexes_to_edit_delete
- and values_to_edit_delete[index] != _dict_reference[k]
- ):
- # Several different editions on the same item of the array
- raise DbException(
- "Conflict at '{}'. Several editions on array index {}".format(
- ":".join(_key_list), index
- )
- )
- indexes_to_edit_delete.append(index)
- values_to_edit_delete[index] = _dict_reference[k]
- if not indexes:
- if insert:
- to_append[k] = _dict_reference[k]
- # elif _dict_reference[k] is not None:
- # raise DbException("Not found any match to edit in the array, or wrong format at '{}'".format(
- # ":".join(_key_list)))
-
- # edition/deletion is done before insertion
- indexes_to_edit_delete.sort(reverse=True)
- for index in indexes_to_edit_delete:
- _key_list[-1] = str(index)
- try:
- if values_to_edit_delete[index] is None: # None->Anything
- try:
- del array_to_change[index]
- except IndexError:
- pass # it is not consider an error if this index does not exist
- elif not isinstance(
- values_to_edit_delete[index], dict
- ): # NotDict->Anything
- array_to_change[index] = deepcopy(values_to_edit_delete[index])
- elif isinstance(array_to_change[index], dict): # Dict->Dict
- deep_update_rfc7396(
- array_to_change[index], values_to_edit_delete[index], _key_list
- )
- else: # Dict->NotDict
- if isinstance(
- array_to_change[index], list
- ): # Dict->List. Check extra array edition
- if _deep_update_array(
- array_to_change[index],
- values_to_edit_delete[index],
- _key_list,
- ):
- continue
- array_to_change[index] = deepcopy(values_to_edit_delete[index])
- # calling deep_update_rfc7396 to delete the None values
- deep_update_rfc7396(
- array_to_change[index], values_to_edit_delete[index], _key_list
- )
- except IndexError:
- raise DbException(
- "Array edition index out of range at '{}'".format(
- ":".join(_key_list)
- )
- )
-
- # insertion with indexes
- to_insert_indexes = list(to_insert_at_index.keys())
- to_insert_indexes.sort(reverse=True)
- for index in to_insert_indexes:
- array_to_change.insert(index, to_insert_at_index[index])
-
- # append
- for k, insert_value in to_append.items():
- _key_list[-1] = str(k)
- insert_value_copy = deepcopy(insert_value)
- if isinstance(insert_value_copy, dict):
- # calling deep_update_rfc7396 to delete the None values
- deep_update_rfc7396(insert_value_copy, insert_value, _key_list)
- array_to_change.append(insert_value_copy)
-
- _key_list.pop()
- if array_edition:
- return True
- return False
-
- if key_list is None:
- key_list = []
- key_list.append("")
- for k in dict_reference:
- key_list[-1] = str(k)
- if dict_reference[k] is None: # None->Anything
- if k in dict_to_change:
- del dict_to_change[k]
- elif not isinstance(dict_reference[k], dict): # NotDict->Anything
- dict_to_change[k] = deepcopy(dict_reference[k])
- elif k not in dict_to_change: # Dict->Empty
- dict_to_change[k] = deepcopy(dict_reference[k])
- # calling deep_update_rfc7396 to delete the None values
- deep_update_rfc7396(dict_to_change[k], dict_reference[k], key_list)
- elif isinstance(dict_to_change[k], dict): # Dict->Dict
- deep_update_rfc7396(dict_to_change[k], dict_reference[k], key_list)
- else: # Dict->NotDict
- if isinstance(
- dict_to_change[k], list
- ): # Dict->List. Check extra array edition
- if _deep_update_array(dict_to_change[k], dict_reference[k], key_list):
- continue
- dict_to_change[k] = deepcopy(dict_reference[k])
- # calling deep_update_rfc7396 to delete the None values
- deep_update_rfc7396(dict_to_change[k], dict_reference[k], key_list)
- key_list.pop()
-
-
-def deep_update(dict_to_change, dict_reference):
- """Maintained for backward compatibility. Use deep_update_rfc7396 instead"""
- return deep_update_rfc7396(dict_to_change, dict_reference)
-
-
-class Encryption(DbBase):
- def __init__(self, uri, config, encoding_type="ascii", logger_name="db"):
- """Constructor.
-
- Args:
- uri (str): Connection string to connect to the database.
- config (dict): Additional database info
- encoding_type (str): ascii, utf-8 etc.
- logger_name (str): Logger name
-
- """
- self._secret_key = None # 32 bytes length array used for encrypt/decrypt
- self.encrypt_mode = AES.MODE_ECB
- super(Encryption, self).__init__(
- encoding_type=encoding_type, logger_name=logger_name
- )
- self._client = AsyncIOMotorClient(uri)
- self._config = config
-
- @property
- def secret_key(self):
- return self._secret_key
-
- @secret_key.setter
- def secret_key(self, value):
- self._secret_key = value
-
- @property
- def _database(self):
- return self._client[DB_NAME]
-
- @property
- def _admin_collection(self):
- return self._database["admin"]
-
- @property
- def database_key(self):
- return self._config.get("database_commonkey")
-
- async def decrypt_fields(
- self,
- item: dict,
- fields: typing.List[str],
- schema_version: str = None,
- salt: str = None,
- ) -> None:
- """Decrypt fields from a dictionary. Follows the same logic as in osm_common.
-
- Args:
-
- item (dict): Dictionary with the keys to be decrypted
- fields (list): List of keys to decrypt
- schema version (str): Schema version. (i.e. 1.11)
- salt (str): Salt for the decryption
-
- """
- flags = re.I
-
- async def process(_item):
- if isinstance(_item, list):
- for elem in _item:
- await process(elem)
- elif isinstance(_item, dict):
- for key, val in _item.items():
- if isinstance(val, str):
- if any(re.search(f, key, flags) for f in fields):
- _item[key] = await self.decrypt(val, schema_version, salt)
- else:
- await process(val)
-
- await process(item)
-
- async def encrypt(
- self, value: str, schema_version: str = None, salt: str = None
- ) -> str:
- """Encrypt a value.
-
- Args:
- value (str): value to be encrypted. It is string/unicode
- schema_version (str): used for version control. If None or '1.0' no encryption is done.
- If '1.1' symmetric AES encryption is done
- salt (str): optional salt to be used. Must be str
-
- Returns:
- Encrypted content of value (str)
-
- """
- await self.get_secret_key()
- return self._encrypt_value(value, schema_version, salt)
-
- async def decrypt(
- self, value: str, schema_version: str = None, salt: str = None
- ) -> str:
- """Decrypt an encrypted value.
- Args:
-
- value (str): value to be decrypted. It is a base64 string
- schema_version (str): used for known encryption method used.
- If None or '1.0' no encryption has been done.
- If '1.1' symmetric AES encryption has been done
- salt (str): optional salt to be used
-
- Returns:
- Plain content of value (str)
-
- """
- await self.get_secret_key()
- return self._decrypt_value(value, schema_version, salt)
-
- def _join_secret_key(self, update_key: typing.Any) -> bytes:
- """Join key with secret key.
-
- Args:
-
- update_key (str or bytes): str or bytes with the to update
-
- Returns:
-
- Joined key (bytes)
- """
- return self._join_keys(update_key, self.secret_key)
-
- def _join_keys(self, key: typing.Any, secret_key: bytes) -> bytes:
- """Join key with secret_key.
-
- Args:
-
- key (str or bytes): str or bytes of the key to update
- secret_key (bytes): bytes of the secret key
-
- Returns:
-
- Joined key (bytes)
- """
- if isinstance(key, str):
- update_key_bytes = key.encode(self.encoding_type)
- else:
- update_key_bytes = key
- new_secret_key = bytearray(secret_key) if secret_key else bytearray(32)
- for i, b in enumerate(update_key_bytes):
- new_secret_key[i % 32] ^= b
- return bytes(new_secret_key)
-
- async def get_secret_key(self):
- """Get secret key using the database key and the serial key in the DB.
- The key is populated in the property self.secret_key.
- """
- if self.secret_key:
- return
- secret_key = None
- if self.database_key:
- secret_key = self._join_keys(self.database_key, None)
- version_data = await self._admin_collection.find_one({"_id": "version"})
- if version_data and version_data.get("serial"):
- secret_key = self._join_keys(b64decode(version_data["serial"]), secret_key)
- self._secret_key = secret_key
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from copy import deepcopy
-from http import HTTPStatus
-import logging
-from uuid import uuid4
-
-from osm_common.dbbase import DbBase, DbException
-from osm_common.dbmongo import deep_update
-
-
-__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
-
-
-class DbMemory(DbBase):
- def __init__(self, logger_name="db", lock=False):
- super().__init__(logger_name=logger_name, lock=lock)
- self.db = {}
-
- def db_connect(self, config):
- """
- Connect to database
- :param config: Configuration of database
- :return: None or raises DbException on error
- """
- if "logger_name" in config:
- self.logger = logging.getLogger(config["logger_name"])
- master_key = config.get("commonkey") or config.get("masterpassword")
- if master_key:
- self.set_secret_key(master_key)
-
- @staticmethod
- def _format_filter(q_filter):
- db_filter = {}
- # split keys with ANYINDEX in this way:
- # {"A.B.ANYINDEX.C.D.ANYINDEX.E": v } -> {"A.B.ANYINDEX": {"C.D.ANYINDEX": {"E": v}}}
- if q_filter:
- for k, v in q_filter.items():
- db_v = v
- kleft, _, kright = k.rpartition(".ANYINDEX.")
- while kleft:
- k = kleft + ".ANYINDEX"
- db_v = {kright: db_v}
- kleft, _, kright = k.rpartition(".ANYINDEX.")
- deep_update(db_filter, {k: db_v})
-
- return db_filter
-
- def _find(self, table, q_filter):
- def recursive_find(key_list, key_next_index, content, oper, target):
- if key_next_index == len(key_list) or content is None:
- try:
- if oper in ("eq", "cont"):
- if isinstance(target, list):
- if isinstance(content, list):
- return any(
- content_item in target for content_item in content
- )
- return content in target
- elif isinstance(content, list):
- return target in content
- else:
- return content == target
- elif oper in ("neq", "ne", "ncont"):
- if isinstance(target, list):
- if isinstance(content, list):
- return all(
- content_item not in target
- for content_item in content
- )
- return content not in target
- elif isinstance(content, list):
- return target not in content
- else:
- return content != target
- if oper == "gt":
- return content > target
- elif oper == "gte":
- return content >= target
- elif oper == "lt":
- return content < target
- elif oper == "lte":
- return content <= target
- else:
- raise DbException(
- "Unknown filter operator '{}' in key '{}'".format(
- oper, ".".join(key_list)
- ),
- http_code=HTTPStatus.BAD_REQUEST,
- )
- except TypeError:
- return False
-
- elif isinstance(content, dict):
- return recursive_find(
- key_list,
- key_next_index + 1,
- content.get(key_list[key_next_index]),
- oper,
- target,
- )
- elif isinstance(content, list):
- look_for_match = True # when there is a match return immediately
- if (target is None) != (
- oper in ("neq", "ne", "ncont")
- ): # one True and other False (Xor)
- look_for_match = (
- False # when there is not a match return immediately
- )
-
- for content_item in content:
- if key_list[key_next_index] == "ANYINDEX" and isinstance(v, dict):
- matches = True
- if target:
- for k2, v2 in target.items():
- k_new_list = k2.split(".")
- new_operator = "eq"
- if k_new_list[-1] in (
- "eq",
- "ne",
- "gt",
- "gte",
- "lt",
- "lte",
- "cont",
- "ncont",
- "neq",
- ):
- new_operator = k_new_list.pop()
- if not recursive_find(
- k_new_list, 0, content_item, new_operator, v2
- ):
- matches = False
- break
-
- else:
- matches = recursive_find(
- key_list, key_next_index, content_item, oper, target
- )
- if matches == look_for_match:
- return matches
- if key_list[key_next_index].isdecimal() and int(
- key_list[key_next_index]
- ) < len(content):
- matches = recursive_find(
- key_list,
- key_next_index + 1,
- content[int(key_list[key_next_index])],
- oper,
- target,
- )
- if matches == look_for_match:
- return matches
- return not look_for_match
- else: # content is not dict, nor list neither None, so not found
- if oper in ("neq", "ne", "ncont"):
- return target is not None
- else:
- return target is None
-
- for i, row in enumerate(self.db.get(table, ())):
- q_filter = q_filter or {}
- for k, v in q_filter.items():
- k_list = k.split(".")
- operator = "eq"
- if k_list[-1] in (
- "eq",
- "ne",
- "gt",
- "gte",
- "lt",
- "lte",
- "cont",
- "ncont",
- "neq",
- ):
- operator = k_list.pop()
- matches = recursive_find(k_list, 0, row, operator, v)
- if not matches:
- break
- else:
- # match
- yield i, row
-
- def get_list(self, table, q_filter=None):
- """
- Obtain a list of entries matching q_filter
- :param table: collection or table
- :param q_filter: Filter
- :return: a list (can be empty) with the found entries. Raises DbException on error
- """
- try:
- result = []
- with self.lock:
- for _, row in self._find(table, self._format_filter(q_filter)):
- result.append(deepcopy(row))
- return result
- except DbException:
- raise
- except Exception as e: # TODO refine
- raise DbException(str(e))
-
- def count(self, table, q_filter=None):
- """
- Count the number of entries matching q_filter
- :param table: collection or table
- :param q_filter: Filter
- :return: number of entries found (can be zero)
- :raise: DbException on error
- """
- try:
- with self.lock:
- return sum(1 for x in self._find(table, self._format_filter(q_filter)))
- except DbException:
- raise
- except Exception as e: # TODO refine
- raise DbException(str(e))
-
- def get_one(self, table, q_filter=None, fail_on_empty=True, fail_on_more=True):
- """
- Obtain one entry matching q_filter
- :param table: collection or table
- :param q_filter: Filter
- :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
- it raises a DbException
- :param fail_on_more: If more than one matches filter it returns one of then unless this flag is set tu True, so
- that it raises a DbException
- :return: The requested element, or None
- """
- try:
- result = None
- with self.lock:
- for _, row in self._find(table, self._format_filter(q_filter)):
- if not fail_on_more:
- return deepcopy(row)
- if result:
- raise DbException(
- "Found more than one entry with filter='{}'".format(
- q_filter
- ),
- HTTPStatus.CONFLICT.value,
- )
- result = row
- if not result and fail_on_empty:
- raise DbException(
- "Not found entry with filter='{}'".format(q_filter),
- HTTPStatus.NOT_FOUND,
- )
- return deepcopy(result)
- except Exception as e: # TODO refine
- raise DbException(str(e))
-
- def del_list(self, table, q_filter=None):
- """
- Deletes all entries that match q_filter
- :param table: collection or table
- :param q_filter: Filter
- :return: Dict with the number of entries deleted
- """
- try:
- id_list = []
- with self.lock:
- for i, _ in self._find(table, self._format_filter(q_filter)):
- id_list.append(i)
- deleted = len(id_list)
- for i in reversed(id_list):
- del self.db[table][i]
- return {"deleted": deleted}
- except DbException:
- raise
- except Exception as e: # TODO refine
- raise DbException(str(e))
-
- def del_one(self, table, q_filter=None, fail_on_empty=True):
- """
- Deletes one entry that matches q_filter
- :param table: collection or table
- :param q_filter: Filter
- :param fail_on_empty: If nothing matches filter it returns '0' deleted unless this flag is set tu True, in
- which case it raises a DbException
- :return: Dict with the number of entries deleted
- """
- try:
- with self.lock:
- for i, _ in self._find(table, self._format_filter(q_filter)):
- break
- else:
- if fail_on_empty:
- raise DbException(
- "Not found entry with filter='{}'".format(q_filter),
- HTTPStatus.NOT_FOUND,
- )
- return None
- del self.db[table][i]
- return {"deleted": 1}
- except Exception as e: # TODO refine
- raise DbException(str(e))
-
- def _update(
- self,
- db_item,
- update_dict,
- unset=None,
- pull=None,
- push=None,
- push_list=None,
- pull_list=None,
- ):
- """
- Modifies an entry at database
- :param db_item: entry of the table to update
- :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
- :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
- ignored. If not exist, it is ignored
- :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
- if exist in the array is removed. If not exist, it is ignored
- :param pull_list: Same as pull but values are arrays where each item is removed from the array
- :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
- is appended to the end of the array
- :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
- whole array
- :return: True if database has been changed, False if not; Exception on error
- """
-
- def _iterate_keys(k, db_nested, populate=True):
- k_list = k.split(".")
- k_item_prev = k_list[0]
- populated = False
- if k_item_prev not in db_nested and populate:
- populated = True
- db_nested[k_item_prev] = None
- for k_item in k_list[1:]:
- if isinstance(db_nested[k_item_prev], dict):
- if k_item not in db_nested[k_item_prev]:
- if not populate:
- raise DbException(
- "Cannot set '{}', not existing '{}'".format(k, k_item)
- )
- populated = True
- db_nested[k_item_prev][k_item] = None
- elif isinstance(db_nested[k_item_prev], list) and k_item.isdigit():
- # extend list with Nones if index greater than list
- k_item = int(k_item)
- if k_item >= len(db_nested[k_item_prev]):
- if not populate:
- raise DbException(
- "Cannot set '{}', index too large '{}'".format(
- k, k_item
- )
- )
- populated = True
- db_nested[k_item_prev] += [None] * (
- k_item - len(db_nested[k_item_prev]) + 1
- )
- elif db_nested[k_item_prev] is None:
- if not populate:
- raise DbException(
- "Cannot set '{}', not existing '{}'".format(k, k_item)
- )
- populated = True
- db_nested[k_item_prev] = {k_item: None}
- else: # number, string, boolean, ... or list but with not integer key
- raise DbException(
- "Cannot set '{}' on existing '{}={}'".format(
- k, k_item_prev, db_nested[k_item_prev]
- )
- )
- db_nested = db_nested[k_item_prev]
- k_item_prev = k_item
- return db_nested, k_item_prev, populated
-
- updated = False
- try:
- if update_dict:
- for dot_k, v in update_dict.items():
- dict_to_update, key_to_update, _ = _iterate_keys(dot_k, db_item)
- dict_to_update[key_to_update] = v
- updated = True
- if unset:
- for dot_k in unset:
- try:
- dict_to_update, key_to_update, _ = _iterate_keys(
- dot_k, db_item, populate=False
- )
- del dict_to_update[key_to_update]
- updated = True
- except Exception as unset_error:
- self.logger.error(f"{unset_error} occured while updating DB.")
- if pull:
- for dot_k, v in pull.items():
- try:
- dict_to_update, key_to_update, _ = _iterate_keys(
- dot_k, db_item, populate=False
- )
- except Exception as pull_error:
- self.logger.error(f"{pull_error} occured while updating DB.")
- continue
-
- if key_to_update not in dict_to_update:
- continue
- if not isinstance(dict_to_update[key_to_update], list):
- raise DbException(
- "Cannot pull '{}'. Target is not a list".format(dot_k)
- )
- while v in dict_to_update[key_to_update]:
- dict_to_update[key_to_update].remove(v)
- updated = True
- if pull_list:
- for dot_k, v in pull_list.items():
- if not isinstance(v, list):
- raise DbException(
- "Invalid content at pull_list, '{}' must be an array".format(
- dot_k
- ),
- http_code=HTTPStatus.BAD_REQUEST,
- )
- try:
- dict_to_update, key_to_update, _ = _iterate_keys(
- dot_k, db_item, populate=False
- )
- except Exception as iterate_error:
- self.logger.error(
- f"{iterate_error} occured while iterating keys in db update."
- )
- continue
-
- if key_to_update not in dict_to_update:
- continue
- if not isinstance(dict_to_update[key_to_update], list):
- raise DbException(
- "Cannot pull_list '{}'. Target is not a list".format(dot_k)
- )
- for single_v in v:
- while single_v in dict_to_update[key_to_update]:
- dict_to_update[key_to_update].remove(single_v)
- updated = True
- if push:
- for dot_k, v in push.items():
- dict_to_update, key_to_update, populated = _iterate_keys(
- dot_k, db_item
- )
- if (
- isinstance(dict_to_update, dict)
- and key_to_update not in dict_to_update
- ):
- dict_to_update[key_to_update] = [v]
- updated = True
- elif populated and dict_to_update[key_to_update] is None:
- dict_to_update[key_to_update] = [v]
- updated = True
- elif not isinstance(dict_to_update[key_to_update], list):
- raise DbException(
- "Cannot push '{}'. Target is not a list".format(dot_k)
- )
- else:
- dict_to_update[key_to_update].append(v)
- updated = True
- if push_list:
- for dot_k, v in push_list.items():
- if not isinstance(v, list):
- raise DbException(
- "Invalid content at push_list, '{}' must be an array".format(
- dot_k
- ),
- http_code=HTTPStatus.BAD_REQUEST,
- )
- dict_to_update, key_to_update, populated = _iterate_keys(
- dot_k, db_item
- )
- if (
- isinstance(dict_to_update, dict)
- and key_to_update not in dict_to_update
- ):
- dict_to_update[key_to_update] = v.copy()
- updated = True
- elif populated and dict_to_update[key_to_update] is None:
- dict_to_update[key_to_update] = v.copy()
- updated = True
- elif not isinstance(dict_to_update[key_to_update], list):
- raise DbException(
- "Cannot push '{}'. Target is not a list".format(dot_k),
- http_code=HTTPStatus.CONFLICT,
- )
- else:
- dict_to_update[key_to_update] += v
- updated = True
-
- return updated
- except DbException:
- raise
- except Exception as e: # TODO refine
- raise DbException(str(e))
-
- def set_one(
- self,
- table,
- q_filter,
- update_dict,
- fail_on_empty=True,
- unset=None,
- pull=None,
- push=None,
- push_list=None,
- pull_list=None,
- ):
- """
- Modifies an entry at database
- :param table: collection or table
- :param q_filter: Filter
- :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
- :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
- it raises a DbException
- :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
- ignored. If not exist, it is ignored
- :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
- if exist in the array is removed. If not exist, it is ignored
- :param pull_list: Same as pull but values are arrays where each item is removed from the array
- :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
- is appended to the end of the array
- :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
- whole array
- :return: Dict with the number of entries modified. None if no matching is found.
- """
- with self.lock:
- for i, db_item in self._find(table, self._format_filter(q_filter)):
- updated = self._update(
- db_item,
- update_dict,
- unset=unset,
- pull=pull,
- push=push,
- push_list=push_list,
- pull_list=pull_list,
- )
- return {"updated": 1 if updated else 0}
- else:
- if fail_on_empty:
- raise DbException(
- "Not found entry with _id='{}'".format(q_filter),
- HTTPStatus.NOT_FOUND,
- )
- return None
-
- def set_list(
- self,
- table,
- q_filter,
- update_dict,
- unset=None,
- pull=None,
- push=None,
- push_list=None,
- pull_list=None,
- ):
- """Modifies al matching entries at database. Same as push. Do not fail if nothing matches"""
- with self.lock:
- updated = 0
- found = 0
- for _, db_item in self._find(table, self._format_filter(q_filter)):
- found += 1
- if self._update(
- db_item,
- update_dict,
- unset=unset,
- pull=pull,
- push=push,
- push_list=push_list,
- pull_list=pull_list,
- ):
- updated += 1
- # if not found and fail_on_empty:
- # raise DbException("Not found entry with '{}'".format(q_filter), HTTPStatus.NOT_FOUND)
- return {"updated": updated} if found else None
-
- def replace(self, table, _id, indata, fail_on_empty=True):
- """
- Replace the content of an entry
- :param table: collection or table
- :param _id: internal database id
- :param indata: content to replace
- :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
- it raises a DbException
- :return: Dict with the number of entries replaced
- """
- try:
- with self.lock:
- for i, _ in self._find(table, self._format_filter({"_id": _id})):
- break
- else:
- if fail_on_empty:
- raise DbException(
- "Not found entry with _id='{}'".format(_id),
- HTTPStatus.NOT_FOUND,
- )
- return None
- self.db[table][i] = deepcopy(indata)
- return {"updated": 1}
- except DbException:
- raise
- except Exception as e: # TODO refine
- raise DbException(str(e))
-
- def create(self, table, indata):
- """
- Add a new entry at database
- :param table: collection or table
- :param indata: content to be added
- :return: database '_id' of the inserted element. Raises a DbException on error
- """
- try:
- id = indata.get("_id")
- if not id:
- id = str(uuid4())
- indata["_id"] = id
- with self.lock:
- if table not in self.db:
- self.db[table] = []
- self.db[table].append(deepcopy(indata))
- return id
- except Exception as e: # TODO refine
- raise DbException(str(e))
-
- def create_list(self, table, indata_list):
- """
- Add a new entry at database
- :param table: collection or table
- :param indata_list: list content to be added
- :return: list of inserted 'id's. Raises a DbException on error
- """
- try:
- _ids = []
- with self.lock:
- for indata in indata_list:
- _id = indata.get("_id")
- if not _id:
- _id = str(uuid4())
- indata["_id"] = _id
- with self.lock:
- if table not in self.db:
- self.db[table] = []
- self.db[table].append(deepcopy(indata))
- _ids.append(_id)
- return _ids
- except Exception as e: # TODO refine
- raise DbException(str(e))
-
-
-if __name__ == "__main__":
- # some test code
- db = DbMemory()
- db.create("test", {"_id": 1, "data": 1})
- db.create("test", {"_id": 2, "data": 2})
- db.create("test", {"_id": 3, "data": 3})
- print("must be 3 items:", db.get_list("test"))
- print("must return item 2:", db.get_list("test", {"_id": 2}))
- db.del_one("test", {"_id": 2})
- print("must be emtpy:", db.get_list("test", {"_id": 2}))
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from base64 import b64decode
-from copy import deepcopy
-from http import HTTPStatus
-import logging
-from time import sleep, time
-from uuid import uuid4
-
-from osm_common.dbbase import DbBase, DbException
-from pymongo import errors, MongoClient
-
-__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
-
-# TODO consider use this decorator for database access retries
-# @retry_mongocall
-# def retry_mongocall(call):
-# def _retry_mongocall(*args, **kwargs):
-# retry = 1
-# while True:
-# try:
-# return call(*args, **kwargs)
-# except pymongo.AutoReconnect as e:
-# if retry == 4:
-# raise DbException(e)
-# sleep(retry)
-# return _retry_mongocall
-
-
-def deep_update(to_update, update_with):
- """
- Similar to deepcopy but recursively with nested dictionaries. 'to_update' dict is updated with a content copy of
- 'update_with' dict recursively
- :param to_update: must be a dictionary to be modified
- :param update_with: must be a dictionary. It is not changed
- :return: to_update
- """
- for key in update_with:
- if key in to_update:
- if isinstance(to_update[key], dict) and isinstance(update_with[key], dict):
- deep_update(to_update[key], update_with[key])
- continue
- to_update[key] = deepcopy(update_with[key])
- return to_update
-
-
-class DbMongo(DbBase):
- conn_initial_timout = 120
- conn_timout = 10
-
- def __init__(self, logger_name="db", lock=False):
- super().__init__(logger_name=logger_name, lock=lock)
- self.client = None
- self.db = None
- self.database_key = None
- self.secret_obtained = False
- # ^ This is used to know if database serial has been got. Database is inited by NBI, who generates the serial
- # In case it is not ready when connected, it should be got later on before any decrypt operation
-
- def get_secret_key(self):
- if self.secret_obtained:
- return
-
- self.secret_key = None
- if self.database_key:
- self.set_secret_key(self.database_key)
- version_data = self.get_one(
- "admin", {"_id": "version"}, fail_on_empty=False, fail_on_more=True
- )
- if version_data and version_data.get("serial"):
- self.set_secret_key(b64decode(version_data["serial"]))
- self.secret_obtained = True
-
- def db_connect(self, config, target_version=None):
- """
- Connect to database
- :param config: Configuration of database
- :param target_version: if provided it checks if database contains required version, raising exception otherwise.
- :return: None or raises DbException on error
- """
- try:
- if "logger_name" in config:
- self.logger = logging.getLogger(config["logger_name"])
- master_key = config.get("commonkey") or config.get("masterpassword")
- if master_key:
- self.database_key = master_key
- self.set_secret_key(master_key)
- if config.get("uri"):
- self.client = MongoClient(
- config["uri"], replicaSet=config.get("replicaset", None)
- )
- # when all modules are ready
- self.db = self.client[config["name"]]
- if "loglevel" in config:
- self.logger.setLevel(getattr(logging, config["loglevel"]))
- # get data to try a connection
- now = time()
- while True:
- try:
- version_data = self.get_one(
- "admin",
- {"_id": "version"},
- fail_on_empty=False,
- fail_on_more=True,
- )
- # check database status is ok
- if version_data and version_data.get("status") != "ENABLED":
- raise DbException(
- "Wrong database status '{}'".format(
- version_data.get("status")
- ),
- http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
- )
- # check version
- db_version = (
- None if not version_data else version_data.get("version")
- )
- if target_version and target_version != db_version:
- raise DbException(
- "Invalid database version {}. Expected {}".format(
- db_version, target_version
- )
- )
- # get serial
- if version_data and version_data.get("serial"):
- self.secret_obtained = True
- self.set_secret_key(b64decode(version_data["serial"]))
- self.logger.info(
- "Connected to database {} version {}".format(
- config["name"], db_version
- )
- )
- return
- except errors.ConnectionFailure as e:
- if time() - now >= self.conn_initial_timout:
- raise
- self.logger.info("Waiting to database up {}".format(e))
- sleep(2)
- except errors.PyMongoError as e:
- raise DbException(e)
-
- @staticmethod
- def _format_filter(q_filter):
- """
- Translate query string q_filter into mongo database filter
- :param q_filter: Query string content. Follows SOL005 section 4.3.2 guidelines, with the follow extensions and
- differences:
- It accept ".nq" (not equal) in addition to ".neq".
- For arrays you can specify index (concrete index must match), nothing (any index may match) or 'ANYINDEX'
- (two or more matches applies for the same array element). Examples:
- with database register: {A: [{B: 1, C: 2}, {B: 6, C: 9}]}
- query 'A.B=6' matches because array A contains one element with B equal to 6
- query 'A.0.B=6' does no match because index 0 of array A contains B with value 1, but not 6
- query 'A.B=6&A.C=2' matches because one element of array matches B=6 and other matchesC=2
- query 'A.ANYINDEX.B=6&A.ANYINDEX.C=2' does not match because it is needed the same element of the
- array matching both
-
- Examples of translations from SOL005 to >> mongo # comment
- A=B; A.eq=B >> A: B # must contain key A and equal to B or be a list that contains B
- A.cont=B >> A: B
- A=B&A=C; A=B,C >> A: {$in: [B, C]} # must contain key A and equal to B or C or be a list that contains
- # B or C
- A.cont=B&A.cont=C; A.cont=B,C >> A: {$in: [B, C]}
- A.ncont=B >> A: {$nin: B} # must not contain key A or if present not equal to B or if a list,
- # it must not not contain B
- A.ncont=B,C; A.ncont=B&A.ncont=C >> A: {$nin: [B,C]} # must not contain key A or if present not equal
- # neither B nor C; or if a list, it must not contain neither B nor C
- A.ne=B&A.ne=C; A.ne=B,C >> A: {$nin: [B, C]}
- A.gt=B >> A: {$gt: B} # must contain key A and greater than B
- A.ne=B; A.neq=B >> A: {$ne: B} # must not contain key A or if present not equal to B, or if
- # an array not contain B
- A.ANYINDEX.B=C >> A: {$elemMatch: {B=C}
- :return: database mongo filter
- """
- try:
- db_filter = {}
- if not q_filter:
- return db_filter
- for query_k, query_v in q_filter.items():
- dot_index = query_k.rfind(".")
- if dot_index > 1 and query_k[dot_index + 1 :] in (
- "eq",
- "ne",
- "gt",
- "gte",
- "lt",
- "lte",
- "cont",
- "ncont",
- "neq",
- ):
- operator = "$" + query_k[dot_index + 1 :]
- if operator == "$neq":
- operator = "$ne"
- k = query_k[:dot_index]
- else:
- operator = "$eq"
- k = query_k
-
- v = query_v
- if isinstance(v, list):
- if operator in ("$eq", "$cont"):
- operator = "$in"
- v = query_v
- elif operator in ("$ne", "$ncont"):
- operator = "$nin"
- v = query_v
- else:
- v = query_v.join(",")
-
- if operator in ("$eq", "$cont"):
- # v cannot be a comma separated list, because operator would have been changed to $in
- db_v = v
- elif operator == "$ncount":
- # v cannot be a comma separated list, because operator would have been changed to $nin
- db_v = {"$ne": v}
- else:
- db_v = {operator: v}
-
- # process the ANYINDEX word at k.
- kleft, _, kright = k.rpartition(".ANYINDEX.")
- while kleft:
- k = kleft
- db_v = {"$elemMatch": {kright: db_v}}
- kleft, _, kright = k.rpartition(".ANYINDEX.")
-
- # insert in db_filter
- # maybe db_filter[k] exist. e.g. in the query string for values between 5 and 8: "a.gt=5&a.lt=8"
- deep_update(db_filter, {k: db_v})
-
- return db_filter
- except Exception as e:
- raise DbException(
- "Invalid query string filter at {}:{}. Error: {}".format(query_k, v, e),
- http_code=HTTPStatus.BAD_REQUEST,
- )
-
- def get_list(self, table, q_filter=None):
- """
- Obtain a list of entries matching q_filter
- :param table: collection or table
- :param q_filter: Filter
- :return: a list (can be empty) with the found entries. Raises DbException on error
- """
- try:
- result = []
- with self.lock:
- collection = self.db[table]
- db_filter = self._format_filter(q_filter)
- rows = collection.find(db_filter)
- for row in rows:
- result.append(row)
- return result
- except DbException:
- raise
- except Exception as e: # TODO refine
- raise DbException(e)
-
- def count(self, table, q_filter=None):
- """
- Count the number of entries matching q_filter
- :param table: collection or table
- :param q_filter: Filter
- :return: number of entries found (can be zero)
- :raise: DbException on error
- """
- try:
- with self.lock:
- collection = self.db[table]
- db_filter = self._format_filter(q_filter)
- count = collection.count_documents(db_filter)
- return count
- except DbException:
- raise
- except Exception as e: # TODO refine
- raise DbException(e)
-
- def get_one(self, table, q_filter=None, fail_on_empty=True, fail_on_more=True):
- """
- Obtain one entry matching q_filter
- :param table: collection or table
- :param q_filter: Filter
- :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
- it raises a DbException
- :param fail_on_more: If more than one matches filter it returns one of then unless this flag is set tu True, so
- that it raises a DbException
- :return: The requested element, or None
- """
- try:
- db_filter = self._format_filter(q_filter)
- with self.lock:
- collection = self.db[table]
- if not (fail_on_empty and fail_on_more):
- return collection.find_one(db_filter)
- rows = list(collection.find(db_filter))
- if len(rows) == 0:
- if fail_on_empty:
- raise DbException(
- "Not found any {} with filter='{}'".format(
- table[:-1], q_filter
- ),
- HTTPStatus.NOT_FOUND,
- )
-
- return None
- elif len(rows) > 1:
- if fail_on_more:
- raise DbException(
- "Found more than one {} with filter='{}'".format(
- table[:-1], q_filter
- ),
- HTTPStatus.CONFLICT,
- )
- return rows[0]
- except Exception as e: # TODO refine
- raise DbException(e)
-
- def del_list(self, table, q_filter=None):
- """
- Deletes all entries that match q_filter
- :param table: collection or table
- :param q_filter: Filter
- :return: Dict with the number of entries deleted
- """
- try:
- with self.lock:
- collection = self.db[table]
- rows = collection.delete_many(self._format_filter(q_filter))
- return {"deleted": rows.deleted_count}
- except DbException:
- raise
- except Exception as e: # TODO refine
- raise DbException(e)
-
- def del_one(self, table, q_filter=None, fail_on_empty=True):
- """
- Deletes one entry that matches q_filter
- :param table: collection or table
- :param q_filter: Filter
- :param fail_on_empty: If nothing matches filter it returns '0' deleted unless this flag is set tu True, in
- which case it raises a DbException
- :return: Dict with the number of entries deleted
- """
- try:
- with self.lock:
- collection = self.db[table]
- rows = collection.delete_one(self._format_filter(q_filter))
- if rows.deleted_count == 0:
- if fail_on_empty:
- raise DbException(
- "Not found any {} with filter='{}'".format(
- table[:-1], q_filter
- ),
- HTTPStatus.NOT_FOUND,
- )
- return None
- return {"deleted": rows.deleted_count}
- except Exception as e: # TODO refine
- raise DbException(e)
-
- def create(self, table, indata):
- """
- Add a new entry at database
- :param table: collection or table
- :param indata: content to be added
- :return: database id of the inserted element. Raises a DbException on error
- """
- try:
- with self.lock:
- collection = self.db[table]
- data = collection.insert_one(indata)
- return data.inserted_id
- except Exception as e: # TODO refine
- raise DbException(e)
-
- def create_list(self, table, indata_list):
- """
- Add several entries at once
- :param table: collection or table
- :param indata_list: content list to be added.
- :return: the list of inserted '_id's. Exception on error
- """
- try:
- for item in indata_list:
- if item.get("_id") is None:
- item["_id"] = str(uuid4())
- with self.lock:
- collection = self.db[table]
- data = collection.insert_many(indata_list)
- return data.inserted_ids
- except Exception as e: # TODO refine
- raise DbException(e)
-
- def set_one(
- self,
- table,
- q_filter,
- update_dict,
- fail_on_empty=True,
- unset=None,
- pull=None,
- push=None,
- push_list=None,
- pull_list=None,
- upsert=False,
- ):
- """
- Modifies an entry at database
- :param table: collection or table
- :param q_filter: Filter
- :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
- :param fail_on_empty: If nothing matches filter it returns None unless this flag is set to True, in which case
- it raises a DbException
- :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
- ignored. If not exist, it is ignored
- :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
- if exist in the array is removed. If not exist, it is ignored
- :param pull_list: Same as pull but values are arrays where each item is removed from the array
- :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
- is appended to the end of the array
- :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
- whole array
- :param upsert: If this parameter is set to True and no document is found using 'q_filter' it will be created.
- By default this is false.
- :return: Dict with the number of entries modified. None if no matching is found.
- """
- try:
- db_oper = {}
- if update_dict:
- db_oper["$set"] = update_dict
- if unset:
- db_oper["$unset"] = unset
- if pull or pull_list:
- db_oper["$pull"] = pull or {}
- if pull_list:
- db_oper["$pull"].update(
- {k: {"$in": v} for k, v in pull_list.items()}
- )
- if push or push_list:
- db_oper["$push"] = push or {}
- if push_list:
- db_oper["$push"].update(
- {k: {"$each": v} for k, v in push_list.items()}
- )
-
- with self.lock:
- collection = self.db[table]
- rows = collection.update_one(
- self._format_filter(q_filter), db_oper, upsert=upsert
- )
- if rows.matched_count == 0:
- if fail_on_empty:
- raise DbException(
- "Not found any {} with filter='{}'".format(
- table[:-1], q_filter
- ),
- HTTPStatus.NOT_FOUND,
- )
- return None
- return {"modified": rows.modified_count}
- except Exception as e: # TODO refine
- raise DbException(e)
-
- def set_list(
- self,
- table,
- q_filter,
- update_dict,
- unset=None,
- pull=None,
- push=None,
- push_list=None,
- pull_list=None,
- ):
- """
- Modifies al matching entries at database
- :param table: collection or table
- :param q_filter: Filter
- :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
- :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
- ignored. If not exist, it is ignored
- :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
- if exist in the array is removed. If not exist, it is ignored
- :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys, the
- single value is appended to the end of the array
- :param pull_list: Same as pull but values are arrays where each item is removed from the array
- :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
- whole array
- :return: Dict with the number of entries modified
- """
- try:
- db_oper = {}
- if update_dict:
- db_oper["$set"] = update_dict
- if unset:
- db_oper["$unset"] = unset
- if pull or pull_list:
- db_oper["$pull"] = pull or {}
- if pull_list:
- db_oper["$pull"].update(
- {k: {"$in": v} for k, v in pull_list.items()}
- )
- if push or push_list:
- db_oper["$push"] = push or {}
- if push_list:
- db_oper["$push"].update(
- {k: {"$each": v} for k, v in push_list.items()}
- )
- with self.lock:
- collection = self.db[table]
- rows = collection.update_many(self._format_filter(q_filter), db_oper)
- return {"modified": rows.modified_count}
- except Exception as e: # TODO refine
- raise DbException(e)
-
- def replace(self, table, _id, indata, fail_on_empty=True):
- """
- Replace the content of an entry
- :param table: collection or table
- :param _id: internal database id
- :param indata: content to replace
- :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
- it raises a DbException
- :return: Dict with the number of entries replaced
- """
- try:
- db_filter = {"_id": _id}
- with self.lock:
- collection = self.db[table]
- rows = collection.replace_one(db_filter, indata)
- if rows.matched_count == 0:
- if fail_on_empty:
- raise DbException(
- "Not found any {} with _id='{}'".format(table[:-1], _id),
- HTTPStatus.NOT_FOUND,
- )
- return None
- return {"replaced": rows.modified_count}
- except Exception as e: # TODO refine
- raise DbException(e)
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from http import HTTPStatus
-import logging
-from threading import Lock
-
-from osm_common.common_utils import FakeLock
-
-
-__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
-
-
-class FsException(Exception):
- def __init__(self, message, http_code=HTTPStatus.INTERNAL_SERVER_ERROR):
- self.http_code = http_code
- Exception.__init__(self, "storage exception " + message)
-
-
-class FsBase(object):
- def __init__(self, logger_name="fs", lock=False):
- """
- Constructor of FsBase
- :param logger_name: logging name
- :param lock: Used to protect simultaneous access to the same instance class by several threads:
- False, None: Do not protect, this object will only be accessed by one thread
- True: This object needs to be protected by several threads accessing.
- Lock object. Use thi Lock for the threads access protection
- """
- self.logger = logging.getLogger(logger_name)
- if not lock:
- self.lock = FakeLock()
- elif lock is True:
- self.lock = Lock()
- elif isinstance(lock, Lock):
- self.lock = lock
- else:
- raise ValueError("lock parameter must be a Lock class or boolean")
-
- def get_params(self):
- return {}
-
- def fs_connect(self, config):
- pass
-
- def fs_disconnect(self):
- pass
-
- def mkdir(self, folder):
- raise FsException("Method 'mkdir' not implemented")
-
- def dir_rename(self, src, dst):
- raise FsException("Method 'dir_rename' not implemented")
-
- def dir_ls(self, storage):
- raise FsException("Method 'dir_ls' not implemented")
-
- def file_exists(self, storage):
- raise FsException("Method 'file_exists' not implemented")
-
- def file_size(self, storage):
- raise FsException("Method 'file_size' not implemented")
-
- def file_extract(self, tar_object, path):
- raise FsException("Method 'file_extract' not implemented")
-
- def file_open(self, storage, mode):
- raise FsException("Method 'file_open' not implemented")
-
- def file_delete(self, storage, ignore_non_exist=False):
- raise FsException("Method 'file_delete' not implemented")
-
- def sync(self, from_path=None):
- raise FsException("Method 'sync' not implemented")
-
- def reverse_sync(self, from_path):
- raise FsException("Method 'reverse_sync' not implemented")
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from http import HTTPStatus
-import logging
-import os
-from shutil import rmtree
-import tarfile
-import zipfile
-
-from osm_common.fsbase import FsBase, FsException
-
-__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
-
-
-class FsLocal(FsBase):
- def __init__(self, logger_name="fs", lock=False):
- super().__init__(logger_name, lock)
- self.path = None
-
- def get_params(self):
- return {"fs": "local", "path": self.path}
-
- def fs_connect(self, config):
- try:
- if "logger_name" in config:
- self.logger = logging.getLogger(config["logger_name"])
- self.path = config["path"]
- if not self.path.endswith("/"):
- self.path += "/"
- if not os.path.exists(self.path):
- raise FsException(
- "Invalid configuration param at '[storage]': path '{}' does not exist".format(
- config["path"]
- )
- )
- except FsException:
- raise
- except Exception as e: # TODO refine
- raise FsException(str(e))
-
- def fs_disconnect(self):
- pass # TODO
-
- def mkdir(self, folder):
- """
- Creates a folder or parent object location
- :param folder:
- :return: None or raises and exception
- """
- try:
- os.mkdir(self.path + folder)
- except FileExistsError: # make it idempotent
- pass
- except Exception as e:
- raise FsException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-
- def dir_rename(self, src, dst):
- """
- Rename one directory name. If dst exist, it replaces (deletes) existing directory
- :param src: source directory
- :param dst: destination directory
- :return: None or raises and exception
- """
- try:
- if os.path.exists(self.path + dst):
- rmtree(self.path + dst)
-
- os.rename(self.path + src, self.path + dst)
-
- except Exception as e:
- raise FsException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-
- def file_exists(self, storage, mode=None):
- """
- Indicates if "storage" file exist
- :param storage: can be a str or a str list
- :param mode: can be 'file' exist as a regular file; 'dir' exists as a directory or; 'None' just exists
- :return: True, False
- """
- if isinstance(storage, str):
- f = storage
- else:
- f = "/".join(storage)
- if os.path.exists(self.path + f):
- if not mode:
- return True
- if mode == "file" and os.path.isfile(self.path + f):
- return True
- if mode == "dir" and os.path.isdir(self.path + f):
- return True
- return False
-
- def file_size(self, storage):
- """
- return file size
- :param storage: can be a str or a str list
- :return: file size
- """
- if isinstance(storage, str):
- f = storage
- else:
- f = "/".join(storage)
- return os.path.getsize(self.path + f)
-
- def file_extract(self, compressed_object, path):
- """
- extract a tar file
- :param compressed_object: object of type tar or zip
- :param path: can be a str or a str list, or a tar object where to extract the tar_object
- :return: None
- """
- if isinstance(path, str):
- f = self.path + path
- else:
- f = self.path + "/".join(path)
-
- if type(compressed_object) is tarfile.TarFile:
- compressed_object.extractall(path=f)
- elif (
- type(compressed_object) is zipfile.ZipFile
- ): # Just a check to know if this works with both tar and zip
- compressed_object.extractall(path=f)
-
- def file_open(self, storage, mode):
- """
- Open a file
- :param storage: can be a str or list of str
- :param mode: file mode
- :return: file object
- """
- try:
- if isinstance(storage, str):
- f = storage
- else:
- f = "/".join(storage)
- return open(self.path + f, mode)
- except FileNotFoundError:
- raise FsException(
- "File {} does not exist".format(f), http_code=HTTPStatus.NOT_FOUND
- )
- except IOError:
- raise FsException(
- "File {} cannot be opened".format(f), http_code=HTTPStatus.BAD_REQUEST
- )
-
- def dir_ls(self, storage):
- """
- return folder content
- :param storage: can be a str or list of str
- :return: folder content
- """
- try:
- if isinstance(storage, str):
- f = storage
- else:
- f = "/".join(storage)
- return os.listdir(self.path + f)
- except NotADirectoryError:
- raise FsException(
- "File {} does not exist".format(f), http_code=HTTPStatus.NOT_FOUND
- )
- except IOError:
- raise FsException(
- "File {} cannot be opened".format(f), http_code=HTTPStatus.BAD_REQUEST
- )
-
- def file_delete(self, storage, ignore_non_exist=False):
- """
- Delete storage content recursively
- :param storage: can be a str or list of str
- :param ignore_non_exist: not raise exception if storage does not exist
- :return: None
- """
- try:
- if isinstance(storage, str):
- f = self.path + storage
- else:
- f = self.path + "/".join(storage)
- if os.path.exists(f):
- rmtree(f)
- elif not ignore_non_exist:
- raise FsException(
- "File {} does not exist".format(storage),
- http_code=HTTPStatus.NOT_FOUND,
- )
- except (IOError, PermissionError) as e:
- raise FsException(
- "File {} cannot be deleted: {}".format(f, e),
- http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
- )
-
- def sync(self, from_path=None):
- pass # Not needed in fslocal
-
- def reverse_sync(self, from_path):
- pass # Not needed in fslocal
+++ /dev/null
-# Copyright 2019 Canonical
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: eduardo.sousa@canonical.com
-##
-import datetime
-import errno
-from http import HTTPStatus
-from io import BytesIO, StringIO
-import logging
-import os
-import tarfile
-import zipfile
-
-from gridfs import errors, GridFSBucket
-from osm_common.fsbase import FsBase, FsException
-from pymongo import MongoClient
-
-
-__author__ = "Eduardo Sousa <eduardo.sousa@canonical.com>"
-
-
-class GridByteStream(BytesIO):
- def __init__(self, filename, fs, mode):
- BytesIO.__init__(self)
- self._id = None
- self.filename = filename
- self.fs = fs
- self.mode = mode
- self.file_type = "file" # Set "file" as default file_type
-
- self.__initialize__()
-
- def __initialize__(self):
- grid_file = None
-
- cursor = self.fs.find({"filename": self.filename})
-
- for requested_file in cursor:
- exception_file = next(cursor, None)
-
- if exception_file:
- raise FsException(
- "Multiple files found", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
- )
-
- if requested_file.metadata["type"] in ("file", "sym"):
- grid_file = requested_file
- self.file_type = requested_file.metadata["type"]
- else:
- raise FsException(
- "Type isn't file", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
- )
-
- if grid_file:
- self._id = grid_file._id
- self.fs.download_to_stream(self._id, self)
-
- if "r" in self.mode:
- self.seek(0, 0)
-
- def close(self):
- if "r" in self.mode:
- super(GridByteStream, self).close()
- return
-
- if self._id:
- self.fs.delete(self._id)
-
- cursor = self.fs.find(
- {"filename": self.filename.split("/")[0], "metadata": {"type": "dir"}}
- )
-
- parent_dir = next(cursor, None)
-
- if not parent_dir:
- parent_dir_name = self.filename.split("/")[0]
- self.filename = self.filename.replace(
- parent_dir_name, parent_dir_name[:-1], 1
- )
-
- self.seek(0, 0)
- if self._id:
- self.fs.upload_from_stream_with_id(
- self._id, self.filename, self, metadata={"type": self.file_type}
- )
- else:
- self.fs.upload_from_stream(
- self.filename, self, metadata={"type": self.file_type}
- )
- super(GridByteStream, self).close()
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
-
-
-class GridStringStream(StringIO):
- def __init__(self, filename, fs, mode):
- StringIO.__init__(self)
- self._id = None
- self.filename = filename
- self.fs = fs
- self.mode = mode
- self.file_type = "file" # Set "file" as default file_type
-
- self.__initialize__()
-
- def __initialize__(self):
- grid_file = None
-
- cursor = self.fs.find({"filename": self.filename})
-
- for requested_file in cursor:
- exception_file = next(cursor, None)
-
- if exception_file:
- raise FsException(
- "Multiple files found", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
- )
-
- if requested_file.metadata["type"] in ("file", "dir"):
- grid_file = requested_file
- self.file_type = requested_file.metadata["type"]
- else:
- raise FsException(
- "File type isn't file", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
- )
-
- if grid_file:
- stream = BytesIO()
- self._id = grid_file._id
- self.fs.download_to_stream(self._id, stream)
- stream.seek(0)
- self.write(stream.read().decode("utf-8"))
- stream.close()
-
- if "r" in self.mode:
- self.seek(0, 0)
-
- def close(self):
- if "r" in self.mode:
- super(GridStringStream, self).close()
- return
-
- if self._id:
- self.fs.delete(self._id)
-
- cursor = self.fs.find(
- {"filename": self.filename.split("/")[0], "metadata": {"type": "dir"}}
- )
-
- parent_dir = next(cursor, None)
-
- if not parent_dir:
- parent_dir_name = self.filename.split("/")[0]
- self.filename = self.filename.replace(
- parent_dir_name, parent_dir_name[:-1], 1
- )
-
- self.seek(0, 0)
- stream = BytesIO()
- stream.write(self.read().encode("utf-8"))
- stream.seek(0, 0)
- if self._id:
- self.fs.upload_from_stream_with_id(
- self._id, self.filename, stream, metadata={"type": self.file_type}
- )
- else:
- self.fs.upload_from_stream(
- self.filename, stream, metadata={"type": self.file_type}
- )
- stream.close()
- super(GridStringStream, self).close()
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
-
-
-class FsMongo(FsBase):
- def __init__(self, logger_name="fs", lock=False):
- super().__init__(logger_name, lock)
- self.path = None
- self.client = None
- self.fs = None
-
- def __update_local_fs(self, from_path=None):
- dir_cursor = self.fs.find({"metadata.type": "dir"}, no_cursor_timeout=True)
-
- valid_paths = []
-
- for directory in dir_cursor:
- if from_path and not directory.filename.startswith(from_path):
- continue
- self.logger.debug("Making dir {}".format(self.path + directory.filename))
- os.makedirs(self.path + directory.filename, exist_ok=True)
- valid_paths.append(self.path + directory.filename)
-
- file_cursor = self.fs.find(
- {"metadata.type": {"$in": ["file", "sym"]}}, no_cursor_timeout=True
- )
-
- for writing_file in file_cursor:
- if from_path and not writing_file.filename.startswith(from_path):
- continue
- file_path = self.path + writing_file.filename
-
- if writing_file.metadata["type"] == "sym":
- with BytesIO() as b:
- self.fs.download_to_stream(writing_file._id, b)
- b.seek(0)
- link = b.read().decode("utf-8")
-
- try:
- self.logger.debug("Sync removing {}".format(file_path))
- os.remove(file_path)
- except OSError as e:
- if e.errno != errno.ENOENT:
- # This is probably permission denied or worse
- raise
- os.symlink(
- link, os.path.realpath(os.path.normpath(os.path.abspath(file_path)))
- )
- else:
- folder = os.path.dirname(file_path)
- if folder not in valid_paths:
- self.logger.debug("Sync local directory {}".format(file_path))
- os.makedirs(folder, exist_ok=True)
- with open(file_path, "wb+") as file_stream:
- self.logger.debug("Sync download {}".format(file_path))
- self.fs.download_to_stream(writing_file._id, file_stream)
- if "permissions" in writing_file.metadata:
- os.chmod(file_path, writing_file.metadata["permissions"])
-
- def get_params(self):
- return {"fs": "mongo", "path": self.path}
-
- def fs_connect(self, config):
- try:
- if "logger_name" in config:
- self.logger = logging.getLogger(config["logger_name"])
- if "path" in config:
- self.path = config["path"]
- else:
- raise FsException('Missing parameter "path"')
- if not self.path.endswith("/"):
- self.path += "/"
- if not os.path.exists(self.path):
- raise FsException(
- "Invalid configuration param at '[storage]': path '{}' does not exist".format(
- config["path"]
- )
- )
- elif not os.access(self.path, os.W_OK):
- raise FsException(
- "Invalid configuration param at '[storage]': path '{}' is not writable".format(
- config["path"]
- )
- )
- if all(key in config.keys() for key in ["uri", "collection"]):
- self.client = MongoClient(config["uri"])
- self.fs = GridFSBucket(self.client[config["collection"]])
- else:
- if "collection" not in config.keys():
- raise FsException('Missing parameter "collection"')
- else:
- raise FsException('Missing parameters: "uri"')
- except FsException:
- raise
- except Exception as e: # TODO refine
- raise FsException(str(e))
-
- def fs_disconnect(self):
- pass # TODO
-
- def mkdir(self, folder):
- """
- Creates a folder or parent object location
- :param folder:
- :return: None or raises an exception
- """
- folder = folder.rstrip("/")
- try:
- self.fs.upload_from_stream(folder, BytesIO(), metadata={"type": "dir"})
- except errors.FileExists: # make it idempotent
- pass
- except Exception as e:
- raise FsException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-
- def dir_rename(self, src, dst):
- """
- Rename one directory name. If dst exist, it replaces (deletes) existing directory
- :param src: source directory
- :param dst: destination directory
- :return: None or raises and exception
- """
- dst = dst.rstrip("/")
- src = src.rstrip("/")
-
- try:
- dst_cursor = self.fs.find(
- {"filename": {"$regex": "^{}(/|$)".format(dst)}}, no_cursor_timeout=True
- )
-
- for dst_file in dst_cursor:
- self.fs.delete(dst_file._id)
-
- src_cursor = self.fs.find(
- {"filename": {"$regex": "^{}(/|$)".format(src)}}, no_cursor_timeout=True
- )
-
- for src_file in src_cursor:
- self.fs.rename(src_file._id, src_file.filename.replace(src, dst, 1))
- except Exception as e:
- raise FsException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-
- def file_exists(self, storage, mode=None):
- """
- Indicates if "storage" file exist
- :param storage: can be a str or a str list
- :param mode: can be 'file' exist as a regular file; 'dir' exists as a directory or; 'None' just exists
- :return: True, False
- """
- f = storage if isinstance(storage, str) else "/".join(storage)
- f = f.rstrip("/")
-
- cursor = self.fs.find({"filename": f})
-
- for requested_file in cursor:
- exception_file = next(cursor, None)
-
- if exception_file:
- raise FsException(
- "Multiple files found", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
- )
-
- self.logger.debug("Entry {} metadata {}".format(f, requested_file.metadata))
-
- # if no special mode is required just check it does exists
- if not mode:
- return True
-
- if requested_file.metadata["type"] == mode:
- return True
-
- if requested_file.metadata["type"] == "sym" and mode == "file":
- return True
-
- return False
-
- def file_size(self, storage):
- """
- return file size
- :param storage: can be a str or a str list
- :return: file size
- """
- f = storage if isinstance(storage, str) else "/".join(storage)
- f = f.rstrip("/")
-
- cursor = self.fs.find({"filename": f})
-
- for requested_file in cursor:
- exception_file = next(cursor, None)
-
- if exception_file:
- raise FsException(
- "Multiple files found", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
- )
-
- return requested_file.length
-
- def file_extract(self, compressed_object, path):
- """
- extract a tar file
- :param compressed_object: object of type tar or zip
- :param path: can be a str or a str list, or a tar object where to extract the tar_object
- :return: None
- """
- f = path if isinstance(path, str) else "/".join(path)
- f = f.rstrip("/")
-
- if type(compressed_object) is tarfile.TarFile:
- for member in compressed_object.getmembers():
- if member.isfile():
- stream = compressed_object.extractfile(member)
- elif member.issym():
- stream = BytesIO(member.linkname.encode("utf-8"))
- else:
- stream = BytesIO()
-
- if member.isfile():
- file_type = "file"
- elif member.issym():
- file_type = "sym"
- else:
- file_type = "dir"
-
- metadata = {"type": file_type, "permissions": member.mode}
- member.name = member.name.rstrip("/")
-
- self.logger.debug("Uploading {}/{}".format(f, member.name))
- self.fs.upload_from_stream(
- f + "/" + member.name, stream, metadata=metadata
- )
-
- stream.close()
- elif type(compressed_object) is zipfile.ZipFile:
- for member in compressed_object.infolist():
- if member.is_dir():
- stream = BytesIO()
- else:
- stream = compressed_object.read(member)
-
- if member.is_dir():
- file_type = "dir"
- else:
- file_type = "file"
-
- metadata = {"type": file_type}
- member.filename = member.filename.rstrip("/")
-
- self.logger.debug("Uploading {}/{}".format(f, member.filename))
- self.fs.upload_from_stream(
- f + "/" + member.filename, stream, metadata=metadata
- )
-
- if member.is_dir():
- stream.close()
-
- def file_open(self, storage, mode):
- """
- Open a file
- :param storage: can be a str or list of str
- :param mode: file mode
- :return: file object
- """
- try:
- f = storage if isinstance(storage, str) else "/".join(storage)
- f = f.rstrip("/")
-
- if "b" in mode:
- return GridByteStream(f, self.fs, mode)
- else:
- return GridStringStream(f, self.fs, mode)
- except errors.NoFile:
- raise FsException(
- "File {} does not exist".format(f), http_code=HTTPStatus.NOT_FOUND
- )
- except IOError:
- raise FsException(
- "File {} cannot be opened".format(f), http_code=HTTPStatus.BAD_REQUEST
- )
-
- def dir_ls(self, storage):
- """
- return folder content
- :param storage: can be a str or list of str
- :return: folder content
- """
- try:
- f = storage if isinstance(storage, str) else "/".join(storage)
- f = f.rstrip("/")
-
- files = []
- dir_cursor = self.fs.find({"filename": f})
- for requested_dir in dir_cursor:
- exception_dir = next(dir_cursor, None)
-
- if exception_dir:
- raise FsException(
- "Multiple directories found",
- http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
- )
-
- if requested_dir.metadata["type"] != "dir":
- raise FsException(
- "File {} does not exist".format(f),
- http_code=HTTPStatus.NOT_FOUND,
- )
-
- if f.endswith("/"):
- f = f[:-1]
-
- files_cursor = self.fs.find(
- {"filename": {"$regex": "^{}/([^/])*".format(f)}}
- )
- for children_file in files_cursor:
- files += [children_file.filename.replace(f + "/", "", 1)]
-
- return files
- except IOError:
- raise FsException(
- "File {} cannot be opened".format(f), http_code=HTTPStatus.BAD_REQUEST
- )
-
- def file_delete(self, storage, ignore_non_exist=False):
- """
- Delete storage content recursively
- :param storage: can be a str or list of str
- :param ignore_non_exist: not raise exception if storage does not exist
- :return: None
- """
- try:
- f = storage if isinstance(storage, str) else "/".join(storage)
- f = f.rstrip("/")
-
- file_cursor = self.fs.find({"filename": f})
- found = False
- for requested_file in file_cursor:
- found = True
- exception_file = next(file_cursor, None)
-
- if exception_file:
- self.logger.error(
- "Cannot delete duplicate file: {} and {}".format(
- requested_file.filename, exception_file.filename
- )
- )
- raise FsException(
- "Multiple files found",
- http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
- )
-
- if requested_file.metadata["type"] == "dir":
- dir_cursor = self.fs.find(
- {"filename": {"$regex": "^{}/".format(f)}}
- )
-
- for tmp in dir_cursor:
- self.logger.debug("Deleting {}".format(tmp.filename))
- self.fs.delete(tmp._id)
-
- self.logger.debug("Deleting {}".format(requested_file.filename))
- self.fs.delete(requested_file._id)
- if not found and not ignore_non_exist:
- raise FsException(
- "File {} does not exist".format(storage),
- http_code=HTTPStatus.NOT_FOUND,
- )
- except IOError as e:
- raise FsException(
- "File {} cannot be deleted: {}".format(f, e),
- http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
- )
-
- def sync(self, from_path=None):
- """
- Sync from FSMongo to local storage
- :param from_path: if supplied, only copy content from this path, not all
- :return: None
- """
- if from_path:
- if os.path.isabs(from_path):
- from_path = os.path.relpath(from_path, self.path)
- self.__update_local_fs(from_path=from_path)
-
- def _update_mongo_fs(self, from_path):
- os_path = self.path + from_path
- # Obtain list of files and dirs in filesystem
- members = []
- for root, dirs, files in os.walk(os_path):
- for folder in dirs:
- member = {"filename": os.path.join(root, folder), "type": "dir"}
- if os.path.islink(member["filename"]):
- member["type"] = "sym"
- members.append(member)
- for file in files:
- filename = os.path.join(root, file)
- if os.path.islink(filename):
- file_type = "sym"
- else:
- file_type = "file"
- member = {"filename": os.path.join(root, file), "type": file_type}
- members.append(member)
-
- # Obtain files in mongo dict
- remote_files = self._get_mongo_files(from_path)
-
- # Upload members if they do not exists or have been modified
- # We will do this for performance (avoid updating unmodified files) and to avoid
- # updating a file with an older one in case there are two sources for synchronization
- # in high availability scenarios
- for member in members:
- # obtain permission
- mask = int(oct(os.stat(member["filename"]).st_mode)[-3:], 8)
-
- # convert to relative path
- rel_filename = os.path.relpath(member["filename"], self.path)
- # get timestamp in UTC because mongo stores upload date in UTC:
- # https://www.mongodb.com/docs/v4.0/tutorial/model-time-data/#overview
- last_modified_date = datetime.datetime.utcfromtimestamp(
- os.path.getmtime(member["filename"])
- )
-
- remote_file = remote_files.get(rel_filename)
- upload_date = (
- remote_file[0].uploadDate if remote_file else datetime.datetime.min
- )
- # remove processed files from dict
- remote_files.pop(rel_filename, None)
-
- if last_modified_date >= upload_date:
- stream = None
- fh = None
- try:
- file_type = member["type"]
- if file_type == "dir":
- stream = BytesIO()
- elif file_type == "sym":
- stream = BytesIO(
- os.readlink(member["filename"]).encode("utf-8")
- )
- else:
- fh = open(member["filename"], "rb")
- stream = BytesIO(fh.read())
-
- metadata = {"type": file_type, "permissions": mask}
-
- self.logger.debug("Sync upload {}".format(rel_filename))
- self.fs.upload_from_stream(rel_filename, stream, metadata=metadata)
-
- # delete old files
- if remote_file:
- for file in remote_file:
- self.logger.debug("Sync deleting {}".format(file.filename))
- self.fs.delete(file._id)
- finally:
- if fh:
- fh.close()
- if stream:
- stream.close()
-
- # delete files that are not anymore in local fs
- for remote_file in remote_files.values():
- for file in remote_file:
- self.fs.delete(file._id)
-
- def _get_mongo_files(self, from_path=None):
- file_dict = {}
- file_cursor = self.fs.find(no_cursor_timeout=True, sort=[("uploadDate", -1)])
- for file in file_cursor:
- if from_path and not file.filename.startswith(from_path):
- continue
- if file.filename in file_dict:
- file_dict[file.filename].append(file)
- else:
- file_dict[file.filename] = [file]
- return file_dict
-
- def reverse_sync(self, from_path: str):
- """
- Sync from local storage to FSMongo
- :param from_path: base directory to upload content to mongo fs
- :return: None
- """
- if os.path.isabs(from_path):
- from_path = os.path.relpath(from_path, self.path)
- self._update_mongo_fs(from_path=from_path)
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from http import HTTPStatus
-import logging
-from threading import Lock
-
-from osm_common.common_utils import FakeLock
-
-__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
-
-
-class MsgException(Exception):
- """
- Base Exception class for all msgXXXX exceptions
- """
-
- def __init__(self, message, http_code=HTTPStatus.SERVICE_UNAVAILABLE):
- """
- General exception
- :param message: descriptive text
- :param http_code: <http.HTTPStatus> type. It contains ".value" (http error code) and ".name" (http error name
- """
- self.http_code = http_code
- Exception.__init__(self, "messaging exception " + message)
-
-
-class MsgBase(object):
- """
- Base class for all msgXXXX classes
- """
-
- def __init__(self, logger_name="msg", lock=False):
- """
- Constructor of FsBase
- :param logger_name: logging name
- :param lock: Used to protect simultaneous access to the same instance class by several threads:
- False, None: Do not protect, this object will only be accessed by one thread
- True: This object needs to be protected by several threads accessing.
- Lock object. Use thi Lock for the threads access protection
- """
- self.logger = logging.getLogger(logger_name)
- if not lock:
- self.lock = FakeLock()
- elif lock is True:
- self.lock = Lock()
- elif isinstance(lock, Lock):
- self.lock = lock
- else:
- raise ValueError("lock parameter must be a Lock class or boolean")
-
- def connect(self, config):
- pass
-
- def disconnect(self):
- pass
-
- def write(self, topic, key, msg):
- raise MsgException(
- "Method 'write' not implemented", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
- )
-
- def read(self, topic):
- raise MsgException(
- "Method 'read' not implemented", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
- )
-
- async def aiowrite(self, topic, key, msg):
- raise MsgException(
- "Method 'aiowrite' not implemented",
- http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
- )
-
- async def aioread(
- self, topic, callback=None, aiocallback=None, group_id=None, **kwargs
- ):
- raise MsgException(
- "Method 'aioread' not implemented",
- http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
- )
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-import logging
-
-from aiokafka import AIOKafkaConsumer
-from aiokafka import AIOKafkaProducer
-from aiokafka.errors import KafkaError
-from osm_common.msgbase import MsgBase, MsgException
-import yaml
-
-__author__ = (
- "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>, "
- "Guillermo Calvino <guillermo.calvinosanchez@altran.com>"
-)
-
-
-class MsgKafka(MsgBase):
- def __init__(self, logger_name="msg", lock=False):
- super().__init__(logger_name, lock)
- self.host = None
- self.port = None
- self.consumer = None
- self.producer = None
- self.broker = None
- self.group_id = None
-
- def connect(self, config):
- try:
- if "logger_name" in config:
- self.logger = logging.getLogger(config["logger_name"])
- self.host = config["host"]
- self.port = config["port"]
- self.broker = str(self.host) + ":" + str(self.port)
- self.group_id = config.get("group_id")
-
- except Exception as e: # TODO refine
- raise MsgException(str(e))
-
- def disconnect(self):
- try:
- pass
- except Exception as e: # TODO refine
- raise MsgException(str(e))
-
- def write(self, topic, key, msg):
- """
- Write a message at kafka bus
- :param topic: message topic, must be string
- :param key: message key, must be string
- :param msg: message content, can be string or dictionary
- :return: None or raises MsgException on failing
- """
- retry = 2 # Try two times
- while retry:
- try:
- asyncio.run(self.aiowrite(topic=topic, key=key, msg=msg))
- break
- except Exception as e:
- retry -= 1
- if retry == 0:
- raise MsgException(
- "Error writing {} topic: {}".format(topic, str(e))
- )
-
- def read(self, topic):
- """
- Read from one or several topics.
- :param topic: can be str: single topic; or str list: several topics
- :return: topic, key, message; or None
- """
- try:
- return asyncio.run(self.aioread(topic))
- except MsgException:
- raise
- except Exception as e:
- raise MsgException("Error reading {} topic: {}".format(topic, str(e)))
-
- async def aiowrite(self, topic, key, msg):
- """
- Asyncio write
- :param topic: str kafka topic
- :param key: str kafka key
- :param msg: str or dictionary kafka message
- :return: None
- """
- try:
- self.producer = AIOKafkaProducer(
- key_serializer=str.encode,
- value_serializer=str.encode,
- bootstrap_servers=self.broker,
- )
- await self.producer.start()
- await self.producer.send(
- topic=topic, key=key, value=yaml.safe_dump(msg, default_flow_style=True)
- )
- except Exception as e:
- raise MsgException(
- "Error publishing topic '{}', key '{}': {}".format(topic, key, e)
- )
- finally:
- await self.producer.stop()
-
- async def aioread(
- self,
- topic,
- callback=None,
- aiocallback=None,
- group_id=None,
- from_beginning=None,
- **kwargs
- ):
- """
- Asyncio read from one or several topics.
- :param topic: can be str: single topic; or str list: several topics
- :param callback: synchronous callback function that will handle the message in kafka bus
- :param aiocallback: async callback function that will handle the message in kafka bus
- :param group_id: kafka group_id to use. Can be False (set group_id to None), None (use general group_id provided
- at connect inside config), or a group_id string
- :param from_beginning: if True, messages will be obtained from beginning instead of only new ones.
- If group_id is supplied, only the not processed messages by other worker are obtained.
- If group_id is None, all messages stored at kafka are obtained.
- :param kwargs: optional keyword arguments for callback function
- :return: If no callback defined, it returns (topic, key, message)
- """
- if group_id is False:
- group_id = None
- elif group_id is None:
- group_id = self.group_id
- try:
- if isinstance(topic, (list, tuple)):
- topic_list = topic
- else:
- topic_list = (topic,)
- self.consumer = AIOKafkaConsumer(
- bootstrap_servers=self.broker,
- group_id=group_id,
- auto_offset_reset="earliest" if from_beginning else "latest",
- )
- await self.consumer.start()
- self.consumer.subscribe(topic_list)
-
- async for message in self.consumer:
- if callback:
- callback(
- message.topic,
- yaml.safe_load(message.key),
- yaml.safe_load(message.value),
- **kwargs
- )
- elif aiocallback:
- await aiocallback(
- message.topic,
- yaml.safe_load(message.key),
- yaml.safe_load(message.value),
- **kwargs
- )
- else:
- return (
- message.topic,
- yaml.safe_load(message.key),
- yaml.safe_load(message.value),
- )
- except KafkaError as e:
- raise MsgException(str(e))
- finally:
- await self.consumer.stop()
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import asyncio
-from http import HTTPStatus
-import logging
-import os
-from time import sleep
-
-from osm_common.msgbase import MsgBase, MsgException
-import yaml
-
-__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
-"""
-This emulated kafka bus by just using a shared file system. Useful for testing or devops.
-One file is used per topic. Only one producer and one consumer is allowed per topic. Both consumer and producer
-access to the same file. e.g. same volume if running with docker.
-One text line per message is used in yaml format.
-"""
-
-
-class MsgLocal(MsgBase):
- def __init__(self, logger_name="msg", lock=False):
- super().__init__(logger_name, lock)
- self.path = None
- # create a different file for each topic
- self.files_read = {}
- self.files_write = {}
- self.buffer = {}
-
- def connect(self, config):
- try:
- if "logger_name" in config:
- self.logger = logging.getLogger(config["logger_name"])
- self.path = config["path"]
- if not self.path.endswith("/"):
- self.path += "/"
- if not os.path.exists(self.path):
- os.mkdir(self.path)
-
- except MsgException:
- raise
- except Exception as e: # TODO refine
- raise MsgException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
-
- def disconnect(self):
- for topic, f in self.files_read.items():
- try:
- f.close()
- self.files_read[topic] = None
- except Exception as read_topic_error:
- if isinstance(read_topic_error, (IOError, FileNotFoundError)):
- self.logger.exception(
- f"{read_topic_error} occured while closing read topic files."
- )
- elif isinstance(read_topic_error, KeyError):
- self.logger.exception(
- f"{read_topic_error} occured while reading from files_read dictionary."
- )
- else:
- self.logger.exception(
- f"{read_topic_error} occured while closing read topics."
- )
-
- for topic, f in self.files_write.items():
- try:
- f.close()
- self.files_write[topic] = None
- except Exception as write_topic_error:
- if isinstance(write_topic_error, (IOError, FileNotFoundError)):
- self.logger.exception(
- f"{write_topic_error} occured while closing write topic files."
- )
- elif isinstance(write_topic_error, KeyError):
- self.logger.exception(
- f"{write_topic_error} occured while reading from files_write dictionary."
- )
- else:
- self.logger.exception(
- f"{write_topic_error} occured while closing write topics."
- )
-
- def write(self, topic, key, msg):
- """
- Insert a message into topic
- :param topic: topic
- :param key: key text to be inserted
- :param msg: value object to be inserted, can be str, object ...
- :return: None or raises and exception
- """
- try:
- with self.lock:
- if topic not in self.files_write:
- self.files_write[topic] = open(self.path + topic, "a+")
- yaml.safe_dump(
- {key: msg},
- self.files_write[topic],
- default_flow_style=True,
- width=20000,
- )
- self.files_write[topic].flush()
- except Exception as e: # TODO refine
- raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
-
- def read(self, topic, blocks=True):
- """
- Read from one or several topics. it is non blocking returning None if nothing is available
- :param topic: can be str: single topic; or str list: several topics
- :param blocks: indicates if it should wait and block until a message is present or returns None
- :return: topic, key, message; or None if blocks==True
- """
- try:
- if isinstance(topic, (list, tuple)):
- topic_list = topic
- else:
- topic_list = (topic,)
- while True:
- for single_topic in topic_list:
- with self.lock:
- if single_topic not in self.files_read:
- self.files_read[single_topic] = open(
- self.path + single_topic, "a+"
- )
- self.buffer[single_topic] = ""
- self.buffer[single_topic] += self.files_read[
- single_topic
- ].readline()
- if not self.buffer[single_topic].endswith("\n"):
- continue
- msg_dict = yaml.safe_load(self.buffer[single_topic])
- self.buffer[single_topic] = ""
- if len(msg_dict) != 1:
- raise ValueError(
- "Length of message dictionary is not equal to 1"
- )
- for k, v in msg_dict.items():
- return single_topic, k, v
- if not blocks:
- return None
- sleep(2)
- except Exception as e: # TODO refine
- raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
-
- async def aioread(
- self, topic, callback=None, aiocallback=None, group_id=None, **kwargs
- ):
- """
- Asyncio read from one or several topics. It blocks
- :param topic: can be str: single topic; or str list: several topics
- :param callback: synchronous callback function that will handle the message
- :param aiocallback: async callback function that will handle the message
- :param group_id: group_id to use for load balancing. Can be False (set group_id to None), None (use general
- group_id provided at connect inside config), or a group_id string
- :param kwargs: optional keyword arguments for callback function
- :return: If no callback defined, it returns (topic, key, message)
- """
- try:
- while True:
- msg = self.read(topic, blocks=False)
- if msg:
- if callback:
- callback(*msg, **kwargs)
- elif aiocallback:
- await aiocallback(*msg, **kwargs)
- else:
- return msg
- await asyncio.sleep(2)
- except MsgException:
- raise
- except Exception as e: # TODO refine
- raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
-
- async def aiowrite(self, topic, key, msg):
- """
- Asyncio write. It blocks
- :param topic: str
- :param key: str
- :param msg: message, can be str or yaml
- :return: nothing if ok or raises an exception
- """
- return self.write(topic, key, msg)
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2021 Whitestack, LLC
-# *************************************************************
-#
-# This file is part of OSM common repository.
-# All Rights Reserved to Whitestack, LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: agarcia@whitestack.com or fbravo@whitestack.com
-##
-
-"""Python module for interacting with ETSI GS NFV-SOL004 compliant packages
-
-This module provides a SOL004Package class for validating and interacting with
-ETSI SOL004 packages. A valid SOL004 package may have its files arranged according
-to one of the following two structures:
-
-SOL004 with metadata directory SOL004 without metadata directory
-
-native_charm_vnf/ native_charm_vnf/
-├── TOSCA-Metadata ├── native_charm_vnfd.mf
-│ └── TOSCA.meta ├── native_charm_vnfd.yaml
-├── manifest.mf ├── ChangeLog.txt
-├── Definitions ├── Licenses
-│ └── native_charm_vnfd.yaml │ └── license.lic
-├── Files ├── Files
-│ ├── icons │ └── icons
-│ │ └── osm.png │ └── osm.png
-│ ├── Licenses └── Scripts
-│ │ └── license.lic ├── cloud_init
-│ └── changelog.txt │ └── cloud-config.txt
-└── Scripts └── charms
- ├── cloud_init └── simple
- │ └── cloud-config.txt ├── config.yaml
- └── charms ├── hooks
- └── simple │ ├── install
- ├── config.yaml ...
- ├── hooks │
- │ ├── install └── src
- ... └── charm.py
- └── src
- └── charm.py
-"""
-
-import datetime
-import os
-
-import yaml
-
-from .sol_package import SOLPackage
-
-
-class SOL004PackageException(Exception):
- pass
-
-
-class SOL004Package(SOLPackage):
- _MANIFEST_VNFD_ID = "vnfd_id"
- _MANIFEST_VNFD_PRODUCT_NAME = "vnfd_product_name"
- _MANIFEST_VNFD_PROVIDER_ID = "vnfd_provider_id"
- _MANIFEST_VNFD_SOFTWARE_VERSION = "vnfd_software_version"
- _MANIFEST_VNFD_PACKAGE_VERSION = "vnfd_package_version"
- _MANIFEST_VNFD_RELEASE_DATE_TIME = "vnfd_release_date_time"
- _MANIFEST_VNFD_COMPATIBLE_SPECIFICATION_VERSIONS = (
- "compatible_specification_versions"
- )
- _MANIFEST_VNFM_INFO = "vnfm_info"
-
- _MANIFEST_ALL_FIELDS = [
- _MANIFEST_VNFD_ID,
- _MANIFEST_VNFD_PRODUCT_NAME,
- _MANIFEST_VNFD_PROVIDER_ID,
- _MANIFEST_VNFD_SOFTWARE_VERSION,
- _MANIFEST_VNFD_PACKAGE_VERSION,
- _MANIFEST_VNFD_RELEASE_DATE_TIME,
- _MANIFEST_VNFD_COMPATIBLE_SPECIFICATION_VERSIONS,
- _MANIFEST_VNFM_INFO,
- ]
-
- def __init__(self, package_path=""):
- super().__init__(package_path)
-
- def generate_manifest_data_from_descriptor(self):
- descriptor_path = os.path.join(
- self._package_path, self.get_descriptor_location()
- )
- with open(descriptor_path, "r") as descriptor:
- try:
- vnfd_data = yaml.safe_load(descriptor)["vnfd"]
- except yaml.YAMLError as e:
- print("Error reading descriptor {}: {}".format(descriptor_path, e))
- return
-
- self._manifest_metadata = {}
- self._manifest_metadata[self._MANIFEST_VNFD_ID] = vnfd_data.get(
- "id", "default-id"
- )
- self._manifest_metadata[self._MANIFEST_VNFD_PRODUCT_NAME] = vnfd_data.get(
- "product-name", "default-product-name"
- )
- self._manifest_metadata[self._MANIFEST_VNFD_PROVIDER_ID] = vnfd_data.get(
- "provider", "OSM"
- )
- self._manifest_metadata[
- self._MANIFEST_VNFD_SOFTWARE_VERSION
- ] = vnfd_data.get("version", "1.0")
- self._manifest_metadata[self._MANIFEST_VNFD_PACKAGE_VERSION] = "1.0.0"
- self._manifest_metadata[self._MANIFEST_VNFD_RELEASE_DATE_TIME] = (
- datetime.datetime.now().astimezone().isoformat()
- )
- self._manifest_metadata[
- self._MANIFEST_VNFD_COMPATIBLE_SPECIFICATION_VERSIONS
- ] = "2.7.1"
- self._manifest_metadata[self._MANIFEST_VNFM_INFO] = "OSM"
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2021 Whitestack, LLC
-# *************************************************************
-#
-# This file is part of OSM common repository.
-# All Rights Reserved to Whitestack, LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: fbravo@whitestack.com
-##
-
-"""Python module for interacting with ETSI GS NFV-SOL007 compliant packages
-
-This module provides a SOL007Package class for validating and interacting with
-ETSI SOL007 packages. A valid SOL007 package may have its files arranged according
-to one of the following two structures:
-
-SOL007 with metadata directory SOL007 without metadata directory
-
-native_charm_vnf/ native_charm_vnf/
-├── TOSCA-Metadata ├── native_charm_nsd.mf
-│ └── TOSCA.meta ├── native_charm_nsd.yaml
-├── manifest.mf ├── ChangeLog.txt
-├── Definitions ├── Licenses
-│ └── native_charm_nsd.yaml │ └── license.lic
-├── Files ├── Files
-│ ├── icons │ └── icons
-│ │ └── osm.png │ └── osm.png
-│ ├── Licenses └── Scripts
-│ │ └── license.lic ├── cloud_init
-│ └── changelog.txt │ └── cloud-config.txt
-└── Scripts └── charms
- ├── cloud_init └── simple
- │ └── cloud-config.txt ├── config.yaml
- └── charms ├── hooks
- └── simple │ ├── install
- ├── config.yaml ...
- ├── hooks │
- │ ├── install └── src
- ... └── charm.py
- └── src
- └── charm.py
-"""
-
-import datetime
-import os
-
-import yaml
-
-from .sol_package import SOLPackage
-
-
-class SOL007PackageException(Exception):
- pass
-
-
-class SOL007Package(SOLPackage):
- _MANIFEST_NSD_INVARIANT_ID = "nsd_invariant_id"
- _MANIFEST_NSD_NAME = "nsd_name"
- _MANIFEST_NSD_DESIGNER = "nsd_designer"
- _MANIFEST_NSD_FILE_STRUCTURE_VERSION = "nsd_file_structure_version"
- _MANIFEST_NSD_RELEASE_DATE_TIME = "nsd_release_date_time"
- _MANIFEST_NSD_COMPATIBLE_SPECIFICATION_VERSIONS = (
- "compatible_specification_versions"
- )
-
- _MANIFEST_ALL_FIELDS = [
- _MANIFEST_NSD_INVARIANT_ID,
- _MANIFEST_NSD_NAME,
- _MANIFEST_NSD_DESIGNER,
- _MANIFEST_NSD_FILE_STRUCTURE_VERSION,
- _MANIFEST_NSD_RELEASE_DATE_TIME,
- _MANIFEST_NSD_COMPATIBLE_SPECIFICATION_VERSIONS,
- ]
-
- def __init__(self, package_path=""):
- super().__init__(package_path)
-
- def generate_manifest_data_from_descriptor(self):
- descriptor_path = os.path.join(
- self._package_path, self.get_descriptor_location()
- )
- with open(descriptor_path, "r") as descriptor:
- try:
- nsd_data = yaml.safe_load(descriptor)["nsd"]
- except yaml.YAMLError as e:
- print("Error reading descriptor {}: {}".format(descriptor_path, e))
- return
-
- self._manifest_metadata = {}
- self._manifest_metadata[self._MANIFEST_NSD_INVARIANT_ID] = nsd_data.get(
- "id", "default-id"
- )
- self._manifest_metadata[self._MANIFEST_NSD_NAME] = nsd_data.get(
- "name", "default-name"
- )
- self._manifest_metadata[self._MANIFEST_NSD_DESIGNER] = nsd_data.get(
- "designer", "OSM"
- )
- self._manifest_metadata[
- self._MANIFEST_NSD_FILE_STRUCTURE_VERSION
- ] = nsd_data.get("version", "1.0")
- self._manifest_metadata[self._MANIFEST_NSD_RELEASE_DATE_TIME] = (
- datetime.datetime.now().astimezone().isoformat()
- )
- self._manifest_metadata[
- self._MANIFEST_NSD_COMPATIBLE_SPECIFICATION_VERSIONS
- ] = "2.7.1"
+++ /dev/null
-# -*- coding: utf-8 -*-
-
-# Copyright 2021 Whitestack, LLC
-# *************************************************************
-#
-# This file is part of OSM common repository.
-# All Rights Reserved to Whitestack, LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: fbravo@whitestack.com or agarcia@whitestack.com
-##
-import hashlib
-import os
-
-import yaml
-
-
-class SOLPackageException(Exception):
- pass
-
-
-class SOLPackage:
- _METADATA_FILE_PATH = "TOSCA-Metadata/TOSCA.meta"
- _METADATA_DESCRIPTOR_FIELD = "Entry-Definitions"
- _METADATA_MANIFEST_FIELD = "ETSI-Entry-Manifest"
- _METADATA_CHANGELOG_FIELD = "ETSI-Entry-Change-Log"
- _METADATA_LICENSES_FIELD = "ETSI-Entry-Licenses"
- _METADATA_DEFAULT_CHANGELOG_PATH = "ChangeLog.txt"
- _METADATA_DEFAULT_LICENSES_PATH = "Licenses"
- _MANIFEST_FILE_PATH_FIELD = "Source"
- _MANIFEST_FILE_HASH_ALGORITHM_FIELD = "Algorithm"
- _MANIFEST_FILE_HASH_DIGEST_FIELD = "Hash"
-
- _MANIFEST_ALL_FIELDS = []
-
- def __init__(self, package_path=""):
- self._package_path = package_path
-
- self._package_metadata = self._parse_package_metadata()
-
- try:
- self._manifest_data = self._parse_manifest_data()
- except Exception:
- self._manifest_data = None
-
- try:
- self._manifest_metadata = self._parse_manifest_metadata()
- except Exception:
- self._manifest_metadata = None
-
- def _parse_package_metadata(self):
- try:
- return self._parse_package_metadata_with_metadata_dir()
- except FileNotFoundError:
- return self._parse_package_metadata_without_metadata_dir()
-
- def _parse_package_metadata_with_metadata_dir(self):
- try:
- return self._parse_file_in_blocks(self._METADATA_FILE_PATH)
- except FileNotFoundError as e:
- raise e
- except (Exception, OSError) as e:
- raise SOLPackageException(
- "Error parsing {}: {}".format(self._METADATA_FILE_PATH, e)
- )
-
- def _parse_package_metadata_without_metadata_dir(self):
- package_root_files = {f for f in os.listdir(self._package_path)}
- package_root_yamls = [
- f for f in package_root_files if f.endswith(".yml") or f.endswith(".yaml")
- ]
- if len(package_root_yamls) != 1:
- error_msg = "Error parsing package metadata: there should be exactly 1 descriptor YAML, found {}"
- raise SOLPackageException(error_msg.format(len(package_root_yamls)))
-
- base_manifest = [
- {
- SOLPackage._METADATA_DESCRIPTOR_FIELD: package_root_yamls[0],
- SOLPackage._METADATA_MANIFEST_FIELD: "{}.mf".format(
- os.path.splitext(package_root_yamls[0])[0]
- ),
- SOLPackage._METADATA_CHANGELOG_FIELD: SOLPackage._METADATA_DEFAULT_CHANGELOG_PATH,
- SOLPackage._METADATA_LICENSES_FIELD: SOLPackage._METADATA_DEFAULT_LICENSES_PATH,
- }
- ]
-
- return base_manifest
-
- def _parse_manifest_data(self):
- manifest_path = None
- for tosca_meta in self._package_metadata:
- if SOLPackage._METADATA_MANIFEST_FIELD in tosca_meta:
- manifest_path = tosca_meta[SOLPackage._METADATA_MANIFEST_FIELD]
- break
- else:
- error_msg = "Error parsing {}: no {} field on path".format(
- self._METADATA_FILE_PATH, self._METADATA_MANIFEST_FIELD
- )
- raise SOLPackageException(error_msg)
-
- try:
- return self._parse_file_in_blocks(manifest_path)
-
- except (Exception, OSError) as e:
- raise SOLPackageException("Error parsing {}: {}".format(manifest_path, e))
-
- def _parse_manifest_metadata(self):
- try:
- base_manifest = {}
- manifest_file = os.open(
- os.path.join(
- self._package_path, base_manifest[self._METADATA_MANIFEST_FIELD]
- ),
- "rw",
- )
- for line in manifest_file:
- fields_in_line = line.split(":", maxsplit=1)
- fields_in_line[0] = fields_in_line[0].strip()
- fields_in_line[1] = fields_in_line[1].strip()
- if fields_in_line[0] in self._MANIFEST_ALL_FIELDS:
- base_manifest[fields_in_line[0]] = fields_in_line[1]
- return base_manifest
- except (Exception, OSError) as e:
- raise SOLPackageException(
- "Error parsing {}: {}".format(
- base_manifest[SOLPackage._METADATA_MANIFEST_FIELD], e
- )
- )
-
- def _get_package_file_full_path(self, file_relative_path):
- return os.path.join(self._package_path, file_relative_path)
-
- def _parse_file_in_blocks(self, file_relative_path):
- file_path = self._get_package_file_full_path(file_relative_path)
- with open(file_path) as f:
- blocks = f.read().split("\n\n")
- parsed_blocks = map(yaml.safe_load, blocks)
- return [block for block in parsed_blocks if block is not None]
-
- def _get_package_file_manifest_data(self, file_relative_path):
- for file_data in self._manifest_data:
- if (
- file_data.get(SOLPackage._MANIFEST_FILE_PATH_FIELD, "")
- == file_relative_path
- ):
- return file_data
-
- error_msg = (
- "Error parsing {} manifest data: file not found on manifest file".format(
- file_relative_path
- )
- )
- raise SOLPackageException(error_msg)
-
- def get_package_file_hash_digest_from_manifest(self, file_relative_path):
- """Returns the hash digest of a file inside this package as specified on the manifest file."""
- file_manifest_data = self._get_package_file_manifest_data(file_relative_path)
- try:
- return file_manifest_data[SOLPackage._MANIFEST_FILE_HASH_DIGEST_FIELD]
- except Exception as e:
- raise SOLPackageException(
- "Error parsing {} hash digest: {}".format(file_relative_path, e)
- )
-
- def get_package_file_hash_algorithm_from_manifest(self, file_relative_path):
- """Returns the hash algorithm of a file inside this package as specified on the manifest file."""
- file_manifest_data = self._get_package_file_manifest_data(file_relative_path)
- try:
- return file_manifest_data[SOLPackage._MANIFEST_FILE_HASH_ALGORITHM_FIELD]
- except Exception as e:
- raise SOLPackageException(
- "Error parsing {} hash digest: {}".format(file_relative_path, e)
- )
-
- @staticmethod
- def _get_hash_function_from_hash_algorithm(hash_algorithm):
- function_to_algorithm = {"SHA-256": hashlib.sha256, "SHA-512": hashlib.sha512}
- if hash_algorithm not in function_to_algorithm:
- error_msg = (
- "Error checking hash function: hash algorithm {} not supported".format(
- hash_algorithm
- )
- )
- raise SOLPackageException(error_msg)
- return function_to_algorithm[hash_algorithm]
-
- def _calculate_file_hash(self, file_relative_path, hash_algorithm):
- file_path = self._get_package_file_full_path(file_relative_path)
- hash_function = self._get_hash_function_from_hash_algorithm(hash_algorithm)
- try:
- with open(file_path, "rb") as f:
- return hash_function(f.read()).hexdigest()
- except Exception as e:
- raise SOLPackageException(
- "Error hashing {}: {}".format(file_relative_path, e)
- )
-
- def validate_package_file_hash(self, file_relative_path):
- """Validates the integrity of a file using the hash algorithm and digest on the package manifest."""
- hash_algorithm = self.get_package_file_hash_algorithm_from_manifest(
- file_relative_path
- )
- file_hash = self._calculate_file_hash(file_relative_path, hash_algorithm)
- expected_file_hash = self.get_package_file_hash_digest_from_manifest(
- file_relative_path
- )
- if file_hash != expected_file_hash:
- error_msg = "Error validating {} hash: calculated hash {} is different than manifest hash {}"
- raise SOLPackageException(
- error_msg.format(file_relative_path, file_hash, expected_file_hash)
- )
-
- def validate_package_hashes(self):
- """Validates the integrity of all files listed on the package manifest."""
- for file_data in self._manifest_data:
- if SOLPackage._MANIFEST_FILE_PATH_FIELD in file_data:
- file_relative_path = file_data[SOLPackage._MANIFEST_FILE_PATH_FIELD]
- self.validate_package_file_hash(file_relative_path)
-
- def create_or_update_metadata_file(self):
- """
- Creates or updates the metadata file with the hashes calculated for each one of the package's files
- """
- if not self._manifest_metadata:
- self.generate_manifest_data_from_descriptor()
-
- self.write_manifest_data_into_file()
-
- def generate_manifest_data_from_descriptor(self):
- pass
-
- def write_manifest_data_into_file(self):
- with open(self.get_manifest_location(), "w") as metadata_file:
- # Write manifest metadata
- for metadata_entry in self._manifest_metadata:
- metadata_file.write(
- "{}: {}\n".format(
- metadata_entry, self._manifest_metadata[metadata_entry]
- )
- )
-
- # Write package's files hashes
- file_hashes = {}
- for root, dirs, files in os.walk(self._package_path):
- for a_file in files:
- file_path = os.path.join(root, a_file)
- file_relative_path = file_path[len(self._package_path) :]
- if file_relative_path.startswith("/"):
- file_relative_path = file_relative_path[1:]
- file_hashes[file_relative_path] = self._calculate_file_hash(
- file_relative_path, "SHA-512"
- )
-
- for file, hash in file_hashes.items():
- file_block = "Source: {}\nAlgorithm: SHA-512\nHash: {}\n\n".format(
- file, hash
- )
- metadata_file.write(file_block)
-
- def get_descriptor_location(self):
- """Returns this package descriptor location as a relative path from the package root."""
- for tosca_meta in self._package_metadata:
- if SOLPackage._METADATA_DESCRIPTOR_FIELD in tosca_meta:
- return tosca_meta[SOLPackage._METADATA_DESCRIPTOR_FIELD]
-
- error_msg = "Error: no {} entry found on {}".format(
- SOLPackage._METADATA_DESCRIPTOR_FIELD, SOLPackage._METADATA_FILE_PATH
- )
- raise SOLPackageException(error_msg)
-
- def get_manifest_location(self):
- """Return the VNF/NS manifest location as a relative path from the package root."""
- for tosca_meta in self._package_metadata:
- if SOLPackage._METADATA_MANIFEST_FIELD in tosca_meta:
- return tosca_meta[SOLPackage._METADATA_MANIFEST_FIELD]
-
- raise SOLPackageException("No manifest file defined for this package")
+++ /dev/null
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-
-vnfd:
- description: A VNF consisting of 1 VDU connected to two external VL, and one for
- data and another one for management
- df:
- - id: default-df
- instantiation-level:
- - id: default-instantiation-level
- vdu-level:
- - number-of-instances: 1
- vdu-id: mgmtVM
- vdu-profile:
- - id: mgmtVM
- min-number-of-instances: 1
- vdu-configuration-id: mgmtVM-vdu-configuration
- ext-cpd:
- - id: vnf-mgmt-ext
- int-cpd:
- cpd: mgmtVM-eth0-int
- vdu-id: mgmtVM
- - id: vnf-data-ext
- int-cpd:
- cpd: dataVM-xe0-int
- vdu-id: mgmtVM
- id: native_charm-vnf
- mgmt-cp: vnf-mgmt-ext
- product-name: native_charm-vnf
- sw-image-desc:
- - id: ubuntu18.04
- image: ubuntu18.04
- name: ubuntu18.04
- vdu:
- - cloud-init-file: cloud-config.txt
- id: mgmtVM
- int-cpd:
- - id: mgmtVM-eth0-int
- virtual-network-interface-requirement:
- - name: mgmtVM-eth0
- position: 1
- virtual-interface:
- type: PARAVIRT
- - id: dataVM-xe0-int
- virtual-network-interface-requirement:
- - name: dataVM-xe0
- position: 2
- virtual-interface:
- type: PARAVIRT
- name: mgmtVM
- sw-image-desc: ubuntu18.04
- virtual-compute-desc: mgmtVM-compute
- virtual-storage-desc:
- - mgmtVM-storage
- vdu-configuration:
- - config-access:
- ssh-access:
- default-user: ubuntu
- required: true
- config-primitive:
- - name: touch
- parameter:
- - data-type: STRING
- default-value: /home/ubuntu/touched
- name: filename
- id: mgmtVM-vdu-configuration
- initial-config-primitive:
- - name: touch
- parameter:
- - data-type: STRING
- name: filename
- value: /home/ubuntu/first-touch
- seq: 1
- juju:
- charm: simple
- proxy: false
- version: 1.0
- virtual-compute-desc:
- - id: mgmtVM-compute
- virtual-cpu:
- num-virtual-cpu: 1
- virtual-memory:
- size: 1.0
- virtual-storage-desc:
- - id: mgmtVM-storage
- size-of-storage: 10
+++ /dev/null
-#!/usr/bin/env python3
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-
-import subprocess
-import sys
-
-from ops.charm import CharmBase
-from ops.main import main
-from ops.model import ActiveStatus
-
-sys.path.append("lib")
-
-
-class MyNativeCharm(CharmBase):
- def __init__(self, framework, key):
- super().__init__(framework, key)
-
- # Listen to charm events
- self.framework.observe(self.on.config_changed, self.on_config_changed)
- self.framework.observe(self.on.install, self.on_install)
- self.framework.observe(self.on.start, self.on_start)
-
- # Listen to the touch action event
- self.framework.observe(self.on.touch_action, self.on_touch_action)
-
- def on_config_changed(self, event):
- """Handle changes in configuration"""
- self.model.unit.status = ActiveStatus()
-
- def on_install(self, event):
- """Called when the charm is being installed"""
- self.model.unit.status = ActiveStatus()
-
- def on_start(self, event):
- """Called when the charm is being started"""
- self.model.unit.status = ActiveStatus()
-
- def on_touch_action(self, event):
- """Touch a file."""
-
- filename = event.params["filename"]
- try:
- subprocess.run(["touch", filename], check=True)
- event.set_results({"created": True, "filename": filename})
- except subprocess.CalledProcessError as e:
- event.fail("Action failed: {}".format(e))
- self.model.unit.status = ActiveStatus()
-
-
-if __name__ == "__main__":
- main(MyNativeCharm)
+++ /dev/null
-# \r
-# Copyright 2020 Whitestack, LLC\r
-# *************************************************************\r
-#\r
-# This file is part of OSM common repository.\r
-# All Rights Reserved to Whitestack, LLC\r
-#\r
-# Licensed under the Apache License, Version 2.0 (the "License"); you may\r
-# not use this file except in compliance with the License. You may obtain\r
-# a copy of the License at\r
-#\r
-# http://www.apache.org/licenses/LICENSE-2.0\r
-#\r
-# Unless required by applicable law or agreed to in writing, software\r
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT\r
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\r
-# License for the specific language governing permissions and limitations\r
-# under the License.\r
-#\r
-\r
-#cloud-config\r
-chpasswd: { expire: False }\r
-ssh_pwauth: True\r
-\r
-write_files:\r
-- content: |\r
- # My new helloworld file\r
-\r
- owner: root:root\r
- permissions: '0644'\r
- path: /root/helloworld.txt\r
+++ /dev/null
-#
-# Copyright 2020 Whitestack, LLC
-# *************************************************************
-#
-# This file is part of OSM common repository.
-# All Rights Reserved to Whitestack, LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: agarcia@whitestack.com
-##
-
-TOSCA-Meta-Version: 1.0
-CSAR-Version: 1.0
-Created-By: Diego Armando Maradona
-Entry-Definitions: Definitions/native_charm_vnfd.yaml # Points to the main descriptor of the package
-ETSI-Entry-Manifest: manifest.mf # Points to the ETSI manifest file
-ETSI-Entry-Change-Log: Files/Changelog.txt # Points to package changelog
-ETSI-Entry-Licenses: Files/Licenses # Points to package licenses folder
-
-# In principle, we could add one block per package file to specify MIME types
-Name: Definitions/native_charm_vnfd.yaml # path to file within package
-Content-Type: application/yaml # MIME type of file
-
-Name: Scripts/cloud_init/cloud-config.txt
-Content-Type: application/yaml
-
+++ /dev/null
-#
-# Copyright 2020 Whitestack, LLC
-# *************************************************************
-#
-# This file is part of OSM common repository.
-# All Rights Reserved to Whitestack, LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: agarcia@whitestack.com
-##
-
-# General definitions of the package
-vnfd_id: native_charm-vnf
-vnf_product_name: native_charm-vnf
-vnf_provider_id: AFA
-vnf_software_version: 1.0
-vnf_package_version: 1.0
-vnf_release_date_time: 2021.12.01T11:36-03:00
-compatible_specification_versions: 3.3.1
-vnfm_info: OSM
-
-Source: Definitions/native_charm_vnfd.yaml
-Algorithm: SHA-256
-Hash: ede8daf9748ac4849e1a1aac955d6c84cafef9ea34067eaef76ee4e5996974c2
-
-Source: Scripts/cloud_init/cloud-config.txt
-Algorithm: SHA-256
-Hash: 7455ca868843cc5da1f0a2255cdedb64a69df3b618c344b83b82848a94540eda
-
-
-# The below sections are all wrong on purpose as they are intended for testing
-
-# Invalid hash algorithm
-Source: Scripts/charms/simple/src/charm.py
-Algorithm: SHA-733
-Hash: ea72f897a966e6174ed9164fabc3c500df5a2f712eb6b22ab2408afb07d04d14
-
-# Wrong hash
-Source: Scripts/charms/simple/hooks/start
-Algorithm: SHA-256
-Hash: 123456aaaaaa123456aaaaaae2bb9d0197f41619165dde6cf205c974f9aa86ae
-
-# Unspecified hash
-Source: Scripts/charms/simple/hooks/upgrade-charm
-Algorithm: SHA-256
+++ /dev/null
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-
-vnfd:
- description: A VNF consisting of 1 VDU connected to two external VL, and one for
- data and another one for management
- df:
- - id: default-df
- instantiation-level:
- - id: default-instantiation-level
- vdu-level:
- - number-of-instances: 1
- vdu-id: mgmtVM
- vdu-profile:
- - id: mgmtVM
- min-number-of-instances: 1
- vdu-configuration-id: mgmtVM-vdu-configuration
- ext-cpd:
- - id: vnf-mgmt-ext
- int-cpd:
- cpd: mgmtVM-eth0-int
- vdu-id: mgmtVM
- - id: vnf-data-ext
- int-cpd:
- cpd: dataVM-xe0-int
- vdu-id: mgmtVM
- id: native_charm-vnf
- mgmt-cp: vnf-mgmt-ext
- product-name: native_charm-vnf
- sw-image-desc:
- - id: ubuntu18.04
- image: ubuntu18.04
- name: ubuntu18.04
- vdu:
- - cloud-init-file: cloud-config.txt
- id: mgmtVM
- int-cpd:
- - id: mgmtVM-eth0-int
- virtual-network-interface-requirement:
- - name: mgmtVM-eth0
- position: 1
- virtual-interface:
- type: PARAVIRT
- - id: dataVM-xe0-int
- virtual-network-interface-requirement:
- - name: dataVM-xe0
- position: 2
- virtual-interface:
- type: PARAVIRT
- name: mgmtVM
- sw-image-desc: ubuntu18.04
- virtual-compute-desc: mgmtVM-compute
- virtual-storage-desc:
- - mgmtVM-storage
- vdu-configuration:
- - config-access:
- ssh-access:
- default-user: ubuntu
- required: true
- config-primitive:
- - name: touch
- parameter:
- - data-type: STRING
- default-value: /home/ubuntu/touched
- name: filename
- id: mgmtVM-vdu-configuration
- initial-config-primitive:
- - name: touch
- parameter:
- - data-type: STRING
- name: filename
- value: /home/ubuntu/first-touch
- seq: 1
- juju:
- charm: simple
- proxy: false
- version: 1.0
- virtual-compute-desc:
- - id: mgmtVM-compute
- virtual-cpu:
- num-virtual-cpu: 1
- virtual-memory:
- size: 1.0
- virtual-storage-desc:
- - id: mgmtVM-storage
- size-of-storage: 10
+++ /dev/null
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-
-1.0.0: First version
+++ /dev/null
-#
-# Copyright 2020 Whitestack, LLC
-# *************************************************************
-#
-# This file is part of OSM common repository.
-# All Rights Reserved to Whitestack, LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
+++ /dev/null
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-touch:
- description: "Touch a file on the VNF."
- params:
- filename:
- description: "The name of the file to touch."
- type: string
- default: ""
- required:
- - filename
+++ /dev/null
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-options: {}
\ No newline at end of file
+++ /dev/null
-#!/usr/bin/env python3
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-
-import sys
-import subprocess
-
-sys.path.append("lib")
-
-from ops.charm import CharmBase
-from ops.main import main
-from ops.model import ActiveStatus
-
-class MyNativeCharm(CharmBase):
-
- def __init__(self, framework, key):
- super().__init__(framework, key)
-
- # Listen to charm events
- self.framework.observe(self.on.config_changed, self.on_config_changed)
- self.framework.observe(self.on.install, self.on_install)
- self.framework.observe(self.on.start, self.on_start)
-
- # Listen to the touch action event
- self.framework.observe(self.on.touch_action, self.on_touch_action)
-
- def on_config_changed(self, event):
- """Handle changes in configuration"""
- self.model.unit.status = ActiveStatus()
-
- def on_install(self, event):
- """Called when the charm is being installed"""
- self.model.unit.status = ActiveStatus()
-
- def on_start(self, event):
- """Called when the charm is being started"""
- self.model.unit.status = ActiveStatus()
-
- def on_touch_action(self, event):
- """Touch a file."""
-
- filename = event.params["filename"]
- try:
- subprocess.run(["touch", filename], check=True)
- event.set_results({"created": True, "filename": filename})
- except subprocess.CalledProcessError as e:
- event.fail("Action failed: {}".format(e))
- self.model.unit.status = ActiveStatus()
-
-
-if __name__ == "__main__":
- main(MyNativeCharm)
-
+++ /dev/null
-#!/usr/bin/env python3
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-
-import sys
-import subprocess
-
-sys.path.append("lib")
-
-from ops.charm import CharmBase
-from ops.main import main
-from ops.model import ActiveStatus
-
-class MyNativeCharm(CharmBase):
-
- def __init__(self, framework, key):
- super().__init__(framework, key)
-
- # Listen to charm events
- self.framework.observe(self.on.config_changed, self.on_config_changed)
- self.framework.observe(self.on.install, self.on_install)
- self.framework.observe(self.on.start, self.on_start)
-
- # Listen to the touch action event
- self.framework.observe(self.on.touch_action, self.on_touch_action)
-
- def on_config_changed(self, event):
- """Handle changes in configuration"""
- self.model.unit.status = ActiveStatus()
-
- def on_install(self, event):
- """Called when the charm is being installed"""
- self.model.unit.status = ActiveStatus()
-
- def on_start(self, event):
- """Called when the charm is being started"""
- self.model.unit.status = ActiveStatus()
-
- def on_touch_action(self, event):
- """Touch a file."""
-
- filename = event.params["filename"]
- try:
- subprocess.run(["touch", filename], check=True)
- event.set_results({"created": True, "filename": filename})
- except subprocess.CalledProcessError as e:
- event.fail("Action failed: {}".format(e))
- self.model.unit.status = ActiveStatus()
-
-
-if __name__ == "__main__":
- main(MyNativeCharm)
-
+++ /dev/null
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-
-name: simple-native
-summary: A simple native charm
-description: |
- Simple native charm
-series:
- - bionic
- - xenial
- - focal
\ No newline at end of file
+++ /dev/null
-#!/usr/bin/env python3
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-
-import subprocess
-import sys
-
-from ops.charm import CharmBase
-from ops.main import main
-from ops.model import ActiveStatus
-
-sys.path.append("lib")
-
-
-class MyNativeCharm(CharmBase):
- def __init__(self, framework, key):
- super().__init__(framework, key)
-
- # Listen to charm events
- self.framework.observe(self.on.config_changed, self.on_config_changed)
- self.framework.observe(self.on.install, self.on_install)
- self.framework.observe(self.on.start, self.on_start)
-
- # Listen to the touch action event
- self.framework.observe(self.on.touch_action, self.on_touch_action)
-
- def on_config_changed(self, event):
- """Handle changes in configuration"""
- self.model.unit.status = ActiveStatus()
-
- def on_install(self, event):
- """Called when the charm is being installed"""
- self.model.unit.status = ActiveStatus()
-
- def on_start(self, event):
- """Called when the charm is being started"""
- self.model.unit.status = ActiveStatus()
-
- def on_touch_action(self, event):
- """Touch a file."""
-
- filename = event.params["filename"]
- try:
- subprocess.run(["touch", filename], check=True)
- event.set_results({"created": True, "filename": filename})
- except subprocess.CalledProcessError as e:
- event.fail("Action failed: {}".format(e))
- self.model.unit.status = ActiveStatus()
-
-
-if __name__ == "__main__":
- main(MyNativeCharm)
+++ /dev/null
-# \r
-# Copyright 2020 Whitestack, LLC\r
-# *************************************************************\r
-#\r
-# This file is part of OSM common repository.\r
-# All Rights Reserved to Whitestack, LLC\r
-#\r
-# Licensed under the Apache License, Version 2.0 (the "License"); you may\r
-# not use this file except in compliance with the License. You may obtain\r
-# a copy of the License at\r
-#\r
-# http://www.apache.org/licenses/LICENSE-2.0\r
-#\r
-# Unless required by applicable law or agreed to in writing, software\r
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT\r
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\r
-# License for the specific language governing permissions and limitations\r
-# under the License.\r
-#\r
-\r
-\r
-#cloud-config\r
-chpasswd: { expire: False }\r
-ssh_pwauth: True\r
-\r
-write_files:\r
-- content: |\r
- # My new helloworld file\r
-\r
- owner: root:root\r
- permissions: '0644'\r
- path: /root/helloworld.txt\r
+++ /dev/null
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: agarcia@whitestack.com
-##
-
-TOSCA-Meta-Version: 1.0
-CSAR-Version: 1.0
-Created-By: Diego Armando Maradona
-Entry-Definitions: Definitions/native_charm_vnfd.yaml # Points to the main descriptor of the package
-ETSI-Entry-Manifest: manifest.mf # Points to the ETSI manifest file
-ETSI-Entry-Change-Log: Files/Changelog.txt # Points to package changelog
-ETSI-Entry-Licenses: Files/Licenses # Points to package licenses folder
-
-# In principle, we could add one block per package file to specify MIME types
-Name: Definitions/native_charm_vnfd.yaml # path to file within package
-Content-Type: application/yaml # MIME type of file
-
-Name: Scripts/cloud_init/cloud-config.txt
-Content-Type: application/yaml
\ No newline at end of file
+++ /dev/null
-#
-# Copyright 2020 Whitestack, LLC
-# *************************************************************
-#
-# This file is part of OSM common repository.
-# All Rights Reserved to Whitestack, LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: agarcia@whitestack.com
-##
-
-# General definitions of the package
-vnfd_id: native_charm-vnf
-vnf_product_name: native_charm-vnf
-vnf_provider_id: AFA
-vnf_software_version: 1.0
-vnf_package_version: 1.0
-vnf_release_date_time: 2021.12.01T11:36-03:00
-compatible_specification_versions: 3.3.1
-vnfm_info: OSM
-
-# One block for every file in the package
-Source: Definitions/native_charm_vnfd.yaml
-Algorithm: SHA-256
-Hash: ede8daf9748ac4849e1a1aac955d6c84cafef9ea34067eaef76ee4e5996974c2
-
-
-
-Source: Scripts/cloud_init/cloud-config.txt
-Algorithm: SHA-256
-Hash: 0eef3f1a642339e2053af48a7e370dac1952f9cb81166e439e8f72afd6f03621
-
-# Charms files
-
-Source: Scripts/charms/simple/src/charm.py
-Algorithm: SHA-256
-Hash: ea72f897a966e6174ed9164fabc3c500df5a2f712eb6b22ab2408afb07d04d14
-
-Source: Scripts/charms/simple/hooks/start
-Algorithm: SHA-256
-Hash: 312490afd82cc86ad823e4d9e2bb9d0197f41619165dde6cf205c974f9aa86ae
-
-Source: Scripts/charms/simple/hooks/install
-Algorithm: SHA-256
-Hash: 312490afd82cc86ad823e4d9e2bb9d0197f41619165dde6cf205c974f9aa86ae
-
-Source: Scripts/charms/simple/actions.yaml
-Algorithm: SHA-256
-Hash: 988ca2653ae6a3977149faaebd664a12858e0025f226b27d2cee1fa954c9462d
-
-Source: Scripts/charms/simple/metadata.yaml
-Algorithm: SHA-256
-Hash: e00cfaf41a518aef0f486e4ae04a5ae19feffa774abfbdb68379bb5b5b102479
-
-Source: Scripts/charms/simple/config.yaml
-Algorithm: SHA-256
-Hash: f5cbf31b9c299504f3b577417b6c82bde5e3eafd74ee11fdeecf8c8bff6cf3e2
-
-
-# And on and on
+++ /dev/null
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-
-1.0.0: First version
+++ /dev/null
-#
-# Copyright 2020 Whitestack, LLC
-# *************************************************************
-#
-# This file is part of OSM common repository.
-# All Rights Reserved to Whitestack, LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
+++ /dev/null
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-touch:
- description: "Touch a file on the VNF."
- params:
- filename:
- description: "The name of the file to touch."
- type: string
- default: ""
- required:
- - filename
+++ /dev/null
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-options: {}
\ No newline at end of file
+++ /dev/null
-#!/usr/bin/env python3
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-
-import sys
-import subprocess
-
-sys.path.append("lib")
-
-from ops.charm import CharmBase
-from ops.main import main
-from ops.model import ActiveStatus
-
-class MyNativeCharm(CharmBase):
-
- def __init__(self, framework, key):
- super().__init__(framework, key)
-
- # Listen to charm events
- self.framework.observe(self.on.config_changed, self.on_config_changed)
- self.framework.observe(self.on.install, self.on_install)
- self.framework.observe(self.on.start, self.on_start)
-
- # Listen to the touch action event
- self.framework.observe(self.on.touch_action, self.on_touch_action)
-
- def on_config_changed(self, event):
- """Handle changes in configuration"""
- self.model.unit.status = ActiveStatus()
-
- def on_install(self, event):
- """Called when the charm is being installed"""
- self.model.unit.status = ActiveStatus()
-
- def on_start(self, event):
- """Called when the charm is being started"""
- self.model.unit.status = ActiveStatus()
-
- def on_touch_action(self, event):
- """Touch a file."""
-
- filename = event.params["filename"]
- try:
- subprocess.run(["touch", filename], check=True)
- event.set_results({"created": True, "filename": filename})
- except subprocess.CalledProcessError as e:
- event.fail("Action failed: {}".format(e))
- self.model.unit.status = ActiveStatus()
-
-
-if __name__ == "__main__":
- main(MyNativeCharm)
-
+++ /dev/null
-#!/usr/bin/env python3
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-
-import sys
-import subprocess
-
-sys.path.append("lib")
-
-from ops.charm import CharmBase
-from ops.main import main
-from ops.model import ActiveStatus
-
-class MyNativeCharm(CharmBase):
-
- def __init__(self, framework, key):
- super().__init__(framework, key)
-
- # Listen to charm events
- self.framework.observe(self.on.config_changed, self.on_config_changed)
- self.framework.observe(self.on.install, self.on_install)
- self.framework.observe(self.on.start, self.on_start)
-
- # Listen to the touch action event
- self.framework.observe(self.on.touch_action, self.on_touch_action)
-
- def on_config_changed(self, event):
- """Handle changes in configuration"""
- self.model.unit.status = ActiveStatus()
-
- def on_install(self, event):
- """Called when the charm is being installed"""
- self.model.unit.status = ActiveStatus()
-
- def on_start(self, event):
- """Called when the charm is being started"""
- self.model.unit.status = ActiveStatus()
-
- def on_touch_action(self, event):
- """Touch a file."""
-
- filename = event.params["filename"]
- try:
- subprocess.run(["touch", filename], check=True)
- event.set_results({"created": True, "filename": filename})
- except subprocess.CalledProcessError as e:
- event.fail("Action failed: {}".format(e))
- self.model.unit.status = ActiveStatus()
-
-
-if __name__ == "__main__":
- main(MyNativeCharm)
-
+++ /dev/null
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-
-name: simple-native
-summary: A simple native charm
-description: |
- Simple native charm
-series:
- - bionic
- - xenial
- - focal
\ No newline at end of file
+++ /dev/null
-#!/usr/bin/env python3
-##
-# Copyright 2020 Canonical Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-##
-
-import subprocess
-import sys
-
-from ops.charm import CharmBase
-from ops.main import main
-from ops.model import ActiveStatus
-
-sys.path.append("lib")
-
-
-class MyNativeCharm(CharmBase):
- def __init__(self, framework, key):
- super().__init__(framework, key)
-
- # Listen to charm events
- self.framework.observe(self.on.config_changed, self.on_config_changed)
- self.framework.observe(self.on.install, self.on_install)
- self.framework.observe(self.on.start, self.on_start)
-
- # Listen to the touch action event
- self.framework.observe(self.on.touch_action, self.on_touch_action)
-
- def on_config_changed(self, event):
- """Handle changes in configuration"""
- self.model.unit.status = ActiveStatus()
-
- def on_install(self, event):
- """Called when the charm is being installed"""
- self.model.unit.status = ActiveStatus()
-
- def on_start(self, event):
- """Called when the charm is being started"""
- self.model.unit.status = ActiveStatus()
-
- def on_touch_action(self, event):
- """Touch a file."""
-
- filename = event.params["filename"]
- try:
- subprocess.run(["touch", filename], check=True)
- event.set_results({"created": True, "filename": filename})
- except subprocess.CalledProcessError as e:
- event.fail("Action failed: {}".format(e))
- self.model.unit.status = ActiveStatus()
-
-
-if __name__ == "__main__":
- main(MyNativeCharm)
+++ /dev/null
-# \r
-# Copyright 2020 Whitestack, LLC\r
-# *************************************************************\r
-#\r
-# This file is part of OSM common repository.\r
-# All Rights Reserved to Whitestack, LLC\r
-#\r
-# Licensed under the Apache License, Version 2.0 (the "License"); you may\r
-# not use this file except in compliance with the License. You may obtain\r
-# a copy of the License at\r
-#\r
-# http://www.apache.org/licenses/LICENSE-2.0\r
-#\r
-# Unless required by applicable law or agreed to in writing, software\r
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT\r
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the\r
-# License for the specific language governing permissions and limitations\r
-# under the License.\r
-#\r
-\r
-\r
-#cloud-config\r
-chpasswd: { expire: False }\r
-ssh_pwauth: True\r
-\r
-write_files:\r
-- content: |\r
- # My new helloworld file\r
-\r
- owner: root:root\r
- permissions: '0644'\r
- path: /root/helloworld.txt\r
+++ /dev/null
-#
-# Copyright 2020 Whitestack, LLC
-# *************************************************************
-#
-# This file is part of OSM common repository.
-# All Rights Reserved to Whitestack, LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: agarcia@whitestack.com
-##
-
-# General definitions of the package
-vnfd_id: native_charm-vnf
-vnf_product_name: native_charm-vnf
-vnf_provider_id: AFA
-vnf_software_version: 1.0
-vnf_package_version: 1.0
-vnf_release_date_time: 2021.12.01T11:36-03:00
-compatible_specification_versions: 3.3.1
-vnfm_info: OSM
-
-# One block for every file in the package
-Source: native_charm_vnfd.yaml
-Algorithm: SHA-256
-Hash: ae06780c082041676df4ca4130ef223548eee6389007ba259416f59044450a7c
-
-
-
-Source: Scripts/cloud_init/cloud-config.txt
-Algorithm: SHA-256
-Hash: 0eef3f1a642339e2053af48a7e370dac1952f9cb81166e439e8f72afd6f03621
-
-# Charms files
-
-Source: Scripts/charms/simple/src/charm.py
-Algorithm: SHA-256
-Hash: ea72f897a966e6174ed9164fabc3c500df5a2f712eb6b22ab2408afb07d04d14
-
-Source: Scripts/charms/simple/hooks/start
-Algorithm: SHA-256
-Hash: 312490afd82cc86ad823e4d9e2bb9d0197f41619165dde6cf205c974f9aa86ae
-
-Source: Scripts/charms/simple/hooks/install
-Algorithm: SHA-256
-Hash: 312490afd82cc86ad823e4d9e2bb9d0197f41619165dde6cf205c974f9aa86ae
-
-Source: Scripts/charms/simple/actions.yaml
-Algorithm: SHA-256
-Hash: 988ca2653ae6a3977149faaebd664a12858e0025f226b27d2cee1fa954c9462d
-
-Source: Scripts/charms/simple/metadata.yaml
-Algorithm: SHA-256
-Hash: e00cfaf41a518aef0f486e4ae04a5ae19feffa774abfbdb68379bb5b5b102479
-
-Source: Scripts/charms/simple/config.yaml
-Algorithm: SHA-256
-Hash: f5cbf31b9c299504f3b577417b6c82bde5e3eafd74ee11fdeecf8c8bff6cf3e2
-
-
-# And on and on
+++ /dev/null
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-
-metadata:
- template_name: native_charm-vnf
- template_author: AFA
- template_version: 1.1
-
-vnfd:
- description: A VNF consisting of 1 VDU connected to two external VL, and one for
- data and another one for management
- df:
- - id: default-df
- instantiation-level:
- - id: default-instantiation-level
- vdu-level:
- - number-of-instances: 1
- vdu-id: mgmtVM
- vdu-profile:
- - id: mgmtVM
- min-number-of-instances: 1
- vdu-configuration-id: mgmtVM-vdu-configuration
- ext-cpd:
- - id: vnf-mgmt-ext
- int-cpd:
- cpd: mgmtVM-eth0-int
- vdu-id: mgmtVM
- - id: vnf-data-ext
- int-cpd:
- cpd: dataVM-xe0-int
- vdu-id: mgmtVM
- id: native_charm-vnf
- mgmt-cp: vnf-mgmt-ext
- product-name: native_charm-vnf
- provider: AFA
- sw-image-desc:
- - id: ubuntu18.04
- image: ubuntu18.04
- name: ubuntu18.04
- vdu:
- - cloud-init-file: cloud-config.txt
- id: mgmtVM
- int-cpd:
- - id: mgmtVM-eth0-int
- virtual-network-interface-requirement:
- - name: mgmtVM-eth0
- position: 1
- virtual-interface:
- type: PARAVIRT
- - id: dataVM-xe0-int
- virtual-network-interface-requirement:
- - name: dataVM-xe0
- position: 2
- virtual-interface:
- type: PARAVIRT
- name: mgmtVM
- sw-image-desc: ubuntu18.04
- virtual-compute-desc: mgmtVM-compute
- virtual-storage-desc:
- - mgmtVM-storage
- vdu-configuration:
- - config-access:
- ssh-access:
- default-user: ubuntu
- required: true
- config-primitive:
- - name: touch
- parameter:
- - data-type: STRING
- default-value: /home/ubuntu/touched
- name: filename
- id: mgmtVM-vdu-configuration
- initial-config-primitive:
- - name: touch
- parameter:
- - data-type: STRING
- name: filename
- value: /home/ubuntu/first-touch
- seq: 1
- juju:
- charm: simple
- proxy: false
- version: 1.0
- virtual-compute-desc:
- - id: mgmtVM-compute
- virtual-cpu:
- num-virtual-cpu: 1
- virtual-memory:
- size: 1.0
- virtual-storage-desc:
- - id: mgmtVM-storage
- size-of-storage: 10
+++ /dev/null
-# Copyright 2018 Whitestack, LLC
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
-##
-import asyncio
-import copy
-from copy import deepcopy
-import http
-from http import HTTPStatus
-import logging
-from os import urandom
-import unittest
-from unittest.mock import MagicMock, Mock, patch
-
-from Crypto.Cipher import AES
-from osm_common.dbbase import DbBase, DbException, deep_update, Encryption
-import pytest
-
-
-# Variables used in TestBaseEncryption and TestAsyncEncryption
-salt = "1afd5d1a-4a7e-4d9c-8c65-251290183106"
-value = "private key txt"
-padded_value = b"private key txt\0"
-padded_encoded_value = b"private key txt\x00"
-encoding_type = "ascii"
-encyrpt_mode = AES.MODE_ECB
-secret_key = b"\xeev\xc2\xb8\xb2#;Ek\xd0\xb5['\x04\xed\x1f\xb9?\xc5Ig\x80\xd5\x8d\x8aT\xd7\xf8Q\xe2u!"
-encyrpted_value = "ZW5jcnlwdGVkIGRhdGE="
-encyrpted_bytes = b"ZW5jcnlwdGVkIGRhdGE="
-data_to_b4_encode = b"encrypted data"
-b64_decoded = b"decrypted data"
-schema_version = "1.1"
-joined_key = b"\x9d\x17\xaf\xc8\xdeF\x1b.\x0e\xa9\xb5['\x04\xed\x1f\xb9?\xc5Ig\x80\xd5\x8d\x8aT\xd7\xf8Q\xe2u!"
-serial_bytes = b"\xf8\x96Z\x1c:}\xb5\xdf\x94\x8d\x0f\x807\xe6)\x8f\xf5!\xee}\xc2\xfa\xb3\t\xb9\xe4\r7\x19\x08\xa5b"
-base64_decoded_serial = b"g\xbe\xdb"
-decrypted_val1 = "BiV9YZEuSRAudqvz7Gs+bg=="
-decrypted_val2 = "q4LwnFdoryzbZJM5mCAnpA=="
-item = {
- "secret": "mysecret",
- "cacert": "mycacert",
- "path": "/var",
- "ip": "192.168.12.23",
-}
-
-
-def exception_message(message):
- return "database exception " + message
-
-
-@pytest.fixture
-def db_base():
- return DbBase()
-
-
-def test_constructor():
- db_base = DbBase()
- assert db_base is not None
- assert isinstance(db_base, DbBase)
-
-
-def test_db_connect(db_base):
- with pytest.raises(DbException) as excinfo:
- db_base.db_connect(None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'db_connect' not implemented")
- )
-
-
-def test_db_disconnect(db_base):
- db_base.db_disconnect()
-
-
-def test_get_list(db_base):
- with pytest.raises(DbException) as excinfo:
- db_base.get_list(None, None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'get_list' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-def test_get_one(db_base):
- with pytest.raises(DbException) as excinfo:
- db_base.get_one(None, None, None, None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'get_one' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-def test_create(db_base):
- with pytest.raises(DbException) as excinfo:
- db_base.create(None, None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'create' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-def test_create_list(db_base):
- with pytest.raises(DbException) as excinfo:
- db_base.create_list(None, None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'create_list' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-def test_del_list(db_base):
- with pytest.raises(DbException) as excinfo:
- db_base.del_list(None, None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'del_list' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-def test_del_one(db_base):
- with pytest.raises(DbException) as excinfo:
- db_base.del_one(None, None, None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'del_one' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-class TestEncryption(unittest.TestCase):
- def setUp(self):
- master_key = "Setting a long master key with numbers 123 and capitals AGHBNHD and symbols %&8)!'"
- db_base1 = DbBase()
- db_base2 = DbBase()
- db_base3 = DbBase()
- # set self.secret_key obtained when connect
- db_base1.set_secret_key(master_key, replace=True)
- db_base1.set_secret_key(urandom(32))
- db_base2.set_secret_key(None, replace=True)
- db_base2.set_secret_key(urandom(30))
- db_base3.set_secret_key(master_key)
- self.db_bases = [db_base1, db_base2, db_base3]
-
- def test_encrypt_decrypt(self):
- TEST = (
- ("plain text 1 ! ", None),
- ("plain text 2 with salt ! ", "1afd5d1a-4a7e-4d9c-8c65-251290183106"),
- )
- for db_base in self.db_bases:
- for value, salt in TEST:
- # no encryption
- encrypted = db_base.encrypt(value, schema_version="1.0", salt=salt)
- self.assertEqual(
- encrypted, value, "value '{}' has been encrypted".format(value)
- )
- decrypted = db_base.decrypt(encrypted, schema_version="1.0", salt=salt)
- self.assertEqual(
- decrypted, value, "value '{}' has been decrypted".format(value)
- )
-
- # encrypt/decrypt
- encrypted = db_base.encrypt(value, schema_version="1.1", salt=salt)
- self.assertNotEqual(
- encrypted, value, "value '{}' has not been encrypted".format(value)
- )
- self.assertIsInstance(encrypted, str, "Encrypted is not ascii text")
- decrypted = db_base.decrypt(encrypted, schema_version="1.1", salt=salt)
- self.assertEqual(
- decrypted, value, "value is not equal after encryption/decryption"
- )
-
- def test_encrypt_decrypt_salt(self):
- value = "value to be encrypted!"
- encrypted = []
- for db_base in self.db_bases:
- for salt in (None, "salt 1", "1afd5d1a-4a7e-4d9c-8c65-251290183106"):
- # encrypt/decrypt
- encrypted.append(
- db_base.encrypt(value, schema_version="1.1", salt=salt)
- )
- self.assertNotEqual(
- encrypted[-1],
- value,
- "value '{}' has not been encrypted".format(value),
- )
- self.assertIsInstance(encrypted[-1], str, "Encrypted is not ascii text")
- decrypted = db_base.decrypt(
- encrypted[-1], schema_version="1.1", salt=salt
- )
- self.assertEqual(
- decrypted, value, "value is not equal after encryption/decryption"
- )
- for i in range(0, len(encrypted)):
- for j in range(i + 1, len(encrypted)):
- self.assertNotEqual(
- encrypted[i],
- encrypted[j],
- "encryption with different salt must contain different result",
- )
- # decrypt with a different master key
- try:
- decrypted = self.db_bases[-1].decrypt(
- encrypted[0], schema_version="1.1", salt=None
- )
- self.assertNotEqual(
- encrypted[0],
- decrypted,
- "Decryption with different KEY must generate different result",
- )
- except DbException as e:
- self.assertEqual(
- e.http_code,
- HTTPStatus.INTERNAL_SERVER_ERROR,
- "Decryption with different KEY does not provide expected http_code",
- )
-
-
-class AsyncMock(MagicMock):
- async def __call__(self, *args, **kwargs):
- args = deepcopy(args)
- kwargs = deepcopy(kwargs)
- return super(AsyncMock, self).__call__(*args, **kwargs)
-
-
-class CopyingMock(MagicMock):
- def __call__(self, *args, **kwargs):
- args = deepcopy(args)
- kwargs = deepcopy(kwargs)
- return super(CopyingMock, self).__call__(*args, **kwargs)
-
-
-def check_if_assert_not_called(mocks: list):
- for mocking in mocks:
- mocking.assert_not_called()
-
-
-class TestBaseEncryption(unittest.TestCase):
- @patch("logging.getLogger", autospec=True)
- def setUp(self, mock_logger):
- mock_logger = logging.getLogger()
- mock_logger.disabled = True
- self.db_base = DbBase()
- self.mock_cipher = CopyingMock()
- self.db_base.encoding_type = encoding_type
- self.db_base.encrypt_mode = encyrpt_mode
- self.db_base.secret_key = secret_key
- self.mock_padded_msg = CopyingMock()
-
- def test_pad_data_len_not_multiplication_of_16(self):
- data = "hello word hello hello word hello word"
- data_len = len(data)
- expected_len = 48
- padded = self.db_base.pad_data(data)
- self.assertEqual(len(padded), expected_len)
- self.assertTrue("\0" * (expected_len - data_len) in padded)
-
- def test_pad_data_len_multiplication_of_16(self):
- data = "hello word!!!!!!"
- padded = self.db_base.pad_data(data)
- self.assertEqual(padded, data)
- self.assertFalse("\0" in padded)
-
- def test_pad_data_empty_string(self):
- data = ""
- expected_len = 0
- padded = self.db_base.pad_data(data)
- self.assertEqual(len(padded), expected_len)
- self.assertFalse("\0" in padded)
-
- def test_pad_data_not_string(self):
- data = None
- with self.assertRaises(Exception) as err:
- self.db_base.pad_data(data)
- self.assertEqual(
- str(err.exception),
- "database exception Incorrect data type: type(None), string is expected.",
- )
-
- def test_unpad_data_null_char_at_right(self):
- null_padded_data = "hell0word\0\0"
- expected_length = len(null_padded_data) - 2
- unpadded = self.db_base.unpad_data(null_padded_data)
- self.assertEqual(len(unpadded), expected_length)
- self.assertFalse("\0" in unpadded)
- self.assertTrue("0" in unpadded)
-
- def test_unpad_data_null_char_is_not_rightest(self):
- null_padded_data = "hell0word\r\t\0\n"
- expected_length = len(null_padded_data)
- unpadded = self.db_base.unpad_data(null_padded_data)
- self.assertEqual(len(unpadded), expected_length)
- self.assertTrue("\0" in unpadded)
-
- def test_unpad_data_with_spaces_at_right(self):
- null_padded_data = " hell0word\0 "
- expected_length = len(null_padded_data)
- unpadded = self.db_base.unpad_data(null_padded_data)
- self.assertEqual(len(unpadded), expected_length)
- self.assertTrue("\0" in unpadded)
-
- def test_unpad_data_empty_string(self):
- data = ""
- unpadded = self.db_base.unpad_data(data)
- self.assertEqual(unpadded, "")
- self.assertFalse("\0" in unpadded)
-
- def test_unpad_data_not_string(self):
- data = None
- with self.assertRaises(Exception) as err:
- self.db_base.unpad_data(data)
- self.assertEqual(
- str(err.exception),
- "database exception Incorrect data type: type(None), string is expected.",
- )
-
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "pad_data")
- def test__encrypt_value_schema_version_1_0_none_secret_key_none_salt(
- self, mock_pad_data, mock_join_secret_key
- ):
- """schema_version 1.0, secret_key is None and salt is None."""
- schema_version = "1.0"
- salt = None
- self.db_base.secret_key = None
- result = self.db_base._encrypt_value(value, schema_version, salt)
- self.assertEqual(result, value)
- check_if_assert_not_called([mock_pad_data, mock_join_secret_key])
-
- @patch("osm_common.dbbase.b64encode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "pad_data")
- def test__encrypt_value_schema_version_1_1_with_secret_key_exists_with_salt(
- self,
- mock_pad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_encode,
- ):
- """schema_version 1.1, secret_key exists, salt exists."""
- mock_aes.new.return_value = self.mock_cipher
- self.mock_cipher.encrypt.return_value = data_to_b4_encode
- self.mock_padded_msg.return_value = padded_value
- mock_pad_data.return_value = self.mock_padded_msg
- self.mock_padded_msg.encode.return_value = padded_encoded_value
-
- mock_b64_encode.return_value = encyrpted_bytes
-
- result = self.db_base._encrypt_value(value, schema_version, salt)
-
- self.assertTrue(isinstance(result, str))
- self.assertEqual(result, encyrpted_value)
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
- mock_pad_data.assert_called_once_with(value)
- mock_b64_encode.assert_called_once_with(data_to_b4_encode)
- self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value)
- self.mock_padded_msg.encode.assert_called_with(encoding_type)
-
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "pad_data")
- def test__encrypt_value_schema_version_1_0_secret_key_not_exists(
- self, mock_pad_data, mock_join_secret_key
- ):
- """schema_version 1.0, secret_key is None, salt exists."""
- schema_version = "1.0"
- self.db_base.secret_key = None
- result = self.db_base._encrypt_value(value, schema_version, salt)
- self.assertEqual(result, value)
- check_if_assert_not_called([mock_pad_data, mock_join_secret_key])
-
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "pad_data")
- def test__encrypt_value_schema_version_1_1_secret_key_not_exists(
- self, mock_pad_data, mock_join_secret_key
- ):
- """schema_version 1.1, secret_key is None, salt exists."""
- self.db_base.secret_key = None
- result = self.db_base._encrypt_value(value, schema_version, salt)
- self.assertEqual(result, value)
- check_if_assert_not_called([mock_pad_data, mock_join_secret_key])
-
- @patch("osm_common.dbbase.b64encode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "pad_data")
- def test__encrypt_value_schema_version_1_1_secret_key_exists_without_salt(
- self,
- mock_pad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_encode,
- ):
- """schema_version 1.1, secret_key exists, salt is None."""
- salt = None
- mock_aes.new.return_value = self.mock_cipher
- self.mock_cipher.encrypt.return_value = data_to_b4_encode
-
- self.mock_padded_msg.return_value = padded_value
- mock_pad_data.return_value = self.mock_padded_msg
- self.mock_padded_msg.encode.return_value = padded_encoded_value
-
- mock_b64_encode.return_value = encyrpted_bytes
-
- result = self.db_base._encrypt_value(value, schema_version, salt)
-
- self.assertEqual(result, encyrpted_value)
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
- mock_pad_data.assert_called_once_with(value)
- mock_b64_encode.assert_called_once_with(data_to_b4_encode)
- self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value)
- self.mock_padded_msg.encode.assert_called_with(encoding_type)
-
- @patch("osm_common.dbbase.b64encode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "pad_data")
- def test__encrypt_value_invalid_encrpt_mode(
- self,
- mock_pad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_encode,
- ):
- """encrypt_mode is invalid."""
- mock_aes.new.side_effect = Exception("Invalid ciphering mode.")
- self.db_base.encrypt_mode = "AES.MODE_XXX"
-
- with self.assertRaises(Exception) as err:
- self.db_base._encrypt_value(value, schema_version, salt)
-
- self.assertEqual(str(err.exception), "Invalid ciphering mode.")
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], "AES.MODE_XXX")
- check_if_assert_not_called([mock_pad_data, mock_b64_encode])
-
- @patch("osm_common.dbbase.b64encode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "pad_data")
- def test__encrypt_value_schema_version_1_1_secret_key_exists_value_none(
- self,
- mock_pad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_encode,
- ):
- """schema_version 1.1, secret_key exists, value is None."""
- value = None
- mock_aes.new.return_value = self.mock_cipher
- mock_pad_data.side_effect = DbException(
- "Incorrect data type: type(None), string is expected."
- )
-
- with self.assertRaises(Exception) as err:
- self.db_base._encrypt_value(value, schema_version, salt)
- self.assertEqual(
- str(err.exception),
- "database exception Incorrect data type: type(None), string is expected.",
- )
-
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
- mock_pad_data.assert_called_once_with(value)
- check_if_assert_not_called(
- [mock_b64_encode, self.mock_cipher.encrypt, mock_b64_encode]
- )
-
- @patch("osm_common.dbbase.b64encode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "pad_data")
- def test__encrypt_value_join_secret_key_raises(
- self,
- mock_pad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_encode,
- ):
- """Method join_secret_key raises DbException."""
- salt = b"3434o34-3wewrwr-222424-2242dwew"
-
- mock_join_secret_key.side_effect = DbException("Unexpected type")
-
- mock_aes.new.return_value = self.mock_cipher
-
- with self.assertRaises(Exception) as err:
- self.db_base._encrypt_value(value, schema_version, salt)
-
- self.assertEqual(str(err.exception), "database exception Unexpected type")
- check_if_assert_not_called(
- [mock_pad_data, mock_aes.new, mock_b64_encode, self.mock_cipher.encrypt]
- )
- mock_join_secret_key.assert_called_once_with(salt)
-
- @patch("osm_common.dbbase.b64encode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "pad_data")
- def test__encrypt_value_schema_version_1_1_secret_key_exists_b64_encode_raises(
- self,
- mock_pad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_encode,
- ):
- """schema_version 1.1, secret_key exists, b64encode raises TypeError."""
- mock_aes.new.return_value = self.mock_cipher
- self.mock_cipher.encrypt.return_value = "encrypted data"
-
- self.mock_padded_msg.return_value = padded_value
- mock_pad_data.return_value = self.mock_padded_msg
- self.mock_padded_msg.encode.return_value = padded_encoded_value
-
- mock_b64_encode.side_effect = TypeError(
- "A bytes-like object is required, not 'str'"
- )
-
- with self.assertRaises(Exception) as error:
- self.db_base._encrypt_value(value, schema_version, salt)
- self.assertEqual(
- str(error.exception), "A bytes-like object is required, not 'str'"
- )
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
- mock_pad_data.assert_called_once_with(value)
- self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value)
- self.mock_padded_msg.encode.assert_called_with(encoding_type)
- mock_b64_encode.assert_called_once_with("encrypted data")
-
- @patch("osm_common.dbbase.b64encode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "pad_data")
- def test__encrypt_value_cipher_encrypt_raises(
- self,
- mock_pad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_encode,
- ):
- """AES encrypt method raises Exception."""
- mock_aes.new.return_value = self.mock_cipher
- self.mock_cipher.encrypt.side_effect = Exception("Invalid data type.")
-
- self.mock_padded_msg.return_value = padded_value
- mock_pad_data.return_value = self.mock_padded_msg
- self.mock_padded_msg.encode.return_value = padded_encoded_value
-
- with self.assertRaises(Exception) as error:
- self.db_base._encrypt_value(value, schema_version, salt)
-
- self.assertEqual(str(error.exception), "Invalid data type.")
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
- mock_pad_data.assert_called_once_with(value)
- self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value)
- self.mock_padded_msg.encode.assert_called_with(encoding_type)
- mock_b64_encode.assert_not_called()
-
- @patch.object(DbBase, "get_secret_key")
- @patch.object(DbBase, "_encrypt_value")
- def test_encrypt_without_schema_version_without_salt(
- self, mock_encrypt_value, mock_get_secret_key
- ):
- """schema and salt is None."""
- mock_encrypt_value.return_value = encyrpted_value
- result = self.db_base.encrypt(value)
- mock_encrypt_value.assert_called_once_with(value, None, None)
- mock_get_secret_key.assert_called_once()
- self.assertEqual(result, encyrpted_value)
-
- @patch.object(DbBase, "get_secret_key")
- @patch.object(DbBase, "_encrypt_value")
- def test_encrypt_with_schema_version_with_salt(
- self, mock_encrypt_value, mock_get_secret_key
- ):
- """schema version exists, salt is None."""
- mock_encrypt_value.return_value = encyrpted_value
- result = self.db_base.encrypt(value, schema_version, salt)
- mock_encrypt_value.assert_called_once_with(value, schema_version, salt)
- mock_get_secret_key.assert_called_once()
- self.assertEqual(result, encyrpted_value)
-
- @patch.object(DbBase, "get_secret_key")
- @patch.object(DbBase, "_encrypt_value")
- def test_encrypt_get_secret_key_raises(
- self, mock_encrypt_value, mock_get_secret_key
- ):
- """get_secret_key method raises DbException."""
- mock_get_secret_key.side_effect = DbException("KeyError")
- with self.assertRaises(Exception) as error:
- self.db_base.encrypt(value)
- self.assertEqual(str(error.exception), "database exception KeyError")
- mock_encrypt_value.assert_not_called()
- mock_get_secret_key.assert_called_once()
-
- @patch.object(DbBase, "get_secret_key")
- @patch.object(DbBase, "_encrypt_value")
- def test_encrypt_encrypt_raises(self, mock_encrypt_value, mock_get_secret_key):
- """_encrypt method raises DbException."""
- mock_encrypt_value.side_effect = DbException(
- "Incorrect data type: type(None), string is expected."
- )
- with self.assertRaises(Exception) as error:
- self.db_base.encrypt(value, schema_version, salt)
- self.assertEqual(
- str(error.exception),
- "database exception Incorrect data type: type(None), string is expected.",
- )
- mock_encrypt_value.assert_called_once_with(value, schema_version, salt)
- mock_get_secret_key.assert_called_once()
-
- @patch("osm_common.dbbase.b64decode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "unpad_data")
- def test__decrypt_value_schema_version_1_1_secret_key_exists_without_salt(
- self,
- mock_unpad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_decode,
- ):
- """schema_version 1.1, secret_key exists, salt is None."""
- salt = None
- mock_aes.new.return_value = self.mock_cipher
- self.mock_cipher.decrypt.return_value = padded_encoded_value
-
- mock_b64_decode.return_value = b64_decoded
-
- mock_unpad_data.return_value = value
-
- result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
- self.assertEqual(result, value)
-
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
- mock_unpad_data.assert_called_once_with("private key txt\0")
- mock_b64_decode.assert_called_once_with(encyrpted_value)
- self.mock_cipher.decrypt.assert_called_once_with(b64_decoded)
-
- @patch("osm_common.dbbase.b64decode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "unpad_data")
- def test__decrypt_value_schema_version_1_1_secret_key_exists_with_salt(
- self,
- mock_unpad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_decode,
- ):
- """schema_version 1.1, secret_key exists, salt is None."""
- mock_aes.new.return_value = self.mock_cipher
- self.mock_cipher.decrypt.return_value = padded_encoded_value
-
- mock_b64_decode.return_value = b64_decoded
-
- mock_unpad_data.return_value = value
-
- result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
- self.assertEqual(result, value)
-
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
- mock_unpad_data.assert_called_once_with("private key txt\0")
- mock_b64_decode.assert_called_once_with(encyrpted_value)
- self.mock_cipher.decrypt.assert_called_once_with(b64_decoded)
-
- @patch("osm_common.dbbase.b64decode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "unpad_data")
- def test__decrypt_value_schema_version_1_1_without_secret_key(
- self,
- mock_unpad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_decode,
- ):
- """schema_version 1.1, secret_key is None, salt exists."""
- self.db_base.secret_key = None
-
- result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
-
- self.assertEqual(result, encyrpted_value)
- check_if_assert_not_called(
- [
- mock_join_secret_key,
- mock_aes.new,
- mock_unpad_data,
- mock_b64_decode,
- self.mock_cipher.decrypt,
- ]
- )
-
- @patch("osm_common.dbbase.b64decode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "unpad_data")
- def test__decrypt_value_schema_version_1_0_with_secret_key(
- self,
- mock_unpad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_decode,
- ):
- """schema_version 1.0, secret_key exists, salt exists."""
- schema_version = "1.0"
- result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
-
- self.assertEqual(result, encyrpted_value)
- check_if_assert_not_called(
- [
- mock_join_secret_key,
- mock_aes.new,
- mock_unpad_data,
- mock_b64_decode,
- self.mock_cipher.decrypt,
- ]
- )
-
- @patch("osm_common.dbbase.b64decode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "unpad_data")
- def test__decrypt_value_join_secret_key_raises(
- self,
- mock_unpad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_decode,
- ):
- """_join_secret_key raises TypeError."""
- salt = object()
- mock_join_secret_key.side_effect = TypeError("'type' object is not iterable")
-
- with self.assertRaises(Exception) as error:
- self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
- self.assertEqual(str(error.exception), "'type' object is not iterable")
-
- mock_join_secret_key.assert_called_once_with(salt)
- check_if_assert_not_called(
- [mock_aes.new, mock_unpad_data, mock_b64_decode, self.mock_cipher.decrypt]
- )
-
- @patch("osm_common.dbbase.b64decode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "unpad_data")
- def test__decrypt_value_b64decode_raises(
- self,
- mock_unpad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_decode,
- ):
- """b64decode raises TypeError."""
- mock_b64_decode.side_effect = TypeError(
- "A str-like object is required, not 'bytes'"
- )
- with self.assertRaises(Exception) as error:
- self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
- self.assertEqual(
- str(error.exception), "A str-like object is required, not 'bytes'"
- )
-
- mock_b64_decode.assert_called_once_with(encyrpted_value)
- mock_join_secret_key.assert_called_once_with(salt)
- check_if_assert_not_called(
- [mock_aes.new, self.mock_cipher.decrypt, mock_unpad_data]
- )
-
- @patch("osm_common.dbbase.b64decode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "unpad_data")
- def test__decrypt_value_invalid_encrypt_mode(
- self,
- mock_unpad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_decode,
- ):
- """Invalid AES encrypt mode."""
- mock_aes.new.side_effect = Exception("Invalid ciphering mode.")
- self.db_base.encrypt_mode = "AES.MODE_XXX"
-
- mock_b64_decode.return_value = b64_decoded
- with self.assertRaises(Exception) as error:
- self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
-
- self.assertEqual(str(error.exception), "Invalid ciphering mode.")
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], "AES.MODE_XXX")
- mock_b64_decode.assert_called_once_with(encyrpted_value)
- check_if_assert_not_called([mock_unpad_data, self.mock_cipher.decrypt])
-
- @patch("osm_common.dbbase.b64decode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "unpad_data")
- def test__decrypt_value_cipher_decrypt_raises(
- self,
- mock_unpad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_decode,
- ):
- """AES decrypt raises Exception."""
- mock_b64_decode.return_value = b64_decoded
-
- mock_aes.new.return_value = self.mock_cipher
- self.mock_cipher.decrypt.side_effect = Exception("Invalid data type.")
-
- with self.assertRaises(Exception) as error:
- self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
- self.assertEqual(str(error.exception), "Invalid data type.")
-
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
- mock_b64_decode.assert_called_once_with(encyrpted_value)
- self.mock_cipher.decrypt.assert_called_once_with(b64_decoded)
- mock_unpad_data.assert_not_called()
-
- @patch("osm_common.dbbase.b64decode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "unpad_data")
- def test__decrypt_value_decode_raises(
- self,
- mock_unpad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_decode,
- ):
- """Decode raises UnicodeDecodeError."""
- mock_aes.new.return_value = self.mock_cipher
- self.mock_cipher.decrypt.return_value = b"\xd0\x000091"
-
- mock_b64_decode.return_value = b64_decoded
-
- mock_unpad_data.return_value = value
- with self.assertRaises(Exception) as error:
- self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
- self.assertEqual(
- str(error.exception),
- "database exception Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
- )
- self.assertEqual(type(error.exception), DbException)
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
- mock_b64_decode.assert_called_once_with(encyrpted_value)
- self.mock_cipher.decrypt.assert_called_once_with(b64_decoded)
- mock_unpad_data.assert_not_called()
-
- @patch("osm_common.dbbase.b64decode")
- @patch("osm_common.dbbase.AES")
- @patch.object(DbBase, "_join_secret_key")
- @patch.object(DbBase, "unpad_data")
- def test__decrypt_value_unpad_data_raises(
- self,
- mock_unpad_data,
- mock_join_secret_key,
- mock_aes,
- mock_b64_decode,
- ):
- """Method unpad_data raises error."""
- mock_decrypted_message = MagicMock()
- mock_decrypted_message.decode.return_value = None
- mock_aes.new.return_value = self.mock_cipher
- self.mock_cipher.decrypt.return_value = mock_decrypted_message
- mock_unpad_data.side_effect = DbException(
- "Incorrect data type: type(None), string is expected."
- )
- mock_b64_decode.return_value = b64_decoded
-
- with self.assertRaises(Exception) as error:
- self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
- self.assertEqual(
- str(error.exception),
- "database exception Incorrect data type: type(None), string is expected.",
- )
- self.assertEqual(type(error.exception), DbException)
- mock_join_secret_key.assert_called_once_with(salt)
- _call_mock_aes_new = mock_aes.new.call_args_list[0].args
- self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
- mock_b64_decode.assert_called_once_with(encyrpted_value)
- self.mock_cipher.decrypt.assert_called_once_with(b64_decoded)
- mock_decrypted_message.decode.assert_called_once_with(
- self.db_base.encoding_type
- )
- mock_unpad_data.assert_called_once_with(None)
-
- @patch.object(DbBase, "get_secret_key")
- @patch.object(DbBase, "_decrypt_value")
- def test_decrypt_without_schema_version_without_salt(
- self, mock_decrypt_value, mock_get_secret_key
- ):
- """schema_version is None, salt is None."""
- mock_decrypt_value.return_value = encyrpted_value
- result = self.db_base.decrypt(value)
- mock_decrypt_value.assert_called_once_with(value, None, None)
- mock_get_secret_key.assert_called_once()
- self.assertEqual(result, encyrpted_value)
-
- @patch.object(DbBase, "get_secret_key")
- @patch.object(DbBase, "_decrypt_value")
- def test_decrypt_with_schema_version_with_salt(
- self, mock_decrypt_value, mock_get_secret_key
- ):
- """schema_version and salt exist."""
- mock_decrypt_value.return_value = encyrpted_value
- result = self.db_base.decrypt(value, schema_version, salt)
- mock_decrypt_value.assert_called_once_with(value, schema_version, salt)
- mock_get_secret_key.assert_called_once()
- self.assertEqual(result, encyrpted_value)
-
- @patch.object(DbBase, "get_secret_key")
- @patch.object(DbBase, "_decrypt_value")
- def test_decrypt_get_secret_key_raises(
- self, mock_decrypt_value, mock_get_secret_key
- ):
- """Method get_secret_key raises KeyError."""
- mock_get_secret_key.side_effect = DbException("KeyError")
- with self.assertRaises(Exception) as error:
- self.db_base.decrypt(value)
- self.assertEqual(str(error.exception), "database exception KeyError")
- mock_decrypt_value.assert_not_called()
- mock_get_secret_key.assert_called_once()
-
- @patch.object(DbBase, "get_secret_key")
- @patch.object(DbBase, "_decrypt_value")
- def test_decrypt_decrypt_value_raises(
- self, mock_decrypt_value, mock_get_secret_key
- ):
- """Method _decrypt raises error."""
- mock_decrypt_value.side_effect = DbException(
- "Incorrect data type: type(None), string is expected."
- )
- with self.assertRaises(Exception) as error:
- self.db_base.decrypt(value, schema_version, salt)
- self.assertEqual(
- str(error.exception),
- "database exception Incorrect data type: type(None), string is expected.",
- )
- mock_decrypt_value.assert_called_once_with(value, schema_version, salt)
- mock_get_secret_key.assert_called_once()
-
- def test_encrypt_decrypt_with_schema_version_1_1_with_salt(self):
- """Encrypt and decrypt with schema version 1.1, salt exists."""
- encrypted_msg = self.db_base.encrypt(value, schema_version, salt)
- decrypted_msg = self.db_base.decrypt(encrypted_msg, schema_version, salt)
- self.assertEqual(value, decrypted_msg)
-
- def test_encrypt_decrypt_with_schema_version_1_0_with_salt(self):
- """Encrypt and decrypt with schema version 1.0, salt exists."""
- schema_version = "1.0"
- encrypted_msg = self.db_base.encrypt(value, schema_version, salt)
- decrypted_msg = self.db_base.decrypt(encrypted_msg, schema_version, salt)
- self.assertEqual(value, decrypted_msg)
-
- def test_encrypt_decrypt_with_schema_version_1_1_without_salt(self):
- """Encrypt and decrypt with schema version 1.1 and without salt."""
- salt = None
- encrypted_msg = self.db_base.encrypt(value, schema_version, salt)
- decrypted_msg = self.db_base.decrypt(encrypted_msg, schema_version, salt)
- self.assertEqual(value, decrypted_msg)
-
-
-class TestAsyncEncryption(unittest.TestCase):
- @patch("logging.getLogger", autospec=True)
- def setUp(self, mock_logger):
- mock_logger = logging.getLogger()
- mock_logger.disabled = True
- self.encryption = Encryption(uri="uri", config={})
- self.encryption.encoding_type = encoding_type
- self.encryption.encrypt_mode = encyrpt_mode
- self.encryption._secret_key = secret_key
- self.admin_collection = Mock()
- self.admin_collection.find_one = AsyncMock()
- self.encryption._client = {
- "osm": {
- "admin": self.admin_collection,
- }
- }
-
- @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
- def test_decrypt_fields_with_item_with_fields(self, mock_decrypt):
- """item and fields exist."""
- mock_decrypt.side_effect = [decrypted_val1, decrypted_val2]
- input_item = copy.deepcopy(item)
- expected_item = {
- "secret": decrypted_val1,
- "cacert": decrypted_val2,
- "path": "/var",
- "ip": "192.168.12.23",
- }
- fields = ["secret", "cacert"]
-
- asyncio.run(
- self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
- )
- self.assertEqual(input_item, expected_item)
- _call_mock_decrypt = mock_decrypt.call_args_list
- self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.1", salt))
- self.assertEqual(_call_mock_decrypt[1].args, ("mycacert", "1.1", salt))
-
- @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
- def test_decrypt_fields_empty_item_with_fields(self, mock_decrypt):
- """item is empty and fields exists."""
- input_item = {}
- fields = ["secret", "cacert"]
- asyncio.run(
- self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
- )
- self.assertEqual(input_item, {})
- mock_decrypt.assert_not_called()
-
- @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
- def test_decrypt_fields_with_item_without_fields(self, mock_decrypt):
- """item exists and fields is empty."""
- input_item = copy.deepcopy(item)
- fields = []
- asyncio.run(
- self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
- )
- self.assertEqual(input_item, item)
- mock_decrypt.assert_not_called()
-
- @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
- def test_decrypt_fields_with_item_with_single_field(self, mock_decrypt):
- """item exists and field has single value."""
- mock_decrypt.return_value = decrypted_val1
- fields = ["secret"]
- input_item = copy.deepcopy(item)
- expected_item = {
- "secret": decrypted_val1,
- "cacert": "mycacert",
- "path": "/var",
- "ip": "192.168.12.23",
- }
- asyncio.run(
- self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
- )
- self.assertEqual(input_item, expected_item)
- _call_mock_decrypt = mock_decrypt.call_args_list
- self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.1", salt))
-
- @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
- def test_decrypt_fields_with_item_with_field_none_salt_1_0_schema_version(
- self, mock_decrypt
- ):
- """item exists and field has single value, salt is None, schema version is 1.0."""
- schema_version = "1.0"
- salt = None
- mock_decrypt.return_value = "mysecret"
- input_item = copy.deepcopy(item)
- fields = ["secret"]
- asyncio.run(
- self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
- )
- self.assertEqual(input_item, item)
- _call_mock_decrypt = mock_decrypt.call_args_list
- self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.0", None))
-
- @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
- def test_decrypt_fields_decrypt_raises(self, mock_decrypt):
- """Method decrypt raises error."""
- mock_decrypt.side_effect = DbException(
- "Incorrect data type: type(None), string is expected."
- )
- fields = ["secret"]
- input_item = copy.deepcopy(item)
- with self.assertRaises(Exception) as error:
- asyncio.run(
- self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
- )
- self.assertEqual(
- str(error.exception),
- "database exception Incorrect data type: type(None), string is expected.",
- )
- self.assertEqual(input_item, item)
- _call_mock_decrypt = mock_decrypt.call_args_list
- self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.1", salt))
-
- @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
- @patch.object(Encryption, "_encrypt_value")
- def test_encrypt(self, mock_encrypt_value, mock_get_secret_key):
- """Method decrypt raises error."""
- mock_encrypt_value.return_value = encyrpted_value
- result = asyncio.run(self.encryption.encrypt(value, schema_version, salt))
- self.assertEqual(result, encyrpted_value)
- mock_get_secret_key.assert_called_once()
- mock_encrypt_value.assert_called_once_with(value, schema_version, salt)
-
- @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
- @patch.object(Encryption, "_encrypt_value")
- def test_encrypt_get_secret_key_raises(
- self, mock_encrypt_value, mock_get_secret_key
- ):
- """Method get_secret_key raises error."""
- mock_get_secret_key.side_effect = DbException("Unexpected type.")
- with self.assertRaises(Exception) as error:
- asyncio.run(self.encryption.encrypt(value, schema_version, salt))
- self.assertEqual(str(error.exception), "database exception Unexpected type.")
- mock_get_secret_key.assert_called_once()
- mock_encrypt_value.assert_not_called()
-
- @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
- @patch.object(Encryption, "_encrypt_value")
- def test_encrypt_get_encrypt_raises(self, mock_encrypt_value, mock_get_secret_key):
- """Method _encrypt raises error."""
- mock_encrypt_value.side_effect = TypeError(
- "A bytes-like object is required, not 'str'"
- )
- with self.assertRaises(Exception) as error:
- asyncio.run(self.encryption.encrypt(value, schema_version, salt))
- self.assertEqual(
- str(error.exception), "A bytes-like object is required, not 'str'"
- )
- mock_get_secret_key.assert_called_once()
- mock_encrypt_value.assert_called_once_with(value, schema_version, salt)
-
- @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
- @patch.object(Encryption, "_decrypt_value")
- def test_decrypt(self, mock_decrypt_value, mock_get_secret_key):
- """Decrypted successfully."""
- mock_decrypt_value.return_value = value
- result = asyncio.run(
- self.encryption.decrypt(encyrpted_value, schema_version, salt)
- )
- self.assertEqual(result, value)
- mock_get_secret_key.assert_called_once()
- mock_decrypt_value.assert_called_once_with(
- encyrpted_value, schema_version, salt
- )
-
- @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
- @patch.object(Encryption, "_decrypt_value")
- def test_decrypt_get_secret_key_raises(
- self, mock_decrypt_value, mock_get_secret_key
- ):
- """Method get_secret_key raises error."""
- mock_get_secret_key.side_effect = DbException("Unexpected type.")
- with self.assertRaises(Exception) as error:
- asyncio.run(self.encryption.decrypt(encyrpted_value, schema_version, salt))
- self.assertEqual(str(error.exception), "database exception Unexpected type.")
- mock_get_secret_key.assert_called_once()
- mock_decrypt_value.assert_not_called()
-
- @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
- @patch.object(Encryption, "_decrypt_value")
- def test_decrypt_decrypt_value_raises(
- self, mock_decrypt_value, mock_get_secret_key
- ):
- """Method get_secret_key raises error."""
- mock_decrypt_value.side_effect = TypeError(
- "A bytes-like object is required, not 'str'"
- )
- with self.assertRaises(Exception) as error:
- asyncio.run(self.encryption.decrypt(encyrpted_value, schema_version, salt))
- self.assertEqual(
- str(error.exception), "A bytes-like object is required, not 'str'"
- )
- mock_get_secret_key.assert_called_once()
- mock_decrypt_value.assert_called_once_with(
- encyrpted_value, schema_version, salt
- )
-
- def test_join_keys_string_key(self):
- """key is string."""
- string_key = "sample key"
- result = self.encryption._join_keys(string_key, secret_key)
- self.assertEqual(result, joined_key)
- self.assertTrue(isinstance(result, bytes))
-
- def test_join_keys_bytes_key(self):
- """key is bytes."""
- bytes_key = b"sample key"
- result = self.encryption._join_keys(bytes_key, secret_key)
- self.assertEqual(result, joined_key)
- self.assertTrue(isinstance(result, bytes))
- self.assertEqual(len(result.decode("unicode_escape")), 32)
-
- def test_join_keys_int_key(self):
- """key is int."""
- int_key = 923
- with self.assertRaises(Exception) as error:
- self.encryption._join_keys(int_key, None)
- self.assertEqual(str(error.exception), "'int' object is not iterable")
-
- def test_join_keys_none_secret_key(self):
- """key is as bytes and secret key is None."""
- bytes_key = b"sample key"
- result = self.encryption._join_keys(bytes_key, None)
- self.assertEqual(
- result,
- b"sample key\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- )
- self.assertTrue(isinstance(result, bytes))
- self.assertEqual(len(result.decode("unicode_escape")), 32)
-
- def test_join_keys_none_key_none_secret_key(self):
- """key is None and secret key is None."""
- with self.assertRaises(Exception) as error:
- self.encryption._join_keys(None, None)
- self.assertEqual(str(error.exception), "'NoneType' object is not iterable")
-
- def test_join_keys_none_key(self):
- """key is None and secret key exists."""
- with self.assertRaises(Exception) as error:
- self.encryption._join_keys(None, secret_key)
- self.assertEqual(str(error.exception), "'NoneType' object is not iterable")
-
- @patch.object(Encryption, "_join_keys")
- def test_join_secret_key_string_sample_key(self, mock_join_keys):
- """key is None and secret key exists as string."""
- update_key = "sample key"
- mock_join_keys.return_value = joined_key
- result = self.encryption._join_secret_key(update_key)
- self.assertEqual(result, joined_key)
- self.assertTrue(isinstance(result, bytes))
- mock_join_keys.assert_called_once_with(update_key, secret_key)
-
- @patch.object(Encryption, "_join_keys")
- def test_join_secret_key_byte_sample_key(self, mock_join_keys):
- """key is None and secret key exists as bytes."""
- update_key = b"sample key"
- mock_join_keys.return_value = joined_key
- result = self.encryption._join_secret_key(update_key)
- self.assertEqual(result, joined_key)
- self.assertTrue(isinstance(result, bytes))
- mock_join_keys.assert_called_once_with(update_key, secret_key)
-
- @patch.object(Encryption, "_join_keys")
- def test_join_secret_key_join_keys_raises(self, mock_join_keys):
- """Method _join_secret_key raises."""
- update_key = 3434
- mock_join_keys.side_effect = TypeError("'int' object is not iterable")
- with self.assertRaises(Exception) as error:
- self.encryption._join_secret_key(update_key)
- self.assertEqual(str(error.exception), "'int' object is not iterable")
- mock_join_keys.assert_called_once_with(update_key, secret_key)
-
- @patch.object(Encryption, "_join_keys")
- def test_get_secret_key_exists(self, mock_join_keys):
- """secret_key exists."""
- self.encryption._secret_key = secret_key
- asyncio.run(self.encryption.get_secret_key())
- self.assertEqual(self.encryption.secret_key, secret_key)
- mock_join_keys.assert_not_called()
-
- @patch.object(Encryption, "_join_keys")
- @patch("osm_common.dbbase.b64decode")
- def test_get_secret_key_not_exist_database_key_exist(
- self, mock_b64decode, mock_join_keys
- ):
- """secret_key does not exist, database key exists."""
- self.encryption._secret_key = None
- self.encryption._admin_collection.find_one.return_value = None
- self.encryption._config = {"database_commonkey": "osm_new_key"}
- mock_join_keys.return_value = joined_key
- asyncio.run(self.encryption.get_secret_key())
- self.assertEqual(self.encryption.secret_key, joined_key)
- self.assertEqual(mock_join_keys.call_count, 1)
- mock_b64decode.assert_not_called()
-
- @patch.object(Encryption, "_join_keys")
- @patch("osm_common.dbbase.b64decode")
- def test_get_secret_key_not_exist_with_database_key_version_data_exist_without_serial(
- self, mock_b64decode, mock_join_keys
- ):
- """secret_key does not exist, database key exists."""
- self.encryption._secret_key = None
- self.encryption._admin_collection.find_one.return_value = {"version": "1.0"}
- self.encryption._config = {"database_commonkey": "osm_new_key"}
- mock_join_keys.return_value = joined_key
- asyncio.run(self.encryption.get_secret_key())
- self.assertEqual(self.encryption.secret_key, joined_key)
- self.assertEqual(mock_join_keys.call_count, 1)
- mock_b64decode.assert_not_called()
- self.encryption._admin_collection.find_one.assert_called_once_with(
- {"_id": "version"}
- )
- _call_mock_join_keys = mock_join_keys.call_args_list
- self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None))
-
- @patch.object(Encryption, "_join_keys")
- @patch("osm_common.dbbase.b64decode")
- def test_get_secret_key_not_exist_with_database_key_version_data_exist_with_serial(
- self, mock_b64decode, mock_join_keys
- ):
- """secret_key does not exist, database key exists, version and serial exist
- in admin collection."""
- self.encryption._secret_key = None
- self.encryption._admin_collection.find_one.return_value = {
- "version": "1.0",
- "serial": serial_bytes,
- }
- self.encryption._config = {"database_commonkey": "osm_new_key"}
- mock_join_keys.side_effect = [secret_key, joined_key]
- mock_b64decode.return_value = base64_decoded_serial
- asyncio.run(self.encryption.get_secret_key())
- self.assertEqual(self.encryption.secret_key, joined_key)
- self.assertEqual(mock_join_keys.call_count, 2)
- mock_b64decode.assert_called_once_with(serial_bytes)
- self.encryption._admin_collection.find_one.assert_called_once_with(
- {"_id": "version"}
- )
- _call_mock_join_keys = mock_join_keys.call_args_list
- self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None))
- self.assertEqual(
- _call_mock_join_keys[1].args, (base64_decoded_serial, secret_key)
- )
-
- @patch.object(Encryption, "_join_keys")
- @patch("osm_common.dbbase.b64decode")
- def test_get_secret_key_join_keys_raises(self, mock_b64decode, mock_join_keys):
- """Method _join_keys raises."""
- self.encryption._secret_key = None
- self.encryption._admin_collection.find_one.return_value = {
- "version": "1.0",
- "serial": serial_bytes,
- }
- self.encryption._config = {"database_commonkey": "osm_new_key"}
- mock_join_keys.side_effect = DbException("Invalid data type.")
- with self.assertRaises(Exception) as error:
- asyncio.run(self.encryption.get_secret_key())
- self.assertEqual(str(error.exception), "database exception Invalid data type.")
- self.assertEqual(mock_join_keys.call_count, 1)
- check_if_assert_not_called(
- [mock_b64decode, self.encryption._admin_collection.find_one]
- )
- _call_mock_join_keys = mock_join_keys.call_args_list
- self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None))
-
- @patch.object(Encryption, "_join_keys")
- @patch("osm_common.dbbase.b64decode")
- def test_get_secret_key_b64decode_raises(self, mock_b64decode, mock_join_keys):
- """Method b64decode raises."""
- self.encryption._secret_key = None
- self.encryption._admin_collection.find_one.return_value = {
- "version": "1.0",
- "serial": base64_decoded_serial,
- }
- self.encryption._config = {"database_commonkey": "osm_new_key"}
- mock_join_keys.return_value = secret_key
- mock_b64decode.side_effect = TypeError(
- "A bytes-like object is required, not 'str'"
- )
- with self.assertRaises(Exception) as error:
- asyncio.run(self.encryption.get_secret_key())
- self.assertEqual(
- str(error.exception), "A bytes-like object is required, not 'str'"
- )
- self.assertEqual(self.encryption.secret_key, None)
- self.assertEqual(mock_join_keys.call_count, 1)
- mock_b64decode.assert_called_once_with(base64_decoded_serial)
- self.encryption._admin_collection.find_one.assert_called_once_with(
- {"_id": "version"}
- )
- _call_mock_join_keys = mock_join_keys.call_args_list
- self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None))
-
- @patch.object(Encryption, "_join_keys")
- @patch("osm_common.dbbase.b64decode")
- def test_get_secret_key_admin_collection_find_one_raises(
- self, mock_b64decode, mock_join_keys
- ):
- """admin_collection find_one raises."""
- self.encryption._secret_key = None
- self.encryption._admin_collection.find_one.side_effect = DbException(
- "Connection failed."
- )
- self.encryption._config = {"database_commonkey": "osm_new_key"}
- mock_join_keys.return_value = secret_key
- with self.assertRaises(Exception) as error:
- asyncio.run(self.encryption.get_secret_key())
- self.assertEqual(str(error.exception), "database exception Connection failed.")
- self.assertEqual(self.encryption.secret_key, None)
- self.assertEqual(mock_join_keys.call_count, 1)
- mock_b64decode.assert_not_called()
- self.encryption._admin_collection.find_one.assert_called_once_with(
- {"_id": "version"}
- )
- _call_mock_join_keys = mock_join_keys.call_args_list
- self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None))
-
- def test_encrypt_decrypt_with_schema_version_1_1_with_salt(self):
- """Encrypt and decrypt with schema version 1.1, salt exists."""
- encrypted_msg = asyncio.run(
- self.encryption.encrypt(value, schema_version, salt)
- )
- decrypted_msg = asyncio.run(
- self.encryption.decrypt(encrypted_msg, schema_version, salt)
- )
- self.assertEqual(value, decrypted_msg)
-
- def test_encrypt_decrypt_with_schema_version_1_0_with_salt(self):
- """Encrypt and decrypt with schema version 1.0, salt exists."""
- schema_version = "1.0"
- encrypted_msg = asyncio.run(
- self.encryption.encrypt(value, schema_version, salt)
- )
- decrypted_msg = asyncio.run(
- self.encryption.decrypt(encrypted_msg, schema_version, salt)
- )
- self.assertEqual(value, decrypted_msg)
-
- def test_encrypt_decrypt_with_schema_version_1_1_without_salt(self):
- """Encrypt and decrypt with schema version 1.1, without salt."""
- salt = None
- with self.assertRaises(Exception) as error:
- asyncio.run(self.encryption.encrypt(value, schema_version, salt))
- self.assertEqual(str(error.exception), "'NoneType' object is not iterable")
-
-
-class TestDeepUpdate(unittest.TestCase):
- def test_update_dict(self):
- # Original, patch, expected result
- TEST = (
- ({"a": "b"}, {"a": "c"}, {"a": "c"}),
- ({"a": "b"}, {"b": "c"}, {"a": "b", "b": "c"}),
- ({"a": "b"}, {"a": None}, {}),
- ({"a": "b", "b": "c"}, {"a": None}, {"b": "c"}),
- ({"a": ["b"]}, {"a": "c"}, {"a": "c"}),
- ({"a": "c"}, {"a": ["b"]}, {"a": ["b"]}),
- ({"a": {"b": "c"}}, {"a": {"b": "d", "c": None}}, {"a": {"b": "d"}}),
- ({"a": [{"b": "c"}]}, {"a": [1]}, {"a": [1]}),
- ({1: ["a", "b"]}, {1: ["c", "d"]}, {1: ["c", "d"]}),
- ({1: {"a": "b"}}, {1: ["c"]}, {1: ["c"]}),
- ({1: {"a": "foo"}}, {1: None}, {}),
- ({1: {"a": "foo"}}, {1: "bar"}, {1: "bar"}),
- ({"e": None}, {"a": 1}, {"e": None, "a": 1}),
- ({1: [1, 2]}, {1: {"a": "b", "c": None}}, {1: {"a": "b"}}),
- ({}, {"a": {"bb": {"ccc": None}}}, {"a": {"bb": {}}}),
- )
- for t in TEST:
- deep_update(t[0], t[1])
- self.assertEqual(t[0], t[2])
- # test deepcopy is done. So that original dictionary does not reference the pach
- test_original = {1: {"a": "b"}}
- test_patch = {1: {"c": {"d": "e"}}}
- test_result = {1: {"a": "b", "c": {"d": "e"}}}
- deep_update(test_original, test_patch)
- self.assertEqual(test_original, test_result)
- test_patch[1]["c"]["f"] = "edition of patch, must not modify original"
- self.assertEqual(test_original, test_result)
-
- def test_update_array(self):
- # This TEST contains a list with the the Original, patch, and expected result
- TEST = (
- # delete all instances of "a"/"d"
- ({"A": ["a", "b", "a"]}, {"A": {"$a": None}}, {"A": ["b"]}),
- ({"A": ["a", "b", "a"]}, {"A": {"$d": None}}, {"A": ["a", "b", "a"]}),
- # delete and insert at 0
- (
- {"A": ["a", "b", "c"]},
- {"A": {"$b": None, "$+[0]": "b"}},
- {"A": ["b", "a", "c"]},
- ),
- # delete and edit
- (
- {"A": ["a", "b", "a"]},
- {"A": {"$a": None, "$[1]": {"c": "d"}}},
- {"A": [{"c": "d"}]},
- ),
- # insert if not exist
- ({"A": ["a", "b", "c"]}, {"A": {"$+b": "b"}}, {"A": ["a", "b", "c"]}),
- ({"A": ["a", "b", "c"]}, {"A": {"$+d": "f"}}, {"A": ["a", "b", "c", "f"]}),
- # edit by filter
- (
- {"A": ["a", "b", "a"]},
- {"A": {"$b": {"c": "d"}}},
- {"A": ["a", {"c": "d"}, "a"]},
- ),
- (
- {"A": ["a", "b", "a"]},
- {"A": {"$b": None, "$+[0]": "b", "$+": "c"}},
- {"A": ["b", "a", "a", "c"]},
- ),
- ({"A": ["a", "b", "a"]}, {"A": {"$c": None}}, {"A": ["a", "b", "a"]}),
- # index deletion out of range
- ({"A": ["a", "b", "a"]}, {"A": {"$[5]": None}}, {"A": ["a", "b", "a"]}),
- # nested array->dict
- (
- {"A": ["a", "b", {"id": "1", "c": {"d": 2}}]},
- {"A": {"$id: '1'": {"h": None, "c": {"d": "e", "f": "g"}}}},
- {"A": ["a", "b", {"id": "1", "c": {"d": "e", "f": "g"}}]},
- ),
- (
- {"A": [{"id": 1, "c": {"d": 2}}, {"id": 1, "c": {"f": []}}]},
- {"A": {"$id: 1": {"h": None, "c": {"d": "e", "f": "g"}}}},
- {
- "A": [
- {"id": 1, "c": {"d": "e", "f": "g"}},
- {"id": 1, "c": {"d": "e", "f": "g"}},
- ]
- },
- ),
- # nested array->array
- (
- {"A": ["a", "b", ["a", "b"]]},
- {"A": {"$b": None, "$[2]": {"$b": {}, "$+": "c"}}},
- {"A": ["a", ["a", {}, "c"]]},
- ),
- # types str and int different, so not found
- (
- {"A": ["a", {"id": "1", "c": "d"}]},
- {"A": {"$id: 1": {"c": "e"}}},
- {"A": ["a", {"id": "1", "c": "d"}]},
- ),
- )
- for t in TEST:
- print(t)
- deep_update(t[0], t[1])
- self.assertEqual(t[0], t[2])
-
- def test_update_badformat(self):
- # This TEST contains original, incorrect patch and #TODO text that must be present
- TEST = (
- # conflict, index 0 is edited twice
- ({"A": ["a", "b", "a"]}, {"A": {"$a": None, "$[0]": {"c": "d"}}}),
- # conflict, two insertions at same index
- ({"A": ["a", "b", "a"]}, {"A": {"$[1]": "c", "$[-2]": "d"}}),
- ({"A": ["a", "b", "a"]}, {"A": {"$[1]": "c", "$[+1]": "d"}}),
- # bad format keys with and without $
- ({"A": ["a", "b", "a"]}, {"A": {"$b": {"c": "d"}, "c": 3}}),
- # bad format empty $ and yaml incorrect
- ({"A": ["a", "b", "a"]}, {"A": {"$": 3}}),
- ({"A": ["a", "b", "a"]}, {"A": {"$a: b: c": 3}}),
- ({"A": ["a", "b", "a"]}, {"A": {"$a: b, c: d": 3}}),
- # insertion of None
- ({"A": ["a", "b", "a"]}, {"A": {"$+": None}}),
- # Not found, insertion of None
- ({"A": ["a", "b", "a"]}, {"A": {"$+c": None}}),
- # index edition out of range
- ({"A": ["a", "b", "a"]}, {"A": {"$[5]": 6}}),
- # conflict, two editions on index 2
- (
- {"A": ["a", {"id": "1", "c": "d"}]},
- {"A": {"$id: '1'": {"c": "e"}, "$c: d": {"c": "f"}}},
- ),
- )
- for t in TEST:
- print(t)
- self.assertRaises(DbException, deep_update, t[0], t[1])
- try:
- deep_update(t[0], t[1])
- except DbException as e:
- print(e)
-
-
-if __name__ == "__main__":
- unittest.main()
+++ /dev/null
-# Copyright 2018 Whitestack, LLC
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
-##
-from copy import deepcopy
-import http
-import logging
-import unittest
-from unittest.mock import MagicMock, Mock
-
-from osm_common.dbbase import DbException
-from osm_common.dbmemory import DbMemory
-import pytest
-
-__author__ = "Eduardo Sousa <eduardosousa@av.it.pt>"
-
-
-@pytest.fixture(scope="function", params=[True, False])
-def db_memory(request):
- db = DbMemory(lock=request.param)
- return db
-
-
-@pytest.fixture(scope="function", params=[True, False])
-def db_memory_with_data(request):
- db = DbMemory(lock=request.param)
-
- db.create("test", {"_id": 1, "data": 1})
- db.create("test", {"_id": 2, "data": 2})
- db.create("test", {"_id": 3, "data": 3})
-
- return db
-
-
-@pytest.fixture(scope="function")
-def db_memory_with_many_data(request):
- db = DbMemory(lock=False)
-
- db.create_list(
- "test",
- [
- {
- "_id": 1,
- "data": {"data2": {"data3": 1}},
- "list": [{"a": 1}],
- "text": "sometext",
- },
- {
- "_id": 2,
- "data": {"data2": {"data3": 2}},
- "list": [{"a": 2}],
- "list2": [1, 2, 3],
- },
- {"_id": 3, "data": {"data2": {"data3": 3}}, "list": [{"a": 3}]},
- {"_id": 4, "data": {"data2": {"data3": 4}}, "list": [{"a": 4}, {"a": 0}]},
- {"_id": 5, "data": {"data2": {"data3": 5}}, "list": [{"a": 5}]},
- {"_id": 6, "data": {"data2": {"data3": 6}}, "list": [{"0": {"a": 1}}]},
- {"_id": 7, "data": {"data2": {"data3": 7}}, "0": {"a": 0}},
- {
- "_id": 8,
- "list": [
- {"a": 3, "b": 0, "c": [{"a": 3, "b": 1}, {"a": 0, "b": "v"}]},
- {"a": 0, "b": 1},
- ],
- },
- ],
- )
- return db
-
-
-def empty_exception_message():
- return "database exception "
-
-
-def get_one_exception_message(db_filter):
- return "database exception Not found entry with filter='{}'".format(db_filter)
-
-
-def get_one_multiple_exception_message(db_filter):
- return "database exception Found more than one entry with filter='{}'".format(
- db_filter
- )
-
-
-def del_one_exception_message(db_filter):
- return "database exception Not found entry with filter='{}'".format(db_filter)
-
-
-def replace_exception_message(value):
- return "database exception Not found entry with _id='{}'".format(value)
-
-
-def test_constructor():
- db = DbMemory()
- assert db.logger == logging.getLogger("db")
- assert db.db == {}
-
-
-def test_constructor_with_logger():
- logger_name = "db_local"
- db = DbMemory(logger_name=logger_name)
- assert db.logger == logging.getLogger(logger_name)
- assert db.db == {}
-
-
-def test_db_connect():
- logger_name = "db_local"
- config = {"logger_name": logger_name}
- db = DbMemory()
- db.db_connect(config)
- assert db.logger == logging.getLogger(logger_name)
- assert db.db == {}
-
-
-def test_db_disconnect(db_memory):
- db_memory.db_disconnect()
-
-
-@pytest.mark.parametrize(
- "table, db_filter",
- [
- ("test", {}),
- ("test", {"_id": 1}),
- ("test", {"data": 1}),
- ("test", {"_id": 1, "data": 1}),
- ],
-)
-def test_get_list_with_empty_db(db_memory, table, db_filter):
- result = db_memory.get_list(table, db_filter)
- assert len(result) == 0
-
-
-@pytest.mark.parametrize(
- "table, db_filter, expected_data",
- [
- (
- "test",
- {},
- [{"_id": 1, "data": 1}, {"_id": 2, "data": 2}, {"_id": 3, "data": 3}],
- ),
- ("test", {"_id": 1}, [{"_id": 1, "data": 1}]),
- ("test", {"data": 1}, [{"_id": 1, "data": 1}]),
- ("test", {"_id": 1, "data": 1}, [{"_id": 1, "data": 1}]),
- ("test", {"_id": 2}, [{"_id": 2, "data": 2}]),
- ("test", {"data": 2}, [{"_id": 2, "data": 2}]),
- ("test", {"_id": 2, "data": 2}, [{"_id": 2, "data": 2}]),
- ("test", {"_id": 4}, []),
- ("test", {"data": 4}, []),
- ("test", {"_id": 4, "data": 4}, []),
- ("test_table", {}, []),
- ("test_table", {"_id": 1}, []),
- ("test_table", {"data": 1}, []),
- ("test_table", {"_id": 1, "data": 1}, []),
- ],
-)
-def test_get_list_with_non_empty_db(
- db_memory_with_data, table, db_filter, expected_data
-):
- result = db_memory_with_data.get_list(table, db_filter)
- assert len(result) == len(expected_data)
- for data in expected_data:
- assert data in result
-
-
-def test_get_list_exception(db_memory_with_data):
- table = "test"
- db_filter = {}
- db_memory_with_data._find = MagicMock(side_effect=Exception())
- with pytest.raises(DbException) as excinfo:
- db_memory_with_data.get_list(table, db_filter)
- assert str(excinfo.value) == empty_exception_message()
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "table, db_filter, expected_data",
- [
- ("test", {"_id": 1}, {"_id": 1, "data": 1}),
- ("test", {"_id": 2}, {"_id": 2, "data": 2}),
- ("test", {"_id": 3}, {"_id": 3, "data": 3}),
- ("test", {"data": 1}, {"_id": 1, "data": 1}),
- ("test", {"data": 2}, {"_id": 2, "data": 2}),
- ("test", {"data": 3}, {"_id": 3, "data": 3}),
- ("test", {"_id": 1, "data": 1}, {"_id": 1, "data": 1}),
- ("test", {"_id": 2, "data": 2}, {"_id": 2, "data": 2}),
- ("test", {"_id": 3, "data": 3}, {"_id": 3, "data": 3}),
- ],
-)
-def test_get_one(db_memory_with_data, table, db_filter, expected_data):
- result = db_memory_with_data.get_one(table, db_filter)
- assert result == expected_data
- assert len(db_memory_with_data.db) == 1
- assert table in db_memory_with_data.db
- assert len(db_memory_with_data.db[table]) == 3
- assert result in db_memory_with_data.db[table]
-
-
-@pytest.mark.parametrize(
- "db_filter, expected_ids",
- [
- ({}, [1, 2, 3, 4, 5, 6, 7, 8]),
- ({"_id": 1}, [1]),
- ({"data.data2.data3": 2}, [2]),
- ({"data.data2.data3.eq": 2}, [2]),
- ({"data.data2.data3": [2]}, [2]),
- ({"data.data2.data3.cont": [2]}, [2]),
- ({"data.data2.data3.neq": 2}, [1, 3, 4, 5, 6, 7, 8]),
- ({"data.data2.data3.neq": [2]}, [1, 3, 4, 5, 6, 7, 8]),
- ({"data.data2.data3.ncont": [2]}, [1, 3, 4, 5, 6, 7, 8]),
- ({"data.data2.data3": [2, 3]}, [2, 3]),
- ({"data.data2.data3.gt": 4}, [5, 6, 7]),
- ({"data.data2.data3.gte": 4}, [4, 5, 6, 7]),
- ({"data.data2.data3.lt": 4}, [1, 2, 3]),
- ({"data.data2.data3.lte": 4}, [1, 2, 3, 4]),
- ({"data.data2.data3.lte": 4.5}, [1, 2, 3, 4]),
- ({"data.data2.data3.gt": "text"}, []),
- ({"nonexist.nonexist": "4"}, []),
- ({"nonexist.nonexist": None}, [1, 2, 3, 4, 5, 6, 7, 8]),
- ({"nonexist.nonexist.neq": "4"}, [1, 2, 3, 4, 5, 6, 7, 8]),
- ({"nonexist.nonexist.neq": None}, []),
- ({"text.eq": "sometext"}, [1]),
- ({"text.neq": "sometext"}, [2, 3, 4, 5, 6, 7, 8]),
- ({"text.eq": "somet"}, []),
- ({"text.gte": "a"}, [1]),
- ({"text.gte": "somet"}, [1]),
- ({"text.gte": "sometext"}, [1]),
- ({"text.lt": "somet"}, []),
- ({"data.data2.data3": 2, "data.data2.data4": None}, [2]),
- ({"data.data2.data3": 2, "data.data2.data4": 5}, []),
- ({"data.data2.data3": 4}, [4]),
- ({"data.data2.data3": [3, 4, "e"]}, [3, 4]),
- ({"data.data2.data3": None}, [8]),
- ({"data.data2": "4"}, []),
- ({"list.0.a": 1}, [1, 6]),
- ({"list2": 1}, [2]),
- ({"list2": [1, 5]}, [2]),
- ({"list2": [1, 2]}, [2]),
- ({"list2": [5, 7]}, []),
- ({"list.ANYINDEX.a": 1}, [1]),
- ({"list.a": 3, "list.b": 1}, [8]),
- ({"list.ANYINDEX.a": 3, "list.ANYINDEX.b": 1}, []),
- ({"list.ANYINDEX.a": 3, "list.ANYINDEX.c.a": 3}, [8]),
- ({"list.ANYINDEX.a": 3, "list.ANYINDEX.b": 0}, [8]),
- (
- {
- "list.ANYINDEX.a": 3,
- "list.ANYINDEX.c.ANYINDEX.a": 0,
- "list.ANYINDEX.c.ANYINDEX.b": "v",
- },
- [8],
- ),
- (
- {
- "list.ANYINDEX.a": 3,
- "list.ANYINDEX.c.ANYINDEX.a": 0,
- "list.ANYINDEX.c.ANYINDEX.b": 1,
- },
- [],
- ),
- ({"list.c.b": 1}, [8]),
- ({"list.c.b": None}, [1, 2, 3, 4, 5, 6, 7]),
- # ({"data.data2.data3": 4}, []),
- # ({"data.data2.data3": 4}, []),
- ],
-)
-def test_get_list(db_memory_with_many_data, db_filter, expected_ids):
- result = db_memory_with_many_data.get_list("test", db_filter)
- assert isinstance(result, list)
- result_ids = [item["_id"] for item in result]
- assert len(result) == len(
- expected_ids
- ), "for db_filter={} result={} expected_ids={}".format(
- db_filter, result, result_ids
- )
- assert result_ids == expected_ids
- for i in range(len(result)):
- assert result[i] in db_memory_with_many_data.db["test"]
-
- assert len(db_memory_with_many_data.db) == 1
- assert "test" in db_memory_with_many_data.db
- assert len(db_memory_with_many_data.db["test"]) == 8
- result = db_memory_with_many_data.count("test", db_filter)
- assert result == len(expected_ids)
-
-
-@pytest.mark.parametrize(
- "table, db_filter, expected_data", [("test", {}, {"_id": 1, "data": 1})]
-)
-def test_get_one_with_multiple_results(
- db_memory_with_data, table, db_filter, expected_data
-):
- result = db_memory_with_data.get_one(table, db_filter, fail_on_more=False)
- assert result == expected_data
- assert len(db_memory_with_data.db) == 1
- assert table in db_memory_with_data.db
- assert len(db_memory_with_data.db[table]) == 3
- assert result in db_memory_with_data.db[table]
-
-
-def test_get_one_with_multiple_results_exception(db_memory_with_data):
- table = "test"
- db_filter = {}
- with pytest.raises(DbException) as excinfo:
- db_memory_with_data.get_one(table, db_filter)
- assert str(excinfo.value) == (
- empty_exception_message() + get_one_multiple_exception_message(db_filter)
- )
- # assert excinfo.value.http_code == http.HTTPStatus.CONFLICT
-
-
-@pytest.mark.parametrize(
- "table, db_filter",
- [
- ("test", {"_id": 4}),
- ("test", {"data": 4}),
- ("test", {"_id": 4, "data": 4}),
- ("test_table", {"_id": 4}),
- ("test_table", {"data": 4}),
- ("test_table", {"_id": 4, "data": 4}),
- ],
-)
-def test_get_one_with_non_empty_db_exception(db_memory_with_data, table, db_filter):
- with pytest.raises(DbException) as excinfo:
- db_memory_with_data.get_one(table, db_filter)
- assert str(excinfo.value) == (
- empty_exception_message() + get_one_exception_message(db_filter)
- )
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "table, db_filter",
- [
- ("test", {"_id": 4}),
- ("test", {"data": 4}),
- ("test", {"_id": 4, "data": 4}),
- ("test_table", {"_id": 4}),
- ("test_table", {"data": 4}),
- ("test_table", {"_id": 4, "data": 4}),
- ],
-)
-def test_get_one_with_non_empty_db_none(db_memory_with_data, table, db_filter):
- result = db_memory_with_data.get_one(table, db_filter, fail_on_empty=False)
- assert result is None
-
-
-@pytest.mark.parametrize(
- "table, db_filter",
- [
- ("test", {"_id": 4}),
- ("test", {"data": 4}),
- ("test", {"_id": 4, "data": 4}),
- ("test_table", {"_id": 4}),
- ("test_table", {"data": 4}),
- ("test_table", {"_id": 4, "data": 4}),
- ],
-)
-def test_get_one_with_empty_db_exception(db_memory, table, db_filter):
- with pytest.raises(DbException) as excinfo:
- db_memory.get_one(table, db_filter)
- assert str(excinfo.value) == (
- empty_exception_message() + get_one_exception_message(db_filter)
- )
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "table, db_filter",
- [
- ("test", {"_id": 4}),
- ("test", {"data": 4}),
- ("test", {"_id": 4, "data": 4}),
- ("test_table", {"_id": 4}),
- ("test_table", {"data": 4}),
- ("test_table", {"_id": 4, "data": 4}),
- ],
-)
-def test_get_one_with_empty_db_none(db_memory, table, db_filter):
- result = db_memory.get_one(table, db_filter, fail_on_empty=False)
- assert result is None
-
-
-def test_get_one_generic_exception(db_memory_with_data):
- table = "test"
- db_filter = {}
- db_memory_with_data._find = MagicMock(side_effect=Exception())
- with pytest.raises(DbException) as excinfo:
- db_memory_with_data.get_one(table, db_filter)
- assert str(excinfo.value) == empty_exception_message()
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "table, db_filter, expected_data",
- [
- ("test", {}, []),
- ("test", {"_id": 1}, [{"_id": 2, "data": 2}, {"_id": 3, "data": 3}]),
- ("test", {"_id": 2}, [{"_id": 1, "data": 1}, {"_id": 3, "data": 3}]),
- ("test", {"_id": 1, "data": 1}, [{"_id": 2, "data": 2}, {"_id": 3, "data": 3}]),
- ("test", {"_id": 2, "data": 2}, [{"_id": 1, "data": 1}, {"_id": 3, "data": 3}]),
- ],
-)
-def test_del_list_with_non_empty_db(
- db_memory_with_data, table, db_filter, expected_data
-):
- result = db_memory_with_data.del_list(table, db_filter)
- assert result["deleted"] == (3 - len(expected_data))
- assert len(db_memory_with_data.db) == 1
- assert table in db_memory_with_data.db
- assert len(db_memory_with_data.db[table]) == len(expected_data)
- for data in expected_data:
- assert data in db_memory_with_data.db[table]
-
-
-@pytest.mark.parametrize(
- "table, db_filter",
- [
- ("test", {}),
- ("test", {"_id": 1}),
- ("test", {"_id": 2}),
- ("test", {"data": 1}),
- ("test", {"data": 2}),
- ("test", {"_id": 1, "data": 1}),
- ("test", {"_id": 2, "data": 2}),
- ],
-)
-def test_del_list_with_empty_db(db_memory, table, db_filter):
- result = db_memory.del_list(table, db_filter)
- assert result["deleted"] == 0
-
-
-def test_del_list_generic_exception(db_memory_with_data):
- table = "test"
- db_filter = {}
- db_memory_with_data._find = MagicMock(side_effect=Exception())
- with pytest.raises(DbException) as excinfo:
- db_memory_with_data.del_list(table, db_filter)
- assert str(excinfo.value) == empty_exception_message()
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "table, db_filter, data",
- [
- ("test", {}, {"_id": 1, "data": 1}),
- ("test", {"_id": 1}, {"_id": 1, "data": 1}),
- ("test", {"data": 1}, {"_id": 1, "data": 1}),
- ("test", {"_id": 1, "data": 1}, {"_id": 1, "data": 1}),
- ("test", {"_id": 2}, {"_id": 2, "data": 2}),
- ("test", {"data": 2}, {"_id": 2, "data": 2}),
- ("test", {"_id": 2, "data": 2}, {"_id": 2, "data": 2}),
- ],
-)
-def test_del_one(db_memory_with_data, table, db_filter, data):
- result = db_memory_with_data.del_one(table, db_filter)
- assert result == {"deleted": 1}
- assert len(db_memory_with_data.db) == 1
- assert table in db_memory_with_data.db
- assert len(db_memory_with_data.db[table]) == 2
- assert data not in db_memory_with_data.db[table]
-
-
-@pytest.mark.parametrize(
- "table, db_filter",
- [
- ("test", {}),
- ("test", {"_id": 1}),
- ("test", {"_id": 2}),
- ("test", {"data": 1}),
- ("test", {"data": 2}),
- ("test", {"_id": 1, "data": 1}),
- ("test", {"_id": 2, "data": 2}),
- ("test_table", {}),
- ("test_table", {"_id": 1}),
- ("test_table", {"_id": 2}),
- ("test_table", {"data": 1}),
- ("test_table", {"data": 2}),
- ("test_table", {"_id": 1, "data": 1}),
- ("test_table", {"_id": 2, "data": 2}),
- ],
-)
-def test_del_one_with_empty_db_exception(db_memory, table, db_filter):
- with pytest.raises(DbException) as excinfo:
- db_memory.del_one(table, db_filter)
- assert str(excinfo.value) == (
- empty_exception_message() + del_one_exception_message(db_filter)
- )
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "table, db_filter",
- [
- ("test", {}),
- ("test", {"_id": 1}),
- ("test", {"_id": 2}),
- ("test", {"data": 1}),
- ("test", {"data": 2}),
- ("test", {"_id": 1, "data": 1}),
- ("test", {"_id": 2, "data": 2}),
- ("test_table", {}),
- ("test_table", {"_id": 1}),
- ("test_table", {"_id": 2}),
- ("test_table", {"data": 1}),
- ("test_table", {"data": 2}),
- ("test_table", {"_id": 1, "data": 1}),
- ("test_table", {"_id": 2, "data": 2}),
- ],
-)
-def test_del_one_with_empty_db_none(db_memory, table, db_filter):
- result = db_memory.del_one(table, db_filter, fail_on_empty=False)
- assert result is None
-
-
-@pytest.mark.parametrize(
- "table, db_filter",
- [
- ("test", {"_id": 4}),
- ("test", {"_id": 5}),
- ("test", {"data": 4}),
- ("test", {"data": 5}),
- ("test", {"_id": 1, "data": 2}),
- ("test", {"_id": 2, "data": 3}),
- ("test_table", {}),
- ("test_table", {"_id": 1}),
- ("test_table", {"_id": 2}),
- ("test_table", {"data": 1}),
- ("test_table", {"data": 2}),
- ("test_table", {"_id": 1, "data": 1}),
- ("test_table", {"_id": 2, "data": 2}),
- ],
-)
-def test_del_one_with_non_empty_db_exception(db_memory_with_data, table, db_filter):
- with pytest.raises(DbException) as excinfo:
- db_memory_with_data.del_one(table, db_filter)
- assert str(excinfo.value) == (
- empty_exception_message() + del_one_exception_message(db_filter)
- )
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "table, db_filter",
- [
- ("test", {"_id": 4}),
- ("test", {"_id": 5}),
- ("test", {"data": 4}),
- ("test", {"data": 5}),
- ("test", {"_id": 1, "data": 2}),
- ("test", {"_id": 2, "data": 3}),
- ("test_table", {}),
- ("test_table", {"_id": 1}),
- ("test_table", {"_id": 2}),
- ("test_table", {"data": 1}),
- ("test_table", {"data": 2}),
- ("test_table", {"_id": 1, "data": 1}),
- ("test_table", {"_id": 2, "data": 2}),
- ],
-)
-def test_del_one_with_non_empty_db_none(db_memory_with_data, table, db_filter):
- result = db_memory_with_data.del_one(table, db_filter, fail_on_empty=False)
- assert result is None
-
-
-@pytest.mark.parametrize("fail_on_empty", [(True), (False)])
-def test_del_one_generic_exception(db_memory_with_data, fail_on_empty):
- table = "test"
- db_filter = {}
- db_memory_with_data._find = MagicMock(side_effect=Exception())
- with pytest.raises(DbException) as excinfo:
- db_memory_with_data.del_one(table, db_filter, fail_on_empty=fail_on_empty)
- assert str(excinfo.value) == empty_exception_message()
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "table, _id, indata",
- [
- ("test", 1, {"_id": 1, "data": 42}),
- ("test", 1, {"_id": 1, "data": 42, "kk": 34}),
- ("test", 1, {"_id": 1}),
- ("test", 2, {"_id": 2, "data": 42}),
- ("test", 2, {"_id": 2, "data": 42, "kk": 34}),
- ("test", 2, {"_id": 2}),
- ("test", 3, {"_id": 3, "data": 42}),
- ("test", 3, {"_id": 3, "data": 42, "kk": 34}),
- ("test", 3, {"_id": 3}),
- ],
-)
-def test_replace(db_memory_with_data, table, _id, indata):
- result = db_memory_with_data.replace(table, _id, indata)
- assert result == {"updated": 1}
- assert len(db_memory_with_data.db) == 1
- assert table in db_memory_with_data.db
- assert len(db_memory_with_data.db[table]) == 3
- assert indata in db_memory_with_data.db[table]
-
-
-@pytest.mark.parametrize(
- "table, _id, indata",
- [
- ("test", 1, {"_id": 1, "data": 42}),
- ("test", 2, {"_id": 2}),
- ("test", 3, {"_id": 3}),
- ],
-)
-def test_replace_without_data_exception(db_memory, table, _id, indata):
- with pytest.raises(DbException) as excinfo:
- db_memory.replace(table, _id, indata, fail_on_empty=True)
- assert str(excinfo.value) == (replace_exception_message(_id))
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "table, _id, indata",
- [
- ("test", 1, {"_id": 1, "data": 42}),
- ("test", 2, {"_id": 2}),
- ("test", 3, {"_id": 3}),
- ],
-)
-def test_replace_without_data_none(db_memory, table, _id, indata):
- result = db_memory.replace(table, _id, indata, fail_on_empty=False)
- assert result is None
-
-
-@pytest.mark.parametrize(
- "table, _id, indata",
- [
- ("test", 11, {"_id": 11, "data": 42}),
- ("test", 12, {"_id": 12}),
- ("test", 33, {"_id": 33}),
- ],
-)
-def test_replace_with_data_exception(db_memory_with_data, table, _id, indata):
- with pytest.raises(DbException) as excinfo:
- db_memory_with_data.replace(table, _id, indata, fail_on_empty=True)
- assert str(excinfo.value) == (replace_exception_message(_id))
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "table, _id, indata",
- [
- ("test", 11, {"_id": 11, "data": 42}),
- ("test", 12, {"_id": 12}),
- ("test", 33, {"_id": 33}),
- ],
-)
-def test_replace_with_data_none(db_memory_with_data, table, _id, indata):
- result = db_memory_with_data.replace(table, _id, indata, fail_on_empty=False)
- assert result is None
-
-
-@pytest.mark.parametrize("fail_on_empty", [True, False])
-def test_replace_generic_exception(db_memory_with_data, fail_on_empty):
- table = "test"
- _id = {}
- indata = {"_id": 1, "data": 1}
- db_memory_with_data._find = MagicMock(side_effect=Exception())
- with pytest.raises(DbException) as excinfo:
- db_memory_with_data.replace(table, _id, indata, fail_on_empty=fail_on_empty)
- assert str(excinfo.value) == empty_exception_message()
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "table, id, data",
- [
- ("test", "1", {"data": 1}),
- ("test", "1", {"data": 2}),
- ("test", "2", {"data": 1}),
- ("test", "2", {"data": 2}),
- ("test_table", "1", {"data": 1}),
- ("test_table", "1", {"data": 2}),
- ("test_table", "2", {"data": 1}),
- ("test_table", "2", {"data": 2}),
- ("test", "1", {"data_1": 1, "data_2": 2}),
- ("test", "1", {"data_1": 2, "data_2": 1}),
- ("test", "2", {"data_1": 1, "data_2": 2}),
- ("test", "2", {"data_1": 2, "data_2": 1}),
- ("test_table", "1", {"data_1": 1, "data_2": 2}),
- ("test_table", "1", {"data_1": 2, "data_2": 1}),
- ("test_table", "2", {"data_1": 1, "data_2": 2}),
- ("test_table", "2", {"data_1": 2, "data_2": 1}),
- ],
-)
-def test_create_with_empty_db_with_id(db_memory, table, id, data):
- data_to_insert = data
- data_to_insert["_id"] = id
- returned_id = db_memory.create(table, data_to_insert)
- assert returned_id == id
- assert len(db_memory.db) == 1
- assert table in db_memory.db
- assert len(db_memory.db[table]) == 1
- assert data_to_insert in db_memory.db[table]
-
-
-@pytest.mark.parametrize(
- "table, id, data",
- [
- ("test", "4", {"data": 1}),
- ("test", "5", {"data": 2}),
- ("test", "4", {"data": 1}),
- ("test", "5", {"data": 2}),
- ("test_table", "4", {"data": 1}),
- ("test_table", "5", {"data": 2}),
- ("test_table", "4", {"data": 1}),
- ("test_table", "5", {"data": 2}),
- ("test", "4", {"data_1": 1, "data_2": 2}),
- ("test", "5", {"data_1": 2, "data_2": 1}),
- ("test", "4", {"data_1": 1, "data_2": 2}),
- ("test", "5", {"data_1": 2, "data_2": 1}),
- ("test_table", "4", {"data_1": 1, "data_2": 2}),
- ("test_table", "5", {"data_1": 2, "data_2": 1}),
- ("test_table", "4", {"data_1": 1, "data_2": 2}),
- ("test_table", "5", {"data_1": 2, "data_2": 1}),
- ],
-)
-def test_create_with_non_empty_db_with_id(db_memory_with_data, table, id, data):
- data_to_insert = data
- data_to_insert["_id"] = id
- returned_id = db_memory_with_data.create(table, data_to_insert)
- assert returned_id == id
- assert len(db_memory_with_data.db) == (1 if table == "test" else 2)
- assert table in db_memory_with_data.db
- assert len(db_memory_with_data.db[table]) == (4 if table == "test" else 1)
- assert data_to_insert in db_memory_with_data.db[table]
-
-
-@pytest.mark.parametrize(
- "table, data",
- [
- ("test", {"data": 1}),
- ("test", {"data": 2}),
- ("test", {"data": 1}),
- ("test", {"data": 2}),
- ("test_table", {"data": 1}),
- ("test_table", {"data": 2}),
- ("test_table", {"data": 1}),
- ("test_table", {"data": 2}),
- ("test", {"data_1": 1, "data_2": 2}),
- ("test", {"data_1": 2, "data_2": 1}),
- ("test", {"data_1": 1, "data_2": 2}),
- ("test", {"data_1": 2, "data_2": 1}),
- ("test_table", {"data_1": 1, "data_2": 2}),
- ("test_table", {"data_1": 2, "data_2": 1}),
- ("test_table", {"data_1": 1, "data_2": 2}),
- ("test_table", {"data_1": 2, "data_2": 1}),
- ],
-)
-def test_create_with_empty_db_without_id(db_memory, table, data):
- returned_id = db_memory.create(table, data)
- assert len(db_memory.db) == 1
- assert table in db_memory.db
- assert len(db_memory.db[table]) == 1
- data_inserted = data
- data_inserted["_id"] = returned_id
- assert data_inserted in db_memory.db[table]
-
-
-@pytest.mark.parametrize(
- "table, data",
- [
- ("test", {"data": 1}),
- ("test", {"data": 2}),
- ("test", {"data": 1}),
- ("test", {"data": 2}),
- ("test_table", {"data": 1}),
- ("test_table", {"data": 2}),
- ("test_table", {"data": 1}),
- ("test_table", {"data": 2}),
- ("test", {"data_1": 1, "data_2": 2}),
- ("test", {"data_1": 2, "data_2": 1}),
- ("test", {"data_1": 1, "data_2": 2}),
- ("test", {"data_1": 2, "data_2": 1}),
- ("test_table", {"data_1": 1, "data_2": 2}),
- ("test_table", {"data_1": 2, "data_2": 1}),
- ("test_table", {"data_1": 1, "data_2": 2}),
- ("test_table", {"data_1": 2, "data_2": 1}),
- ],
-)
-def test_create_with_non_empty_db_without_id(db_memory_with_data, table, data):
- returned_id = db_memory_with_data.create(table, data)
- assert len(db_memory_with_data.db) == (1 if table == "test" else 2)
- assert table in db_memory_with_data.db
- assert len(db_memory_with_data.db[table]) == (4 if table == "test" else 1)
- data_inserted = data
- data_inserted["_id"] = returned_id
- assert data_inserted in db_memory_with_data.db[table]
-
-
-def test_create_with_exception(db_memory):
- table = "test"
- data = {"_id": 1, "data": 1}
- db_memory.db = MagicMock()
- db_memory.db.__contains__.side_effect = Exception()
- with pytest.raises(DbException) as excinfo:
- db_memory.create(table, data)
- assert str(excinfo.value) == empty_exception_message()
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "db_content, update_dict, expected, message",
- [
- (
- {"a": {"none": None}},
- {"a.b.num": "v"},
- {"a": {"none": None, "b": {"num": "v"}}},
- "create dict",
- ),
- (
- {"a": {"none": None}},
- {"a.none.num": "v"},
- {"a": {"none": {"num": "v"}}},
- "create dict over none",
- ),
- (
- {"a": {"b": {"num": 4}}},
- {"a.b.num": "v"},
- {"a": {"b": {"num": "v"}}},
- "replace_number",
- ),
- (
- {"a": {"b": {"num": 4}}},
- {"a.b.num.c.d": "v"},
- None,
- "create dict over number should fail",
- ),
- (
- {"a": {"b": {"num": 4}}},
- {"a.b": "v"},
- {"a": {"b": "v"}},
- "replace dict with a string",
- ),
- (
- {"a": {"b": {"num": 4}}},
- {"a.b": None},
- {"a": {"b": None}},
- "replace dict with None",
- ),
- (
- {"a": [{"b": {"num": 4}}]},
- {"a.b.num": "v"},
- None,
- "create dict over list should fail",
- ),
- (
- {"a": [{"b": {"num": 4}}]},
- {"a.0.b.num": "v"},
- {"a": [{"b": {"num": "v"}}]},
- "set list",
- ),
- (
- {"a": [{"b": {"num": 4}}]},
- {"a.3.b.num": "v"},
- {"a": [{"b": {"num": 4}}, None, None, {"b": {"num": "v"}}]},
- "expand list",
- ),
- ({"a": [[4]]}, {"a.0.0": "v"}, {"a": [["v"]]}, "set nested list"),
- ({"a": [[4]]}, {"a.0.2": "v"}, {"a": [[4, None, "v"]]}, "expand nested list"),
- (
- {"a": [[4]]},
- {"a.2.2": "v"},
- {"a": [[4], None, {"2": "v"}]},
- "expand list and add number key",
- ),
- ],
-)
-def test_set_one(db_memory, db_content, update_dict, expected, message):
- db_memory._find = Mock(return_value=((0, db_content),))
- if expected is None:
- with pytest.raises(DbException) as excinfo:
- db_memory.set_one("table", {}, update_dict)
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND, message
- else:
- db_memory.set_one("table", {}, update_dict)
- assert db_content == expected, message
-
-
-class TestDbMemory(unittest.TestCase):
- # TODO to delete. This is cover with pytest test_set_one.
- def test_set_one(self):
- test_set = (
- # (database content, set-content, expected database content (None=fails), message)
- (
- {"a": {"none": None}},
- {"a.b.num": "v"},
- {"a": {"none": None, "b": {"num": "v"}}},
- "create dict",
- ),
- (
- {"a": {"none": None}},
- {"a.none.num": "v"},
- {"a": {"none": {"num": "v"}}},
- "create dict over none",
- ),
- (
- {"a": {"b": {"num": 4}}},
- {"a.b.num": "v"},
- {"a": {"b": {"num": "v"}}},
- "replace_number",
- ),
- (
- {"a": {"b": {"num": 4}}},
- {"a.b.num.c.d": "v"},
- None,
- "create dict over number should fail",
- ),
- (
- {"a": {"b": {"num": 4}}},
- {"a.b": "v"},
- {"a": {"b": "v"}},
- "replace dict with a string",
- ),
- (
- {"a": {"b": {"num": 4}}},
- {"a.b": None},
- {"a": {"b": None}},
- "replace dict with None",
- ),
- (
- {"a": [{"b": {"num": 4}}]},
- {"a.b.num": "v"},
- None,
- "create dict over list should fail",
- ),
- (
- {"a": [{"b": {"num": 4}}]},
- {"a.0.b.num": "v"},
- {"a": [{"b": {"num": "v"}}]},
- "set list",
- ),
- (
- {"a": [{"b": {"num": 4}}]},
- {"a.3.b.num": "v"},
- {"a": [{"b": {"num": 4}}, None, None, {"b": {"num": "v"}}]},
- "expand list",
- ),
- ({"a": [[4]]}, {"a.0.0": "v"}, {"a": [["v"]]}, "set nested list"),
- (
- {"a": [[4]]},
- {"a.0.2": "v"},
- {"a": [[4, None, "v"]]},
- "expand nested list",
- ),
- (
- {"a": [[4]]},
- {"a.2.2": "v"},
- {"a": [[4], None, {"2": "v"}]},
- "expand list and add number key",
- ),
- ({"a": None}, {"b.c": "v"}, {"a": None, "b": {"c": "v"}}, "expand at root"),
- )
- db_men = DbMemory()
- db_men._find = Mock()
- for db_content, update_dict, expected, message in test_set:
- db_men._find.return_value = ((0, db_content),)
- if expected is None:
- self.assertRaises(DbException, db_men.set_one, "table", {}, update_dict)
- else:
- db_men.set_one("table", {}, update_dict)
- self.assertEqual(db_content, expected, message)
-
- def test_set_one_pull(self):
- example = {"a": [1, "1", 1], "d": {}, "n": None}
- test_set = (
- # (database content, set-content, expected database content (None=fails), message)
- (example, {"a": "1"}, {"a": [1, 1], "d": {}, "n": None}, "pull one item"),
- (example, {"a": 1}, {"a": ["1"], "d": {}, "n": None}, "pull two items"),
- (example, {"a": "v"}, example, "pull non existing item"),
- (example, {"a.6": 1}, example, "pull non existing arrray"),
- (example, {"d.b.c": 1}, example, "pull non existing arrray2"),
- (example, {"b": 1}, example, "pull non existing arrray3"),
- (example, {"d": 1}, None, "pull over dict"),
- (example, {"n": 1}, None, "pull over None"),
- )
- db_men = DbMemory()
- db_men._find = Mock()
- for db_content, pull_dict, expected, message in test_set:
- db_content = deepcopy(db_content)
- db_men._find.return_value = ((0, db_content),)
- if expected is None:
- self.assertRaises(
- DbException,
- db_men.set_one,
- "table",
- {},
- None,
- fail_on_empty=False,
- pull=pull_dict,
- )
- else:
- db_men.set_one("table", {}, None, pull=pull_dict)
- self.assertEqual(db_content, expected, message)
-
- def test_set_one_push(self):
- example = {"a": [1, "1", 1], "d": {}, "n": None}
- test_set = (
- # (database content, set-content, expected database content (None=fails), message)
- (
- example,
- {"d.b.c": 1},
- {"a": [1, "1", 1], "d": {"b": {"c": [1]}}, "n": None},
- "push non existing arrray2",
- ),
- (
- example,
- {"b": 1},
- {"a": [1, "1", 1], "d": {}, "b": [1], "n": None},
- "push non existing arrray3",
- ),
- (
- example,
- {"a.6": 1},
- {"a": [1, "1", 1, None, None, None, [1]], "d": {}, "n": None},
- "push non existing arrray",
- ),
- (
- example,
- {"a": 2},
- {"a": [1, "1", 1, 2], "d": {}, "n": None},
- "push one item",
- ),
- (
- example,
- {"a": {1: 1}},
- {"a": [1, "1", 1, {1: 1}], "d": {}, "n": None},
- "push a dict",
- ),
- (example, {"d": 1}, None, "push over dict"),
- (example, {"n": 1}, None, "push over None"),
- )
- db_men = DbMemory()
- db_men._find = Mock()
- for db_content, push_dict, expected, message in test_set:
- db_content = deepcopy(db_content)
- db_men._find.return_value = ((0, db_content),)
- if expected is None:
- self.assertRaises(
- DbException,
- db_men.set_one,
- "table",
- {},
- None,
- fail_on_empty=False,
- push=push_dict,
- )
- else:
- db_men.set_one("table", {}, None, push=push_dict)
- self.assertEqual(db_content, expected, message)
-
- def test_set_one_push_list(self):
- example = {"a": [1, "1", 1], "d": {}, "n": None}
- test_set = (
- # (database content, set-content, expected database content (None=fails), message)
- (
- example,
- {"d.b.c": [1]},
- {"a": [1, "1", 1], "d": {"b": {"c": [1]}}, "n": None},
- "push non existing arrray2",
- ),
- (
- example,
- {"b": [1]},
- {"a": [1, "1", 1], "d": {}, "b": [1], "n": None},
- "push non existing arrray3",
- ),
- (
- example,
- {"a.6": [1]},
- {"a": [1, "1", 1, None, None, None, [1]], "d": {}, "n": None},
- "push non existing arrray",
- ),
- (
- example,
- {"a": [2, 3]},
- {"a": [1, "1", 1, 2, 3], "d": {}, "n": None},
- "push two item",
- ),
- (
- example,
- {"a": [{1: 1}]},
- {"a": [1, "1", 1, {1: 1}], "d": {}, "n": None},
- "push a dict",
- ),
- (example, {"d": [1]}, None, "push over dict"),
- (example, {"n": [1]}, None, "push over None"),
- (example, {"a": 1}, None, "invalid push list non an array"),
- )
- db_men = DbMemory()
- db_men._find = Mock()
- for db_content, push_list, expected, message in test_set:
- db_content = deepcopy(db_content)
- db_men._find.return_value = ((0, db_content),)
- if expected is None:
- self.assertRaises(
- DbException,
- db_men.set_one,
- "table",
- {},
- None,
- fail_on_empty=False,
- push_list=push_list,
- )
- else:
- db_men.set_one("table", {}, None, push_list=push_list)
- self.assertEqual(db_content, expected, message)
-
- def test_unset_one(self):
- example = {"a": [1, "1", 1], "d": {}, "n": None}
- test_set = (
- # (database content, set-content, expected database content (None=fails), message)
- (example, {"d.b.c": 1}, example, "unset non existing"),
- (example, {"b": 1}, example, "unset non existing"),
- (example, {"a.6": 1}, example, "unset non existing arrray"),
- (example, {"a": 2}, {"d": {}, "n": None}, "unset array"),
- (example, {"d": 1}, {"a": [1, "1", 1], "n": None}, "unset dict"),
- (example, {"n": 1}, {"a": [1, "1", 1], "d": {}}, "unset None"),
- )
- db_men = DbMemory()
- db_men._find = Mock()
- for db_content, unset_dict, expected, message in test_set:
- db_content = deepcopy(db_content)
- db_men._find.return_value = ((0, db_content),)
- if expected is None:
- self.assertRaises(
- DbException,
- db_men.set_one,
- "table",
- {},
- None,
- fail_on_empty=False,
- unset=unset_dict,
- )
- else:
- db_men.set_one("table", {}, None, unset=unset_dict)
- self.assertEqual(db_content, expected, message)
+++ /dev/null
-#######################################################################################
-# Copyright ETSI Contributors and Others.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#######################################################################################
-
-import logging
-from urllib.parse import quote
-
-from osm_common.dbbase import DbException, FakeLock
-from osm_common.dbmongo import DbMongo
-from pymongo import MongoClient
-import pytest
-
-
-def db_status_exception_message():
- return "database exception Wrong database status"
-
-
-def db_version_exception_message():
- return "database exception Invalid database version"
-
-
-def mock_get_one_status_not_enabled(a, b, c, fail_on_empty=False, fail_on_more=True):
- return {"status": "ERROR", "version": "", "serial": ""}
-
-
-def mock_get_one_wrong_db_version(a, b, c, fail_on_empty=False, fail_on_more=True):
- return {"status": "ENABLED", "version": "4.0", "serial": "MDY4OA=="}
-
-
-def db_generic_exception(exception):
- return exception
-
-
-def db_generic_exception_message(message):
- return f"database exception {message}"
-
-
-def test_constructor():
- db = DbMongo(lock=True)
- assert db.logger == logging.getLogger("db")
- assert db.db is None
- assert db.client is None
- assert db.database_key is None
- assert db.secret_obtained is False
- assert db.lock.acquire() is True
-
-
-def test_constructor_with_logger():
- logger_name = "db_mongo"
- db = DbMongo(logger_name=logger_name, lock=False)
- assert db.logger == logging.getLogger(logger_name)
- assert db.db is None
- assert db.client is None
- assert db.database_key is None
- assert db.secret_obtained is False
- assert type(db.lock) == FakeLock
-
-
-@pytest.mark.parametrize(
- "config, target_version, serial, lock",
- [
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": "mongo:27017",
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "5.0",
- "MDY=",
- True,
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "masterpassword": "master",
- "uri": "mongo:27017",
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "5.0",
- "MDY=",
- False,
- ),
- (
- {
- "logger_name": "logger",
- "uri": "mongo:27017",
- "name": "newdb",
- "commonkey": "common",
- },
- "3.6",
- "",
- True,
- ),
- (
- {
- "uri": "mongo:27017",
- "commonkey": "common",
- "name": "newdb",
- },
- "5.0",
- "MDIy",
- False,
- ),
- (
- {
- "uri": "mongo:27017",
- "masterpassword": "common",
- "name": "newdb",
- "loglevel": "CRITICAL",
- },
- "4.4",
- "OTA=",
- False,
- ),
- (
- {
- "uri": "mongo",
- "masterpassword": "common",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "4.4",
- "OTA=",
- True,
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": quote("user4:password4@mongo"),
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "5.0",
- "NTM=",
- True,
- ),
- (
- {
- "logger_name": "logger",
- "uri": quote("user3:password3@mongo:27017"),
- "name": "newdb",
- "commonkey": "common",
- },
- "4.0",
- "NjEx",
- False,
- ),
- (
- {
- "uri": quote("user2:password2@mongo:27017"),
- "commonkey": "common",
- "name": "newdb",
- },
- "5.0",
- "cmV0MzI=",
- False,
- ),
- (
- {
- "uri": quote("user1:password1@mongo:27017"),
- "commonkey": "common",
- "masterpassword": "master",
- "name": "newdb",
- "loglevel": "CRITICAL",
- },
- "4.0",
- "MjMyNQ==",
- False,
- ),
- (
- {
- "uri": quote("user1:password1@mongo"),
- "masterpassword": "common",
- "name": "newdb",
- "loglevel": "CRITICAL",
- },
- "4.0",
- "MjMyNQ==",
- True,
- ),
- ],
-)
-def test_db_connection_with_valid_config(
- config, target_version, serial, lock, monkeypatch
-):
- def mock_get_one(a, b, c, fail_on_empty=False, fail_on_more=True):
- return {"status": "ENABLED", "version": target_version, "serial": serial}
-
- monkeypatch.setattr(DbMongo, "get_one", mock_get_one)
- db = DbMongo(lock=lock)
- db.db_connect(config, target_version)
- assert (
- db.logger == logging.getLogger(config.get("logger_name"))
- if config.get("logger_name")
- else logging.getLogger("db")
- )
- assert type(db.client) == MongoClient
- assert db.database_key == "common"
- assert db.logger.getEffectiveLevel() == 50 if config.get("loglevel") else 20
-
-
-@pytest.mark.parametrize(
- "config, target_version, version_data, expected_exception_message",
- [
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": "mongo:27017",
- "replicaset": "rs2",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "4.0",
- mock_get_one_status_not_enabled,
- db_status_exception_message(),
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": "mongo:27017",
- "replicaset": "rs4",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "5.0",
- mock_get_one_wrong_db_version,
- db_version_exception_message(),
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": quote("user2:pa@word2@mongo:27017"),
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "DEBUG",
- },
- "4.0",
- mock_get_one_status_not_enabled,
- db_status_exception_message(),
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": quote("username:pass1rd@mongo:27017"),
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "DEBUG",
- },
- "5.0",
- mock_get_one_wrong_db_version,
- db_version_exception_message(),
- ),
- ],
-)
-def test_db_connection_db_status_error(
- config, target_version, version_data, expected_exception_message, monkeypatch
-):
- monkeypatch.setattr(DbMongo, "get_one", version_data)
- db = DbMongo(lock=False)
- with pytest.raises(DbException) as exception_info:
- db.db_connect(config, target_version)
- assert str(exception_info.value).startswith(expected_exception_message)
-
-
-@pytest.mark.parametrize(
- "config, target_version, lock, expected_exception",
- [
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": "27017@/:",
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "4.0",
- True,
- db_generic_exception(DbException),
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": "user@pass",
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "4.0",
- False,
- db_generic_exception(DbException),
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": "user@pass:27017",
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "4.0",
- True,
- db_generic_exception(DbException),
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": "",
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "5.0",
- False,
- db_generic_exception(TypeError),
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": "user2::@mon:27017",
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "DEBUG",
- },
- "4.0",
- True,
- db_generic_exception(ValueError),
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "replicaset": 33,
- "uri": "user2@@mongo:27017",
- "name": "osmdb",
- "loglevel": "DEBUG",
- },
- "5.0",
- False,
- db_generic_exception(TypeError),
- ),
- ],
-)
-def test_db_connection_with_invalid_uri(
- config, target_version, lock, expected_exception, monkeypatch
-):
- def mock_get_one(a, b, c, fail_on_empty=False, fail_on_more=True):
- pass
-
- monkeypatch.setattr(DbMongo, "get_one", mock_get_one)
- db = DbMongo(lock=lock)
- with pytest.raises(expected_exception) as exception_info:
- db.db_connect(config, target_version)
- assert type(exception_info.value) == expected_exception
-
-
-@pytest.mark.parametrize(
- "config, target_version, expected_exception",
- [
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "",
- db_generic_exception(TypeError),
- ),
- (
- {
- "logger_name": "mongo_logger",
- "uri": "mongo:27017",
- "replicaset": "rs0",
- "loglevel": "CRITICAL",
- },
- "4.0",
- db_generic_exception(KeyError),
- ),
- (
- {
- "replicaset": "rs0",
- "loglevel": "CRITICAL",
- },
- None,
- db_generic_exception(KeyError),
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": "",
- "replicaset": "rs0",
- "name": "osmdb",
- "loglevel": "CRITICAL",
- },
- "5.0",
- db_generic_exception(TypeError),
- ),
- (
- {
- "logger_name": "mongo_logger",
- "name": "osmdb",
- },
- "4.0",
- db_generic_exception(TypeError),
- ),
- (
- {
- "logger_name": "logger",
- "replicaset": "",
- "uri": "user2@@mongo:27017",
- },
- "5.0",
- db_generic_exception(KeyError),
- ),
- ],
-)
-def test_db_connection_with_missing_parameters(
- config, target_version, expected_exception, monkeypatch
-):
- def mock_get_one(a, b, c, fail_on_empty=False, fail_on_more=True):
- return
-
- monkeypatch.setattr(DbMongo, "get_one", mock_get_one)
- db = DbMongo(lock=False)
- with pytest.raises(expected_exception) as exception_info:
- db.db_connect(config, target_version)
- assert type(exception_info.value) == expected_exception
-
-
-@pytest.mark.parametrize(
- "config, expected_exception_message",
- [
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": "mongo:27017",
- "replicaset": "rs0",
- "name": "osmdb1",
- "loglevel": "CRITICAL",
- },
- "MongoClient crashed",
- ),
- (
- {
- "logger_name": "mongo_logger",
- "commonkey": "common",
- "uri": "username:pas1ed@mongo:27017",
- "replicaset": "rs1",
- "name": "osmdb2",
- "loglevel": "DEBUG",
- },
- "MongoClient crashed",
- ),
- ],
-)
-def test_db_connection_with_invalid_mongoclient(
- config, expected_exception_message, monkeypatch
-):
- def generate_exception(a, b, replicaSet=None):
- raise DbException(expected_exception_message)
-
- monkeypatch.setattr(MongoClient, "__init__", generate_exception)
- db = DbMongo()
- with pytest.raises(DbException) as exception_info:
- db.db_connect(config)
- assert str(exception_info.value) == db_generic_exception_message(
- expected_exception_message
- )
+++ /dev/null
-# Copyright 2018 Whitestack, LLC
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
-##
-
-import http
-
-from osm_common.fsbase import FsBase, FsException
-import pytest
-
-
-def exception_message(message):
- return "storage exception " + message
-
-
-@pytest.fixture
-def fs_base():
- return FsBase()
-
-
-def test_constructor():
- fs_base = FsBase()
- assert fs_base is not None
- assert isinstance(fs_base, FsBase)
-
-
-def test_get_params(fs_base):
- params = fs_base.get_params()
- assert isinstance(params, dict)
- assert len(params) == 0
-
-
-def test_fs_connect(fs_base):
- fs_base.fs_connect(None)
-
-
-def test_fs_disconnect(fs_base):
- fs_base.fs_disconnect()
-
-
-def test_mkdir(fs_base):
- with pytest.raises(FsException) as excinfo:
- fs_base.mkdir(None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'mkdir' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-def test_file_exists(fs_base):
- with pytest.raises(FsException) as excinfo:
- fs_base.file_exists(None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'file_exists' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-def test_file_size(fs_base):
- with pytest.raises(FsException) as excinfo:
- fs_base.file_size(None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'file_size' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-def test_file_extract(fs_base):
- with pytest.raises(FsException) as excinfo:
- fs_base.file_extract(None, None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'file_extract' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-def test_file_open(fs_base):
- with pytest.raises(FsException) as excinfo:
- fs_base.file_open(None, None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'file_open' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-def test_file_delete(fs_base):
- with pytest.raises(FsException) as excinfo:
- fs_base.file_delete(None, None)
- assert str(excinfo.value).startswith(
- exception_message("Method 'file_delete' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+++ /dev/null
-# Copyright 2018 Whitestack, LLC
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
-##
-
-import http
-import io
-import logging
-import os
-import shutil
-import tarfile
-import tempfile
-import uuid
-
-from osm_common.fsbase import FsException
-from osm_common.fslocal import FsLocal
-import pytest
-
-__author__ = "Eduardo Sousa <eduardosousa@av.it.pt>"
-
-
-def valid_path():
- return tempfile.gettempdir() + "/"
-
-
-def invalid_path():
- return "/#tweeter/"
-
-
-@pytest.fixture(scope="function", params=[True, False])
-def fs_local(request):
- fs = FsLocal(lock=request.param)
- fs.fs_connect({"path": valid_path()})
- return fs
-
-
-def fs_connect_exception_message(path):
- return "storage exception Invalid configuration param at '[storage]': path '{}' does not exist".format(
- path
- )
-
-
-def file_open_file_not_found_exception(storage):
- f = storage if isinstance(storage, str) else "/".join(storage)
- return "storage exception File {} does not exist".format(f)
-
-
-def file_open_io_exception(storage):
- f = storage if isinstance(storage, str) else "/".join(storage)
- return "storage exception File {} cannot be opened".format(f)
-
-
-def dir_ls_not_a_directory_exception(storage):
- f = storage if isinstance(storage, str) else "/".join(storage)
- return "storage exception File {} does not exist".format(f)
-
-
-def dir_ls_io_exception(storage):
- f = storage if isinstance(storage, str) else "/".join(storage)
- return "storage exception File {} cannot be opened".format(f)
-
-
-def file_delete_exception_message(storage):
- return "storage exception File {} does not exist".format(storage)
-
-
-def test_constructor_without_logger():
- fs = FsLocal()
- assert fs.logger == logging.getLogger("fs")
- assert fs.path is None
-
-
-def test_constructor_with_logger():
- logger_name = "fs_local"
- fs = FsLocal(logger_name=logger_name)
- assert fs.logger == logging.getLogger(logger_name)
- assert fs.path is None
-
-
-def test_get_params(fs_local):
- params = fs_local.get_params()
- assert len(params) == 2
- assert "fs" in params
- assert "path" in params
- assert params["fs"] == "local"
- assert params["path"] == valid_path()
-
-
-@pytest.mark.parametrize(
- "config, exp_logger, exp_path",
- [
- ({"logger_name": "fs_local", "path": valid_path()}, "fs_local", valid_path()),
- (
- {"logger_name": "fs_local", "path": valid_path()[:-1]},
- "fs_local",
- valid_path(),
- ),
- ({"path": valid_path()}, "fs", valid_path()),
- ({"path": valid_path()[:-1]}, "fs", valid_path()),
- ],
-)
-def test_fs_connect_with_valid_config(config, exp_logger, exp_path):
- fs = FsLocal()
- fs.fs_connect(config)
- assert fs.logger == logging.getLogger(exp_logger)
- assert fs.path == exp_path
-
-
-@pytest.mark.parametrize(
- "config, exp_exception_message",
- [
- (
- {"logger_name": "fs_local", "path": invalid_path()},
- fs_connect_exception_message(invalid_path()),
- ),
- (
- {"logger_name": "fs_local", "path": invalid_path()[:-1]},
- fs_connect_exception_message(invalid_path()[:-1]),
- ),
- ({"path": invalid_path()}, fs_connect_exception_message(invalid_path())),
- (
- {"path": invalid_path()[:-1]},
- fs_connect_exception_message(invalid_path()[:-1]),
- ),
- ],
-)
-def test_fs_connect_with_invalid_path(config, exp_exception_message):
- fs = FsLocal()
- with pytest.raises(FsException) as excinfo:
- fs.fs_connect(config)
- assert str(excinfo.value) == exp_exception_message
-
-
-def test_fs_disconnect(fs_local):
- fs_local.fs_disconnect()
-
-
-def test_mkdir_with_valid_path(fs_local):
- folder_name = str(uuid.uuid4())
- folder_path = valid_path() + folder_name
- fs_local.mkdir(folder_name)
- assert os.path.exists(folder_path)
- # test idempotency
- fs_local.mkdir(folder_name)
- assert os.path.exists(folder_path)
- os.rmdir(folder_path)
-
-
-def test_mkdir_with_exception(fs_local):
- folder_name = str(uuid.uuid4())
- with pytest.raises(FsException) as excinfo:
- fs_local.mkdir(folder_name + "/" + folder_name)
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-@pytest.mark.parametrize(
- "storage, mode, expected",
- [
- (str(uuid.uuid4()), "file", False),
- ([str(uuid.uuid4())], "file", False),
- (str(uuid.uuid4()), "dir", False),
- ([str(uuid.uuid4())], "dir", False),
- ],
-)
-def test_file_exists_returns_false(fs_local, storage, mode, expected):
- assert fs_local.file_exists(storage, mode) == expected
-
-
-@pytest.mark.parametrize(
- "storage, mode, expected",
- [
- (str(uuid.uuid4()), "file", True),
- ([str(uuid.uuid4())], "file", True),
- (str(uuid.uuid4()), "dir", True),
- ([str(uuid.uuid4())], "dir", True),
- ],
-)
-def test_file_exists_returns_true(fs_local, storage, mode, expected):
- path = (
- valid_path() + storage
- if isinstance(storage, str)
- else valid_path() + storage[0]
- )
- if mode == "file":
- os.mknod(path)
- elif mode == "dir":
- os.mkdir(path)
- assert fs_local.file_exists(storage, mode) == expected
- if mode == "file":
- os.remove(path)
- elif mode == "dir":
- os.rmdir(path)
-
-
-@pytest.mark.parametrize(
- "storage, mode",
- [
- (str(uuid.uuid4()), "file"),
- ([str(uuid.uuid4())], "file"),
- (str(uuid.uuid4()), "dir"),
- ([str(uuid.uuid4())], "dir"),
- ],
-)
-def test_file_size(fs_local, storage, mode):
- path = (
- valid_path() + storage
- if isinstance(storage, str)
- else valid_path() + storage[0]
- )
- if mode == "file":
- os.mknod(path)
- elif mode == "dir":
- os.mkdir(path)
- size = os.path.getsize(path)
- assert fs_local.file_size(storage) == size
- if mode == "file":
- os.remove(path)
- elif mode == "dir":
- os.rmdir(path)
-
-
-@pytest.mark.parametrize(
- "files, path",
- [
- (["foo", "bar", "foobar"], str(uuid.uuid4())),
- (["foo", "bar", "foobar"], [str(uuid.uuid4())]),
- ],
-)
-def test_file_extract(fs_local, files, path):
- for f in files:
- os.mknod(valid_path() + f)
- tar_path = valid_path() + str(uuid.uuid4()) + ".tar"
- with tarfile.open(tar_path, "w") as tar:
- for f in files:
- tar.add(valid_path() + f, arcname=f)
- with tarfile.open(tar_path, "r") as tar:
- fs_local.file_extract(tar, path)
- extracted_path = valid_path() + (path if isinstance(path, str) else "/".join(path))
- ls_dir = os.listdir(extracted_path)
- assert len(ls_dir) == len(files)
- for f in files:
- assert f in ls_dir
- os.remove(tar_path)
- for f in files:
- os.remove(valid_path() + f)
- shutil.rmtree(extracted_path)
-
-
-@pytest.mark.parametrize(
- "storage, mode",
- [
- (str(uuid.uuid4()), "r"),
- (str(uuid.uuid4()), "w"),
- (str(uuid.uuid4()), "a"),
- (str(uuid.uuid4()), "rb"),
- (str(uuid.uuid4()), "wb"),
- (str(uuid.uuid4()), "ab"),
- ([str(uuid.uuid4())], "r"),
- ([str(uuid.uuid4())], "w"),
- ([str(uuid.uuid4())], "a"),
- ([str(uuid.uuid4())], "rb"),
- ([str(uuid.uuid4())], "wb"),
- ([str(uuid.uuid4())], "ab"),
- ],
-)
-def test_file_open(fs_local, storage, mode):
- path = (
- valid_path() + storage
- if isinstance(storage, str)
- else valid_path() + storage[0]
- )
- os.mknod(path)
- file_obj = fs_local.file_open(storage, mode)
- assert isinstance(file_obj, io.IOBase)
- assert file_obj.closed is False
- os.remove(path)
-
-
-@pytest.mark.parametrize(
- "storage, mode",
- [
- (str(uuid.uuid4()), "r"),
- (str(uuid.uuid4()), "rb"),
- ([str(uuid.uuid4())], "r"),
- ([str(uuid.uuid4())], "rb"),
- ],
-)
-def test_file_open_file_not_found_exception(fs_local, storage, mode):
- with pytest.raises(FsException) as excinfo:
- fs_local.file_open(storage, mode)
- assert str(excinfo.value) == file_open_file_not_found_exception(storage)
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize(
- "storage, mode",
- [
- (str(uuid.uuid4()), "r"),
- (str(uuid.uuid4()), "w"),
- (str(uuid.uuid4()), "a"),
- (str(uuid.uuid4()), "rb"),
- (str(uuid.uuid4()), "wb"),
- (str(uuid.uuid4()), "ab"),
- ([str(uuid.uuid4())], "r"),
- ([str(uuid.uuid4())], "w"),
- ([str(uuid.uuid4())], "a"),
- ([str(uuid.uuid4())], "rb"),
- ([str(uuid.uuid4())], "wb"),
- ([str(uuid.uuid4())], "ab"),
- ],
-)
-def test_file_open_io_error(fs_local, storage, mode):
- path = (
- valid_path() + storage
- if isinstance(storage, str)
- else valid_path() + storage[0]
- )
- os.mknod(path)
- os.chmod(path, 0)
- with pytest.raises(FsException) as excinfo:
- fs_local.file_open(storage, mode)
- assert str(excinfo.value) == file_open_io_exception(storage)
- assert excinfo.value.http_code == http.HTTPStatus.BAD_REQUEST
- os.remove(path)
-
-
-@pytest.mark.parametrize(
- "storage, with_files",
- [
- (str(uuid.uuid4()), True),
- (str(uuid.uuid4()), False),
- ([str(uuid.uuid4())], True),
- ([str(uuid.uuid4())], False),
- ],
-)
-def test_dir_ls(fs_local, storage, with_files):
- path = (
- valid_path() + storage
- if isinstance(storage, str)
- else valid_path() + storage[0]
- )
- os.mkdir(path)
- if with_files is True:
- file_name = str(uuid.uuid4())
- file_path = path + "/" + file_name
- os.mknod(file_path)
- result = fs_local.dir_ls(storage)
-
- if with_files is True:
- assert len(result) == 1
- assert result[0] == file_name
- else:
- assert len(result) == 0
- shutil.rmtree(path)
-
-
-@pytest.mark.parametrize("storage", [(str(uuid.uuid4())), ([str(uuid.uuid4())])])
-def test_dir_ls_with_not_a_directory_error(fs_local, storage):
- path = (
- valid_path() + storage
- if isinstance(storage, str)
- else valid_path() + storage[0]
- )
- os.mknod(path)
- with pytest.raises(FsException) as excinfo:
- fs_local.dir_ls(storage)
- assert str(excinfo.value) == dir_ls_not_a_directory_exception(storage)
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
- os.remove(path)
-
-
-@pytest.mark.parametrize("storage", [(str(uuid.uuid4())), ([str(uuid.uuid4())])])
-def test_dir_ls_with_io_error(fs_local, storage):
- path = (
- valid_path() + storage
- if isinstance(storage, str)
- else valid_path() + storage[0]
- )
- os.mkdir(path)
- os.chmod(path, 0)
- with pytest.raises(FsException) as excinfo:
- fs_local.dir_ls(storage)
- assert str(excinfo.value) == dir_ls_io_exception(storage)
- assert excinfo.value.http_code == http.HTTPStatus.BAD_REQUEST
- os.rmdir(path)
-
-
-@pytest.mark.parametrize(
- "storage, with_files, ignore_non_exist",
- [
- (str(uuid.uuid4()), True, True),
- (str(uuid.uuid4()), False, True),
- (str(uuid.uuid4()), True, False),
- (str(uuid.uuid4()), False, False),
- ([str(uuid.uuid4())], True, True),
- ([str(uuid.uuid4())], False, True),
- ([str(uuid.uuid4())], True, False),
- ([str(uuid.uuid4())], False, False),
- ],
-)
-def test_file_delete_with_dir(fs_local, storage, with_files, ignore_non_exist):
- path = (
- valid_path() + storage
- if isinstance(storage, str)
- else valid_path() + storage[0]
- )
- os.mkdir(path)
- if with_files is True:
- file_path = path + "/" + str(uuid.uuid4())
- os.mknod(file_path)
- fs_local.file_delete(storage, ignore_non_exist)
- assert os.path.exists(path) is False
-
-
-@pytest.mark.parametrize("storage", [(str(uuid.uuid4())), ([str(uuid.uuid4())])])
-def test_file_delete_expect_exception(fs_local, storage):
- with pytest.raises(FsException) as excinfo:
- fs_local.file_delete(storage)
- assert str(excinfo.value) == file_delete_exception_message(storage)
- assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
-
-
-@pytest.mark.parametrize("storage", [(str(uuid.uuid4())), ([str(uuid.uuid4())])])
-def test_file_delete_no_exception(fs_local, storage):
- path = (
- valid_path() + storage
- if isinstance(storage, str)
- else valid_path() + storage[0]
- )
- fs_local.file_delete(storage, ignore_non_exist=True)
- assert os.path.exists(path) is False
+++ /dev/null
-# Copyright 2019 Canonical
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: eduardo.sousa@canonical.com
-##
-
-from io import BytesIO
-import logging
-import os
-from pathlib import Path
-import subprocess
-import tarfile
-import tempfile
-from unittest.mock import Mock
-
-from gridfs import GridFSBucket
-from osm_common.fsbase import FsException
-from osm_common.fsmongo import FsMongo
-from pymongo import MongoClient
-import pytest
-
-__author__ = "Eduardo Sousa <eduardo.sousa@canonical.com>"
-
-
-def valid_path():
- return tempfile.gettempdir() + "/"
-
-
-def invalid_path():
- return "/#tweeter/"
-
-
-@pytest.fixture(scope="function", params=[True, False])
-def fs_mongo(request, monkeypatch):
- def mock_mongoclient_constructor(a, b):
- pass
-
- def mock_mongoclient_getitem(a, b):
- pass
-
- def mock_gridfs_constructor(a, b):
- pass
-
- monkeypatch.setattr(MongoClient, "__init__", mock_mongoclient_constructor)
- monkeypatch.setattr(MongoClient, "__getitem__", mock_mongoclient_getitem)
- monkeypatch.setattr(GridFSBucket, "__init__", mock_gridfs_constructor)
- fs = FsMongo(lock=request.param)
- fs.fs_connect({"path": valid_path(), "uri": "mongo:27017", "collection": "files"})
- return fs
-
-
-def generic_fs_exception_message(message):
- return "storage exception {}".format(message)
-
-
-def fs_connect_exception_message(path):
- return "storage exception Invalid configuration param at '[storage]': path '{}' does not exist".format(
- path
- )
-
-
-def file_open_file_not_found_exception(storage):
- f = storage if isinstance(storage, str) else "/".join(storage)
- return "storage exception File {} does not exist".format(f)
-
-
-def file_open_io_exception(storage):
- f = storage if isinstance(storage, str) else "/".join(storage)
- return "storage exception File {} cannot be opened".format(f)
-
-
-def dir_ls_not_a_directory_exception(storage):
- f = storage if isinstance(storage, str) else "/".join(storage)
- return "storage exception File {} does not exist".format(f)
-
-
-def dir_ls_io_exception(storage):
- f = storage if isinstance(storage, str) else "/".join(storage)
- return "storage exception File {} cannot be opened".format(f)
-
-
-def file_delete_exception_message(storage):
- return "storage exception File {} does not exist".format(storage)
-
-
-def test_constructor_without_logger():
- fs = FsMongo()
- assert fs.logger == logging.getLogger("fs")
- assert fs.path is None
- assert fs.client is None
- assert fs.fs is None
-
-
-def test_constructor_with_logger():
- logger_name = "fs_mongo"
- fs = FsMongo(logger_name=logger_name)
- assert fs.logger == logging.getLogger(logger_name)
- assert fs.path is None
- assert fs.client is None
- assert fs.fs is None
-
-
-def test_get_params(fs_mongo, monkeypatch):
- def mock_gridfs_find(self, search_query, **kwargs):
- return []
-
- monkeypatch.setattr(GridFSBucket, "find", mock_gridfs_find)
- params = fs_mongo.get_params()
- assert len(params) == 2
- assert "fs" in params
- assert "path" in params
- assert params["fs"] == "mongo"
- assert params["path"] == valid_path()
-
-
-@pytest.mark.parametrize(
- "config, exp_logger, exp_path",
- [
- (
- {
- "logger_name": "fs_mongo",
- "path": valid_path(),
- "uri": "mongo:27017",
- "collection": "files",
- },
- "fs_mongo",
- valid_path(),
- ),
- (
- {
- "logger_name": "fs_mongo",
- "path": valid_path()[:-1],
- "uri": "mongo:27017",
- "collection": "files",
- },
- "fs_mongo",
- valid_path(),
- ),
- (
- {"path": valid_path(), "uri": "mongo:27017", "collection": "files"},
- "fs",
- valid_path(),
- ),
- (
- {"path": valid_path()[:-1], "uri": "mongo:27017", "collection": "files"},
- "fs",
- valid_path(),
- ),
- ],
-)
-def test_fs_connect_with_valid_config(config, exp_logger, exp_path):
- fs = FsMongo()
- fs.fs_connect(config)
- assert fs.logger == logging.getLogger(exp_logger)
- assert fs.path == exp_path
- assert type(fs.client) == MongoClient
- assert type(fs.fs) == GridFSBucket
-
-
-@pytest.mark.parametrize(
- "config, exp_exception_message",
- [
- (
- {
- "logger_name": "fs_mongo",
- "path": invalid_path(),
- "uri": "mongo:27017",
- "collection": "files",
- },
- fs_connect_exception_message(invalid_path()),
- ),
- (
- {
- "logger_name": "fs_mongo",
- "path": invalid_path()[:-1],
- "uri": "mongo:27017",
- "collection": "files",
- },
- fs_connect_exception_message(invalid_path()[:-1]),
- ),
- (
- {"path": invalid_path(), "uri": "mongo:27017", "collection": "files"},
- fs_connect_exception_message(invalid_path()),
- ),
- (
- {"path": invalid_path()[:-1], "uri": "mongo:27017", "collection": "files"},
- fs_connect_exception_message(invalid_path()[:-1]),
- ),
- (
- {"path": "/", "uri": "mongo:27017", "collection": "files"},
- generic_fs_exception_message(
- "Invalid configuration param at '[storage]': path '/' is not writable"
- ),
- ),
- ],
-)
-def test_fs_connect_with_invalid_path(config, exp_exception_message):
- fs = FsMongo()
- with pytest.raises(FsException) as excinfo:
- fs.fs_connect(config)
- assert str(excinfo.value) == exp_exception_message
-
-
-@pytest.mark.parametrize(
- "config, exp_exception_message",
- [
- (
- {"logger_name": "fs_mongo", "uri": "mongo:27017", "collection": "files"},
- 'Missing parameter "path"',
- ),
- (
- {"logger_name": "fs_mongo", "path": valid_path(), "collection": "files"},
- 'Missing parameters: "uri"',
- ),
- (
- {"logger_name": "fs_mongo", "path": valid_path(), "uri": "mongo:27017"},
- 'Missing parameter "collection"',
- ),
- ],
-)
-def test_fs_connect_with_missing_parameters(config, exp_exception_message):
- fs = FsMongo()
- with pytest.raises(FsException) as excinfo:
- fs.fs_connect(config)
- assert str(excinfo.value) == generic_fs_exception_message(exp_exception_message)
-
-
-@pytest.mark.parametrize(
- "config, exp_exception_message",
- [
- (
- {
- "logger_name": "fs_mongo",
- "path": valid_path(),
- "uri": "mongo:27017",
- "collection": "files",
- },
- "MongoClient crashed",
- ),
- ],
-)
-def test_fs_connect_with_invalid_mongoclient(
- config, exp_exception_message, monkeypatch
-):
- def generate_exception(a, b=None):
- raise Exception(exp_exception_message)
-
- monkeypatch.setattr(MongoClient, "__init__", generate_exception)
-
- fs = FsMongo()
- with pytest.raises(FsException) as excinfo:
- fs.fs_connect(config)
- assert str(excinfo.value) == generic_fs_exception_message(exp_exception_message)
-
-
-@pytest.mark.parametrize(
- "config, exp_exception_message",
- [
- (
- {
- "logger_name": "fs_mongo",
- "path": valid_path(),
- "uri": "mongo:27017",
- "collection": "files",
- },
- "Collection unavailable",
- ),
- ],
-)
-def test_fs_connect_with_invalid_mongo_collection(
- config, exp_exception_message, monkeypatch
-):
- def mock_mongoclient_constructor(a, b=None):
- pass
-
- def generate_exception(a, b):
- raise Exception(exp_exception_message)
-
- monkeypatch.setattr(MongoClient, "__init__", mock_mongoclient_constructor)
- monkeypatch.setattr(MongoClient, "__getitem__", generate_exception)
-
- fs = FsMongo()
- with pytest.raises(FsException) as excinfo:
- fs.fs_connect(config)
- assert str(excinfo.value) == generic_fs_exception_message(exp_exception_message)
-
-
-@pytest.mark.parametrize(
- "config, exp_exception_message",
- [
- (
- {
- "logger_name": "fs_mongo",
- "path": valid_path(),
- "uri": "mongo:27017",
- "collection": "files",
- },
- "GridFsBucket crashed",
- ),
- ],
-)
-def test_fs_connect_with_invalid_gridfsbucket(
- config, exp_exception_message, monkeypatch
-):
- def mock_mongoclient_constructor(a, b=None):
- pass
-
- def mock_mongoclient_getitem(a, b):
- pass
-
- def generate_exception(a, b):
- raise Exception(exp_exception_message)
-
- monkeypatch.setattr(MongoClient, "__init__", mock_mongoclient_constructor)
- monkeypatch.setattr(MongoClient, "__getitem__", mock_mongoclient_getitem)
- monkeypatch.setattr(GridFSBucket, "__init__", generate_exception)
-
- fs = FsMongo()
- with pytest.raises(FsException) as excinfo:
- fs.fs_connect(config)
- assert str(excinfo.value) == generic_fs_exception_message(exp_exception_message)
-
-
-def test_fs_disconnect(fs_mongo):
- fs_mongo.fs_disconnect()
-
-
-# Example.tar.gz
-# example_tar/
-# ├── directory
-# │ └── file
-# └── symlinks
-# ├── directory_link -> ../directory/
-# └── file_link -> ../directory/file
-class FakeCursor:
- def __init__(self, id, filename, metadata):
- self._id = id
- self.filename = filename
- self.metadata = metadata
-
-
-class FakeFS:
- directory_metadata = {"type": "dir", "permissions": 509}
- file_metadata = {"type": "file", "permissions": 436}
- symlink_metadata = {"type": "sym", "permissions": 511}
-
- tar_info = {
- 1: {
- "cursor": FakeCursor(1, "example_tar", directory_metadata),
- "metadata": directory_metadata,
- "stream_content": b"",
- "stream_content_bad": b"Something",
- "path": "./tmp/example_tar",
- },
- 2: {
- "cursor": FakeCursor(2, "example_tar/directory", directory_metadata),
- "metadata": directory_metadata,
- "stream_content": b"",
- "stream_content_bad": b"Something",
- "path": "./tmp/example_tar/directory",
- },
- 3: {
- "cursor": FakeCursor(3, "example_tar/symlinks", directory_metadata),
- "metadata": directory_metadata,
- "stream_content": b"",
- "stream_content_bad": b"Something",
- "path": "./tmp/example_tar/symlinks",
- },
- 4: {
- "cursor": FakeCursor(4, "example_tar/directory/file", file_metadata),
- "metadata": file_metadata,
- "stream_content": b"Example test",
- "stream_content_bad": b"Example test2",
- "path": "./tmp/example_tar/directory/file",
- },
- 5: {
- "cursor": FakeCursor(5, "example_tar/symlinks/file_link", symlink_metadata),
- "metadata": symlink_metadata,
- "stream_content": b"../directory/file",
- "stream_content_bad": b"",
- "path": "./tmp/example_tar/symlinks/file_link",
- },
- 6: {
- "cursor": FakeCursor(
- 6, "example_tar/symlinks/directory_link", symlink_metadata
- ),
- "metadata": symlink_metadata,
- "stream_content": b"../directory/",
- "stream_content_bad": b"",
- "path": "./tmp/example_tar/symlinks/directory_link",
- },
- }
-
- def upload_from_stream(self, f, stream, metadata=None):
- found = False
- for i, v in self.tar_info.items():
- if f == v["path"]:
- assert metadata["type"] == v["metadata"]["type"]
- assert stream.read() == BytesIO(v["stream_content"]).read()
- stream.seek(0)
- assert stream.read() != BytesIO(v["stream_content_bad"]).read()
- found = True
- continue
- assert found
-
- def find(self, type, no_cursor_timeout=True, sort=None):
- list = []
- for i, v in self.tar_info.items():
- if type["metadata.type"] == "dir":
- if v["metadata"] == self.directory_metadata:
- list.append(v["cursor"])
- else:
- if v["metadata"] != self.directory_metadata:
- list.append(v["cursor"])
- return list
-
- def download_to_stream(self, id, file_stream):
- file_stream.write(BytesIO(self.tar_info[id]["stream_content"]).read())
-
-
-def test_file_extract():
- tar_path = "tmp/Example.tar.gz"
- folder_path = "tmp/example_tar"
-
- # Generate package
- subprocess.call(["rm", "-rf", "./tmp"])
- subprocess.call(["mkdir", "-p", "{}/directory".format(folder_path)])
- subprocess.call(["mkdir", "-p", "{}/symlinks".format(folder_path)])
- p = Path("{}/directory/file".format(folder_path))
- p.write_text("Example test")
- os.symlink("../directory/file", "{}/symlinks/file_link".format(folder_path))
- os.symlink("../directory/", "{}/symlinks/directory_link".format(folder_path))
- if os.path.exists(tar_path):
- os.remove(tar_path)
- subprocess.call(["tar", "-czvf", tar_path, folder_path])
-
- try:
- tar = tarfile.open(tar_path, "r")
- fs = FsMongo()
- fs.fs = FakeFS()
- fs.file_extract(compressed_object=tar, path=".")
- finally:
- os.remove(tar_path)
- subprocess.call(["rm", "-rf", "./tmp"])
-
-
-def test_upload_local_fs():
- path = "./tmp/"
-
- subprocess.call(["rm", "-rf", path])
- try:
- fs = FsMongo()
- fs.path = path
- fs.fs = FakeFS()
- fs.sync()
- assert os.path.isdir("{}example_tar".format(path))
- assert os.path.isdir("{}example_tar/directory".format(path))
- assert os.path.isdir("{}example_tar/symlinks".format(path))
- assert os.path.isfile("{}example_tar/directory/file".format(path))
- assert os.path.islink("{}example_tar/symlinks/file_link".format(path))
- assert os.path.islink("{}example_tar/symlinks/directory_link".format(path))
- finally:
- subprocess.call(["rm", "-rf", path])
-
-
-def test_upload_mongo_fs():
- path = "./tmp/"
-
- subprocess.call(["rm", "-rf", path])
- try:
- fs = FsMongo()
- fs.path = path
- fs.fs = Mock()
- fs.fs.find.return_value = {}
-
- file_content = "Test file content"
-
- # Create local dir and upload content to fakefs
- os.mkdir(path)
- os.mkdir("{}example_local".format(path))
- os.mkdir("{}example_local/directory".format(path))
- with open(
- "{}example_local/directory/test_file".format(path), "w+"
- ) as test_file:
- test_file.write(file_content)
- fs.reverse_sync("example_local")
-
- assert fs.fs.upload_from_stream.call_count == 2
-
- # first call to upload_from_stream, dir_name
- dir_name = "example_local/directory"
- call_args_0 = fs.fs.upload_from_stream.call_args_list[0]
- assert call_args_0[0][0] == dir_name
- assert call_args_0[1].get("metadata").get("type") == "dir"
-
- # second call to upload_from_stream, dir_name
- file_name = "example_local/directory/test_file"
- call_args_1 = fs.fs.upload_from_stream.call_args_list[1]
- assert call_args_1[0][0] == file_name
- assert call_args_1[1].get("metadata").get("type") == "file"
-
- finally:
- subprocess.call(["rm", "-rf", path])
- pass
+++ /dev/null
-# Copyright 2018 Whitestack, LLC
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
-##
-import asyncio
-import http
-
-from osm_common.msgbase import MsgBase, MsgException
-import pytest
-
-
-def exception_message(message):
- return "messaging exception " + message
-
-
-@pytest.fixture
-def msg_base():
- return MsgBase()
-
-
-def test_constructor():
- msgbase = MsgBase()
- assert msgbase is not None
- assert isinstance(msgbase, MsgBase)
-
-
-def test_connect(msg_base):
- msg_base.connect(None)
-
-
-def test_disconnect(msg_base):
- msg_base.disconnect()
-
-
-def test_write(msg_base):
- with pytest.raises(MsgException) as excinfo:
- msg_base.write("test", "test", "test")
- assert str(excinfo.value).startswith(
- exception_message("Method 'write' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-def test_read(msg_base):
- with pytest.raises(MsgException) as excinfo:
- msg_base.read("test")
- assert str(excinfo.value).startswith(
- exception_message("Method 'read' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-def test_aiowrite(msg_base):
- with pytest.raises(MsgException) as excinfo:
- asyncio.run(msg_base.aiowrite("test", "test", "test"))
- assert str(excinfo.value).startswith(
- exception_message("Method 'aiowrite' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-def test_aioread(msg_base):
- with pytest.raises(MsgException) as excinfo:
- asyncio.run(msg_base.aioread("test"))
- assert str(excinfo.value).startswith(
- exception_message("Method 'aioread' not implemented")
- )
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+++ /dev/null
-# Copyright 2018 Whitestack, LLC
-# Copyright 2018 Telefonica S.A.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# For those usages not covered by the Apache License, Version 2.0 please
-# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
-##
-import asyncio
-import http
-import logging
-import os
-import shutil
-import tempfile
-import threading
-import time
-from unittest.mock import MagicMock
-import uuid
-
-from osm_common.msgbase import MsgException
-from osm_common.msglocal import MsgLocal
-import pytest
-import yaml
-
-__author__ = "Eduardo Sousa <eduardosousa@av.it.pt>"
-
-
-def valid_path():
- return tempfile.gettempdir() + "/"
-
-
-def invalid_path():
- return "/#tweeter/"
-
-
-@pytest.fixture(scope="function", params=[True, False])
-def msg_local(request):
- msg = MsgLocal(lock=request.param)
- yield msg
-
- msg.disconnect()
- if msg.path and msg.path != invalid_path() and msg.path != valid_path():
- shutil.rmtree(msg.path)
-
-
-@pytest.fixture(scope="function", params=[True, False])
-def msg_local_config(request):
- msg = MsgLocal(lock=request.param)
- msg.connect({"path": valid_path() + str(uuid.uuid4())})
- yield msg
-
- msg.disconnect()
- if msg.path != invalid_path():
- shutil.rmtree(msg.path)
-
-
-@pytest.fixture(scope="function", params=[True, False])
-def msg_local_with_data(request):
- msg = MsgLocal(lock=request.param)
- msg.connect({"path": valid_path() + str(uuid.uuid4())})
-
- msg.write("topic1", "key1", "msg1")
- msg.write("topic1", "key2", "msg1")
- msg.write("topic2", "key1", "msg1")
- msg.write("topic2", "key2", "msg1")
- msg.write("topic1", "key1", "msg2")
- msg.write("topic1", "key2", "msg2")
- msg.write("topic2", "key1", "msg2")
- msg.write("topic2", "key2", "msg2")
- yield msg
-
- msg.disconnect()
- if msg.path != invalid_path():
- shutil.rmtree(msg.path)
-
-
-def empty_exception_message():
- return "messaging exception "
-
-
-def test_constructor():
- msg = MsgLocal()
- assert msg.logger == logging.getLogger("msg")
- assert msg.path is None
- assert len(msg.files_read) == 0
- assert len(msg.files_write) == 0
- assert len(msg.buffer) == 0
-
-
-def test_constructor_with_logger():
- logger_name = "msg_local"
- msg = MsgLocal(logger_name=logger_name)
- assert msg.logger == logging.getLogger(logger_name)
- assert msg.path is None
- assert len(msg.files_read) == 0
- assert len(msg.files_write) == 0
- assert len(msg.buffer) == 0
-
-
-@pytest.mark.parametrize(
- "config, logger_name, path",
- [
- ({"logger_name": "msg_local", "path": valid_path()}, "msg_local", valid_path()),
- (
- {"logger_name": "msg_local", "path": valid_path()[:-1]},
- "msg_local",
- valid_path(),
- ),
- (
- {"logger_name": "msg_local", "path": valid_path() + "test_it/"},
- "msg_local",
- valid_path() + "test_it/",
- ),
- (
- {"logger_name": "msg_local", "path": valid_path() + "test_it"},
- "msg_local",
- valid_path() + "test_it/",
- ),
- ({"path": valid_path()}, "msg", valid_path()),
- ({"path": valid_path()[:-1]}, "msg", valid_path()),
- ({"path": valid_path() + "test_it/"}, "msg", valid_path() + "test_it/"),
- ({"path": valid_path() + "test_it"}, "msg", valid_path() + "test_it/"),
- ],
-)
-def test_connect(msg_local, config, logger_name, path):
- msg_local.connect(config)
- assert msg_local.logger == logging.getLogger(logger_name)
- assert msg_local.path == path
- assert len(msg_local.files_read) == 0
- assert len(msg_local.files_write) == 0
- assert len(msg_local.buffer) == 0
-
-
-@pytest.mark.parametrize(
- "config",
- [
- ({"logger_name": "msg_local", "path": invalid_path()}),
- ({"path": invalid_path()}),
- ],
-)
-def test_connect_with_exception(msg_local, config):
- with pytest.raises(MsgException) as excinfo:
- msg_local.connect(config)
- assert str(excinfo.value).startswith(empty_exception_message())
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-def test_disconnect(msg_local_config):
- files_read = msg_local_config.files_read.copy()
- files_write = msg_local_config.files_write.copy()
- msg_local_config.disconnect()
- for f in files_read.values():
- assert f.closed
- for f in files_write.values():
- assert f.closed
-
-
-def test_disconnect_with_read(msg_local_config):
- msg_local_config.read("topic1", blocks=False)
- msg_local_config.read("topic2", blocks=False)
- files_read = msg_local_config.files_read.copy()
- files_write = msg_local_config.files_write.copy()
- msg_local_config.disconnect()
- for f in files_read.values():
- assert f.closed
- for f in files_write.values():
- assert f.closed
-
-
-def test_disconnect_with_write(msg_local_with_data):
- files_read = msg_local_with_data.files_read.copy()
- files_write = msg_local_with_data.files_write.copy()
- msg_local_with_data.disconnect()
-
- for f in files_read.values():
- assert f.closed
-
- for f in files_write.values():
- assert f.closed
-
-
-def test_disconnect_with_read_and_write(msg_local_with_data):
- msg_local_with_data.read("topic1", blocks=False)
- msg_local_with_data.read("topic2", blocks=False)
- files_read = msg_local_with_data.files_read.copy()
- files_write = msg_local_with_data.files_write.copy()
-
- msg_local_with_data.disconnect()
- for f in files_read.values():
- assert f.closed
- for f in files_write.values():
- assert f.closed
-
-
-@pytest.mark.parametrize(
- "topic, key, msg",
- [
- ("test_topic", "test_key", "test_msg"),
- ("test", "test_key", "test_msg"),
- ("test_topic", "test", "test_msg"),
- ("test_topic", "test_key", "test"),
- ("test_topic", "test_list", ["a", "b", "c"]),
- ("test_topic", "test_tuple", ("c", "b", "a")),
- ("test_topic", "test_dict", {"a": 1, "b": 2, "c": 3}),
- ("test_topic", "test_number", 123),
- ("test_topic", "test_float", 1.23),
- ("test_topic", "test_boolean", True),
- ("test_topic", "test_none", None),
- ],
-)
-def test_write(msg_local_config, topic, key, msg):
- file_path = msg_local_config.path + topic
- msg_local_config.write(topic, key, msg)
- assert os.path.exists(file_path)
-
- with open(file_path, "r") as stream:
- assert yaml.safe_load(stream) == {
- key: msg if not isinstance(msg, tuple) else list(msg)
- }
-
-
-@pytest.mark.parametrize(
- "topic, key, msg, times",
- [
- ("test_topic", "test_key", "test_msg", 2),
- ("test", "test_key", "test_msg", 3),
- ("test_topic", "test", "test_msg", 4),
- ("test_topic", "test_key", "test", 2),
- ("test_topic", "test_list", ["a", "b", "c"], 3),
- ("test_topic", "test_tuple", ("c", "b", "a"), 4),
- ("test_topic", "test_dict", {"a": 1, "b": 2, "c": 3}, 2),
- ("test_topic", "test_number", 123, 3),
- ("test_topic", "test_float", 1.23, 4),
- ("test_topic", "test_boolean", True, 2),
- ("test_topic", "test_none", None, 3),
- ],
-)
-def test_write_with_multiple_calls(msg_local_config, topic, key, msg, times):
- file_path = msg_local_config.path + topic
-
- for _ in range(times):
- msg_local_config.write(topic, key, msg)
- assert os.path.exists(file_path)
-
- with open(file_path, "r") as stream:
- for _ in range(times):
- data = stream.readline()
- assert yaml.safe_load(data) == {
- key: msg if not isinstance(msg, tuple) else list(msg)
- }
-
-
-def test_write_exception(msg_local_config):
- msg_local_config.files_write = MagicMock()
- msg_local_config.files_write.__contains__.side_effect = Exception()
-
- with pytest.raises(MsgException) as excinfo:
- msg_local_config.write("test", "test", "test")
- assert str(excinfo.value).startswith(empty_exception_message())
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-@pytest.mark.parametrize(
- "topics, datas",
- [
- (["topic"], [{"key": "value"}]),
- (["topic1"], [{"key": "value"}]),
- (["topic2"], [{"key": "value"}]),
- (["topic", "topic1"], [{"key": "value"}]),
- (["topic", "topic2"], [{"key": "value"}]),
- (["topic1", "topic2"], [{"key": "value"}]),
- (["topic", "topic1", "topic2"], [{"key": "value"}]),
- (["topic"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic1"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic1"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- ],
-)
-def test_read(msg_local_with_data, topics, datas):
- def write_to_topic(topics, datas):
- # Allow msglocal to block while waiting
- time.sleep(2)
- for topic in topics:
- for data in datas:
- with open(msg_local_with_data.path + topic, "a+") as fp:
- yaml.safe_dump(data, fp, default_flow_style=True, width=20000)
- fp.flush()
-
- # If file is not opened first, the messages written won't be seen
- for topic in topics:
- if topic not in msg_local_with_data.files_read:
- msg_local_with_data.read(topic, blocks=False)
-
- t = threading.Thread(target=write_to_topic, args=(topics, datas))
- t.start()
-
- for topic in topics:
- for data in datas:
- recv_topic, recv_key, recv_msg = msg_local_with_data.read(topic)
- key = list(data.keys())[0]
- val = data[key]
- assert recv_topic == topic
- assert recv_key == key
- assert recv_msg == val
- t.join()
-
-
-@pytest.mark.parametrize(
- "topics, datas",
- [
- (["topic"], [{"key": "value"}]),
- (["topic1"], [{"key": "value"}]),
- (["topic2"], [{"key": "value"}]),
- (["topic", "topic1"], [{"key": "value"}]),
- (["topic", "topic2"], [{"key": "value"}]),
- (["topic1", "topic2"], [{"key": "value"}]),
- (["topic", "topic1", "topic2"], [{"key": "value"}]),
- (["topic"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic1"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic1"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- ],
-)
-def test_read_non_block(msg_local_with_data, topics, datas):
- def write_to_topic(topics, datas):
- for topic in topics:
- for data in datas:
- with open(msg_local_with_data.path + topic, "a+") as fp:
- yaml.safe_dump(data, fp, default_flow_style=True, width=20000)
- fp.flush()
-
- # If file is not opened first, the messages written won't be seen
- for topic in topics:
- if topic not in msg_local_with_data.files_read:
- msg_local_with_data.read(topic, blocks=False)
-
- t = threading.Thread(target=write_to_topic, args=(topics, datas))
- t.start()
- t.join()
-
- for topic in topics:
- for data in datas:
- recv_topic, recv_key, recv_msg = msg_local_with_data.read(
- topic, blocks=False
- )
- key = list(data.keys())[0]
- val = data[key]
- assert recv_topic == topic
- assert recv_key == key
- assert recv_msg == val
-
-
-@pytest.mark.parametrize(
- "topics, datas",
- [
- (["topic"], [{"key": "value"}]),
- (["topic1"], [{"key": "value"}]),
- (["topic2"], [{"key": "value"}]),
- (["topic", "topic1"], [{"key": "value"}]),
- (["topic", "topic2"], [{"key": "value"}]),
- (["topic1", "topic2"], [{"key": "value"}]),
- (["topic", "topic1", "topic2"], [{"key": "value"}]),
- (["topic"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic1"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic1"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- ],
-)
-def test_read_non_block_none(msg_local_with_data, topics, datas):
- def write_to_topic(topics, datas):
- time.sleep(2)
- for topic in topics:
- for data in datas:
- with open(msg_local_with_data.path + topic, "a+") as fp:
- yaml.safe_dump(data, fp, default_flow_style=True, width=20000)
- fp.flush()
-
- # If file is not opened first, the messages written won't be seen
- for topic in topics:
- if topic not in msg_local_with_data.files_read:
- msg_local_with_data.read(topic, blocks=False)
- t = threading.Thread(target=write_to_topic, args=(topics, datas))
- t.start()
-
- for topic in topics:
- recv_data = msg_local_with_data.read(topic, blocks=False)
- assert recv_data is None
- t.join()
-
-
-@pytest.mark.parametrize("blocks", [(True), (False)])
-def test_read_exception(msg_local_with_data, blocks):
- msg_local_with_data.files_read = MagicMock()
- msg_local_with_data.files_read.__contains__.side_effect = Exception()
-
- with pytest.raises(MsgException) as excinfo:
- msg_local_with_data.read("topic1", blocks=blocks)
- assert str(excinfo.value).startswith(empty_exception_message())
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-@pytest.mark.parametrize(
- "topics, datas",
- [
- (["topic"], [{"key": "value"}]),
- (["topic1"], [{"key": "value"}]),
- (["topic2"], [{"key": "value"}]),
- (["topic", "topic1"], [{"key": "value"}]),
- (["topic", "topic2"], [{"key": "value"}]),
- (["topic1", "topic2"], [{"key": "value"}]),
- (["topic", "topic1", "topic2"], [{"key": "value"}]),
- (["topic"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic1"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic1"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- (["topic", "topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
- ],
-)
-def test_aioread(msg_local_with_data, topics, datas):
- def write_to_topic(topics, datas):
- time.sleep(2)
- for topic in topics:
- for data in datas:
- with open(msg_local_with_data.path + topic, "a+") as fp:
- yaml.safe_dump(data, fp, default_flow_style=True, width=20000)
- fp.flush()
-
- # If file is not opened first, the messages written won't be seen
- for topic in topics:
- if topic not in msg_local_with_data.files_read:
- msg_local_with_data.read(topic, blocks=False)
-
- t = threading.Thread(target=write_to_topic, args=(topics, datas))
- t.start()
- for topic in topics:
- for data in datas:
- recv = asyncio.run(msg_local_with_data.aioread(topic))
- recv_topic, recv_key, recv_msg = recv
- key = list(data.keys())[0]
- val = data[key]
- assert recv_topic == topic
- assert recv_key == key
- assert recv_msg == val
- t.join()
-
-
-def test_aioread_exception(msg_local_with_data):
- msg_local_with_data.files_read = MagicMock()
- msg_local_with_data.files_read.__contains__.side_effect = Exception()
-
- with pytest.raises(MsgException) as excinfo:
- asyncio.run(msg_local_with_data.aioread("topic1"))
- assert str(excinfo.value).startswith(empty_exception_message())
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-def test_aioread_general_exception(msg_local_with_data):
- msg_local_with_data.read = MagicMock()
- msg_local_with_data.read.side_effect = Exception()
-
- with pytest.raises(MsgException) as excinfo:
- asyncio.run(msg_local_with_data.aioread("topic1"))
- assert str(excinfo.value).startswith(empty_exception_message())
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
-
-
-@pytest.mark.parametrize(
- "topic, key, msg",
- [
- ("test_topic", "test_key", "test_msg"),
- ("test", "test_key", "test_msg"),
- ("test_topic", "test", "test_msg"),
- ("test_topic", "test_key", "test"),
- ("test_topic", "test_list", ["a", "b", "c"]),
- ("test_topic", "test_tuple", ("c", "b", "a")),
- ("test_topic", "test_dict", {"a": 1, "b": 2, "c": 3}),
- ("test_topic", "test_number", 123),
- ("test_topic", "test_float", 1.23),
- ("test_topic", "test_boolean", True),
- ("test_topic", "test_none", None),
- ],
-)
-def test_aiowrite(msg_local_config, topic, key, msg):
- file_path = msg_local_config.path + topic
- asyncio.run(msg_local_config.aiowrite(topic, key, msg))
- assert os.path.exists(file_path)
-
- with open(file_path, "r") as stream:
- assert yaml.safe_load(stream) == {
- key: msg if not isinstance(msg, tuple) else list(msg)
- }
-
-
-@pytest.mark.parametrize(
- "topic, key, msg, times",
- [
- ("test_topic", "test_key", "test_msg", 2),
- ("test", "test_key", "test_msg", 3),
- ("test_topic", "test", "test_msg", 4),
- ("test_topic", "test_key", "test", 2),
- ("test_topic", "test_list", ["a", "b", "c"], 3),
- ("test_topic", "test_tuple", ("c", "b", "a"), 4),
- ("test_topic", "test_dict", {"a": 1, "b": 2, "c": 3}, 2),
- ("test_topic", "test_number", 123, 3),
- ("test_topic", "test_float", 1.23, 4),
- ("test_topic", "test_boolean", True, 2),
- ("test_topic", "test_none", None, 3),
- ],
-)
-def test_aiowrite_with_multiple_calls(msg_local_config, topic, key, msg, times):
- file_path = msg_local_config.path + topic
- for _ in range(times):
- asyncio.run(msg_local_config.aiowrite(topic, key, msg))
- assert os.path.exists(file_path)
-
- with open(file_path, "r") as stream:
- for _ in range(times):
- data = stream.readline()
- assert yaml.safe_load(data) == {
- key: msg if not isinstance(msg, tuple) else list(msg)
- }
-
-
-def test_aiowrite_exception(msg_local_config):
- msg_local_config.files_write = MagicMock()
- msg_local_config.files_write.__contains__.side_effect = Exception()
-
- with pytest.raises(MsgException) as excinfo:
- asyncio.run(msg_local_config.aiowrite("test", "test", "test"))
- assert str(excinfo.value).startswith(empty_exception_message())
- assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
--- /dev/null
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+[build-system]
+requires = ["setuptools>=61", "wheel", "setuptools_scm>=8.0"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "osm_common"
+description = "OSM common utilities"
+readme = "README.rst"
+authors = [
+ {name = "ETSI OSM", email = "osmsupport@etsi.com"}
+]
+maintainers = [
+ {name = "ETSI OSM", email = "osmsupport@etsi.com"}
+]
+license = "Apache-2.0"
+dynamic = ["dependencies", "version"]
+
+[tool.setuptools_scm]
+write_to = "src/osm_common/_version.py"
+version_scheme = "guess-next-dev"
+local_scheme = "node-and-date"
+
+[tool.setuptools.packages.find]
+where = ["src"]
+
+[tool.setuptools.dynamic]
+dependencies = { file = [".//requirements.txt"] }
+
# See the License for the specific language governing permissions and
# limitations under the License.
+build
stdeb
-setuptools-version-command
-setuptools<60
\ No newline at end of file
+setuptools-scm
# See the License for the specific language governing permissions and
# limitations under the License.
#######################################################################################
-setuptools-version-command==99.9
+build==1.3.0
+ # via -r requirements-dist.in
+packaging==25.0
+ # via
+ # build
+ # setuptools-scm
+pyproject-hooks==1.2.0
+ # via build
+setuptools-scm==9.2.2
# via -r requirements-dist.in
stdeb==0.11.0
# via -r requirements-dist.in
# The following packages are considered to be unsafe in a requirements file:
-setuptools==59.8.0
- # via
- # -r requirements-dist.in
- # setuptools-version-command
+setuptools==80.9.0
+ # via setuptools-scm
#######################################################################################
coverage==7.10.7
# via -r requirements-test.in
-exceptiongroup==1.3.0
- # via pytest
iniconfig==2.1.0
# via pytest
nose2==0.15.1
# via pytest
pytest==8.4.2
# via -r requirements-test.in
-tomli==2.2.1
- # via pytest
-typing-extensions==4.15.0
- # via exceptiongroup
+++ /dev/null
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-# Copyright ETSI OSM Contributors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-# implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import os
-
-from setuptools import setup
-
-here = os.path.abspath(os.path.dirname(__file__))
-_name = "osm_common"
-README = open(os.path.join(here, "README.rst")).read()
-
-setup(
- name=_name,
- description="OSM common utilities",
- long_description=README,
- version_command=(
- "git describe --tags --long --dirty --match v*",
- "pep440-git-full",
- ),
- author="ETSI OSM",
- author_email="osmsupport@etsi.com",
- maintainer="ETSI OSM",
- maintainer_email="osmsupport@etsi.com",
- url="https://osm.etsi.org/gitweb/?p=osm/common.git;a=summary",
- license="Apache 2.0",
- setup_requires=["setuptools-version-command"],
- packages=[_name],
- include_package_data=True,
-)
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
+
+
+class FakeLock:
+ """Implements a fake lock that can be called with the "with" statement or acquire, release methods"""
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ def acquire(self):
+ pass
+
+ def release(self):
+ pass
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from base64 import b64decode, b64encode
+from copy import deepcopy
+from http import HTTPStatus
+import logging
+import re
+from threading import Lock
+import typing
+
+
+from Crypto.Cipher import AES
+from motor.motor_asyncio import AsyncIOMotorClient
+from osm_common.common_utils import FakeLock
+import yaml
+
+__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
+
+
+DB_NAME = "osm"
+
+
+class DbException(Exception):
+ def __init__(self, message, http_code=HTTPStatus.NOT_FOUND):
+ self.http_code = http_code
+ Exception.__init__(self, "database exception " + str(message))
+
+
+class DbBase(object):
+ def __init__(self, encoding_type="ascii", logger_name="db", lock=False):
+ """
+ Constructor of dbBase
+ :param logger_name: logging name
+ :param lock: Used to protect simultaneous access to the same instance class by several threads:
+ False, None: Do not protect, this object will only be accessed by one thread
+ True: This object needs to be protected by several threads accessing.
+ Lock object. Use thi Lock for the threads access protection
+ """
+ self.logger = logging.getLogger(logger_name)
+ self.secret_key = None # 32 bytes length array used for encrypt/decrypt
+ self.encrypt_mode = AES.MODE_ECB
+ self.encoding_type = encoding_type
+ if not lock:
+ self.lock = FakeLock()
+ elif lock is True:
+ self.lock = Lock()
+ elif isinstance(lock, Lock):
+ self.lock = lock
+ else:
+ raise ValueError("lock parameter must be a Lock classclass or boolean")
+
+ def db_connect(self, config, target_version=None):
+ """
+ Connect to database
+ :param config: Configuration of database. Contains among others:
+ host: database host (mandatory)
+ port: database port (mandatory)
+ name: database name (mandatory)
+ user: database username
+ password: database password
+ commonkey: common OSM key used for sensible information encryption
+ masterpassword: same as commonkey, for backward compatibility. Deprecated, to be removed in the future
+ :param target_version: if provided it checks if database contains required version, raising exception otherwise.
+ :return: None or raises DbException on error
+ """
+ raise DbException("Method 'db_connect' not implemented")
+
+ def db_disconnect(self):
+ """
+ Disconnect from database
+ :return: None
+ """
+ pass
+
+ def get_list(self, table, q_filter=None):
+ """
+ Obtain a list of entries matching q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :return: a list (can be empty) with the found entries. Raises DbException on error
+ """
+ raise DbException("Method 'get_list' not implemented")
+
+ def count(self, table, q_filter=None):
+ """
+ Count the number of entries matching q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :return: number of entries found (can be zero)
+ :raise: DbException on error
+ """
+ raise DbException("Method 'count' not implemented")
+
+ def get_one(self, table, q_filter=None, fail_on_empty=True, fail_on_more=True):
+ """
+ Obtain one entry matching q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
+ it raises a DbException
+ :param fail_on_more: If more than one matches filter it returns one of then unless this flag is set tu True, so
+ that it raises a DbException
+ :return: The requested element, or None
+ """
+ raise DbException("Method 'get_one' not implemented")
+
+ def del_list(self, table, q_filter=None):
+ """
+ Deletes all entries that match q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :return: Dict with the number of entries deleted
+ """
+ raise DbException("Method 'del_list' not implemented")
+
+ def del_one(self, table, q_filter=None, fail_on_empty=True):
+ """
+ Deletes one entry that matches q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :param fail_on_empty: If nothing matches filter it returns '0' deleted unless this flag is set tu True, in
+ which case it raises a DbException
+ :return: Dict with the number of entries deleted
+ """
+ raise DbException("Method 'del_one' not implemented")
+
+ def create(self, table, indata):
+ """
+ Add a new entry at database
+ :param table: collection or table
+ :param indata: content to be added
+ :return: database '_id' of the inserted element. Raises a DbException on error
+ """
+ raise DbException("Method 'create' not implemented")
+
+ def create_list(self, table, indata_list):
+ """
+ Add several entries at once
+ :param table: collection or table
+ :param indata_list: list of elements to insert. Each element must be a dictionary.
+ An '_id' key based on random uuid is added at each element if missing
+ :return: list of inserted '_id's. Exception on error
+ """
+ raise DbException("Method 'create_list' not implemented")
+
+ def set_one(
+ self,
+ table,
+ q_filter,
+ update_dict,
+ fail_on_empty=True,
+ unset=None,
+ pull=None,
+ push=None,
+ push_list=None,
+ pull_list=None,
+ ):
+ """
+ Modifies an entry at database
+ :param table: collection or table
+ :param q_filter: Filter
+ :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
+ :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
+ it raises a DbException
+ :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
+ ignored. If not exist, it is ignored
+ :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
+ if exist in the array is removed. If not exist, it is ignored
+ :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
+ is appended to the end of the array
+ :param pull_list: Same as pull but values are arrays where each item is removed from the array
+ :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
+ whole array
+ :return: Dict with the number of entries modified. None if no matching is found.
+ """
+ raise DbException("Method 'set_one' not implemented")
+
+ def set_list(
+ self,
+ table,
+ q_filter,
+ update_dict,
+ unset=None,
+ pull=None,
+ push=None,
+ push_list=None,
+ pull_list=None,
+ ):
+ """
+ Modifies al matching entries at database
+ :param table: collection or table
+ :param q_filter: Filter
+ :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
+ :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
+ ignored. If not exist, it is ignored
+ :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
+ if exist in the array is removed. If not exist, it is ignored
+ :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
+ is appended to the end of the array
+ :param pull_list: Same as pull but values are arrays where each item is removed from the array
+ :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
+ whole array
+ :return: Dict with the number of entries modified
+ """
+ raise DbException("Method 'set_list' not implemented")
+
+ def replace(self, table, _id, indata, fail_on_empty=True):
+ """
+ Replace the content of an entry
+ :param table: collection or table
+ :param _id: internal database id
+ :param indata: content to replace
+ :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
+ it raises a DbException
+ :return: Dict with the number of entries replaced
+ """
+ raise DbException("Method 'replace' not implemented")
+
+ def _join_secret_key(self, update_key):
+ """
+ Returns a xor byte combination of the internal secret_key and the provided update_key.
+ It does not modify the internal secret_key. Used for adding salt, join keys, etc.
+ :param update_key: Can be a string, byte or None. Recommended a long one (e.g. 32 byte length)
+ :return: joined key in bytes with a 32 bytes length. Can be None if both internal secret_key and update_key
+ are None
+ """
+ if not update_key:
+ return self.secret_key
+ elif isinstance(update_key, str):
+ update_key_bytes = update_key.encode()
+ else:
+ update_key_bytes = update_key
+
+ new_secret_key = (
+ bytearray(self.secret_key) if self.secret_key else bytearray(32)
+ )
+ for i, b in enumerate(update_key_bytes):
+ new_secret_key[i % 32] ^= b
+ return bytes(new_secret_key)
+
+ def set_secret_key(self, new_secret_key, replace=False):
+ """
+ Updates internal secret_key used for encryption, with a byte xor
+ :param new_secret_key: string or byte array. It is recommended a 32 byte length
+ :param replace: if True, old value of internal secret_key is ignored and replaced. If false, a byte xor is used
+ :return: None
+ """
+ if replace:
+ self.secret_key = None
+ self.secret_key = self._join_secret_key(new_secret_key)
+
+ def get_secret_key(self):
+ """
+ Get the database secret key in case it is not done when "connect" is called. It can happens when database is
+ empty after an initial install. It should skip if secret is already obtained.
+ """
+ pass
+
+ @staticmethod
+ def pad_data(value: str) -> str:
+ if not isinstance(value, str):
+ raise DbException(
+ f"Incorrect data type: type({value}), string is expected."
+ )
+ return value + ("\0" * ((16 - len(value)) % 16))
+
+ @staticmethod
+ def unpad_data(value: str) -> str:
+ if not isinstance(value, str):
+ raise DbException(
+ f"Incorrect data type: type({value}), string is expected."
+ )
+ return value.rstrip("\0")
+
+ def _encrypt_value(self, value: str, schema_version: str, salt: str):
+ """Encrypt a value.
+
+ Args:
+ value (str): value to be encrypted. It is string/unicode
+ schema_version (str): used for version control. If None or '1.0' no encryption is done.
+ If '1.1' symmetric AES encryption is done
+ salt (str): optional salt to be used. Must be str
+
+ Returns:
+ Encrypted content of value (str)
+
+ """
+ if not self.secret_key or not schema_version or schema_version == "1.0":
+ return value
+
+ else:
+ # Secret key as bytes
+ secret_key = self._join_secret_key(salt)
+ cipher = AES.new(secret_key, self.encrypt_mode)
+ # Padded data as string
+ padded_private_msg = self.pad_data(value)
+ # Padded data as bytes
+ padded_private_msg_bytes = padded_private_msg.encode(self.encoding_type)
+ # Encrypt padded data
+ encrypted_msg = cipher.encrypt(padded_private_msg_bytes)
+ # Base64 encoded encrypted data
+ encoded_encrypted_msg = b64encode(encrypted_msg)
+ # Converting to string
+ return encoded_encrypted_msg.decode(self.encoding_type)
+
+ def encrypt(self, value: str, schema_version: str = None, salt: str = None) -> str:
+ """Encrypt a value.
+
+ Args:
+ value (str): value to be encrypted. It is string/unicode
+ schema_version (str): used for version control. If None or '1.0' no encryption is done.
+ If '1.1' symmetric AES encryption is done
+ salt (str): optional salt to be used. Must be str
+
+ Returns:
+ Encrypted content of value (str)
+
+ """
+ self.get_secret_key()
+ return self._encrypt_value(value, schema_version, salt)
+
+ def _decrypt_value(self, value: str, schema_version: str, salt: str) -> str:
+ """Decrypt an encrypted value.
+ Args:
+
+ value (str): value to be decrypted. It is a base64 string
+ schema_version (str): used for known encryption method used.
+ If None or '1.0' no encryption has been done.
+ If '1.1' symmetric AES encryption has been done
+ salt (str): optional salt to be used
+
+ Returns:
+ Plain content of value (str)
+
+ """
+ if not self.secret_key or not schema_version or schema_version == "1.0":
+ return value
+
+ else:
+ secret_key = self._join_secret_key(salt)
+ # Decoding encrypted data, output bytes
+ encrypted_msg = b64decode(value)
+ cipher = AES.new(secret_key, self.encrypt_mode)
+ # Decrypted data, output bytes
+ decrypted_msg = cipher.decrypt(encrypted_msg)
+ try:
+ # Converting to string
+ private_msg = decrypted_msg.decode(self.encoding_type)
+ except UnicodeDecodeError:
+ raise DbException(
+ "Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
+ http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+ # Unpadded data as string
+ return self.unpad_data(private_msg)
+
+ def decrypt(self, value: str, schema_version: str = None, salt: str = None) -> str:
+ """Decrypt an encrypted value.
+ Args:
+
+ value (str): value to be decrypted. It is a base64 string
+ schema_version (str): used for known encryption method used.
+ If None or '1.0' no encryption has been done.
+ If '1.1' symmetric AES encryption has been done
+ salt (str): optional salt to be used
+
+ Returns:
+ Plain content of value (str)
+
+ """
+ self.get_secret_key()
+ return self._decrypt_value(value, schema_version, salt)
+
+ def encrypt_decrypt_fields(
+ self, item, action, fields=None, flags=None, schema_version=None, salt=None
+ ):
+ if not fields:
+ return
+ self.get_secret_key()
+ actions = ["encrypt", "decrypt"]
+ if action.lower() not in actions:
+ raise DbException(
+ "Unknown action ({}): Must be one of {}".format(action, actions),
+ http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+ method = self.encrypt if action.lower() == "encrypt" else self.decrypt
+ if flags is None:
+ flags = re.I
+
+ def process(_item):
+ if isinstance(_item, list):
+ for elem in _item:
+ process(elem)
+ elif isinstance(_item, dict):
+ for key, val in _item.items():
+ if isinstance(val, str):
+ if any(re.search(f, key, flags) for f in fields):
+ _item[key] = method(val, schema_version, salt)
+ else:
+ process(val)
+
+ process(item)
+
+
+def deep_update_rfc7396(dict_to_change, dict_reference, key_list=None):
+ """
+ Modifies one dictionary with the information of the other following https://tools.ietf.org/html/rfc7396
+ Basically is a recursive python 'dict_to_change.update(dict_reference)', but a value of None is used to delete.
+ It implements an extra feature that allows modifying an array. RFC7396 only allows replacing the entire array.
+ For that, dict_reference should contains a dict with keys starting by "$" with the following meaning:
+ $[index] <index> is an integer for targeting a concrete index from dict_to_change array. If the value is None
+ the element of the array is deleted, otherwise it is edited.
+ $+[index] The value is inserted at this <index>. A value of None has not sense and an exception is raised.
+ $+ The value is appended at the end. A value of None has not sense and an exception is raised.
+ $val It looks for all the items in the array dict_to_change equal to <val>. <val> is evaluated as yaml,
+ that is, numbers are taken as type int, true/false as boolean, etc. Use quotes to force string.
+ Nothing happens if no match is found. If the value is None the matched elements are deleted.
+ $key: val In case a dictionary is passed in yaml format, if looks for all items in the array dict_to_change
+ that are dictionaries and contains this <key> equal to <val>. Several keys can be used by yaml
+ format '{key: val, key: val, ...}'; and all of them must match. Nothing happens if no match is
+ found. If value is None the matched items are deleted, otherwise they are edited.
+ $+val If no match if found (see '$val'), the value is appended to the array. If any match is found nothing
+ is changed. A value of None has not sense.
+ $+key: val If no match if found (see '$key: val'), the value is appended to the array. If any match is found
+ nothing is changed. A value of None has not sense.
+ If there are several editions, insertions and deletions; editions and deletions are done first in reverse index
+ order; then insertions also in reverse index order; and finally appends in any order. So indexes used at
+ insertions must take into account the deleted items.
+ :param dict_to_change: Target dictionary to be changed.
+ :param dict_reference: Dictionary that contains changes to be applied.
+ :param key_list: This is used internally for recursive calls. Do not fill this parameter.
+ :return: none or raises and exception only at array modification when there is a bad format or conflict.
+ """
+
+ def _deep_update_array(array_to_change, _dict_reference, _key_list):
+ to_append = {}
+ to_insert_at_index = {}
+ values_to_edit_delete = {}
+ indexes_to_edit_delete = []
+ array_edition = None
+ _key_list.append("")
+ for k in _dict_reference:
+ _key_list[-1] = str(k)
+ if not isinstance(k, str) or not k.startswith("$"):
+ if array_edition is True:
+ raise DbException(
+ "Found array edition (keys starting with '$') and pure dictionary edition in the"
+ " same dict at '{}'".format(":".join(_key_list[:-1]))
+ )
+ array_edition = False
+ continue
+ else:
+ if array_edition is False:
+ raise DbException(
+ "Found array edition (keys starting with '$') and pure dictionary edition in the"
+ " same dict at '{}'".format(":".join(_key_list[:-1]))
+ )
+ array_edition = True
+ insert = False
+ indexes = [] # indexes to edit or insert
+ kitem = k[1:]
+ if kitem.startswith("+"):
+ insert = True
+ kitem = kitem[1:]
+ if _dict_reference[k] is None:
+ raise DbException(
+ "A value of None has not sense for insertions at '{}'".format(
+ ":".join(_key_list)
+ )
+ )
+
+ if kitem.startswith("[") and kitem.endswith("]"):
+ try:
+ index = int(kitem[1:-1])
+ if index < 0:
+ index += len(array_to_change)
+ if index < 0:
+ index = 0 # skip outside index edition
+ indexes.append(index)
+ except Exception:
+ raise DbException(
+ "Wrong format at '{}'. Expecting integer index inside quotes".format(
+ ":".join(_key_list)
+ )
+ )
+ elif kitem:
+ # match_found_skip = False
+ try:
+ filter_in = yaml.safe_load(kitem)
+ except Exception:
+ raise DbException(
+ "Wrong format at '{}'. Expecting '$<yaml-format>'".format(
+ ":".join(_key_list)
+ )
+ )
+ if isinstance(filter_in, dict):
+ for index, item in enumerate(array_to_change):
+ for filter_k, filter_v in filter_in.items():
+ if (
+ not isinstance(item, dict)
+ or filter_k not in item
+ or item[filter_k] != filter_v
+ ):
+ break
+ else: # match found
+ if insert:
+ # match_found_skip = True
+ insert = False
+ break
+ else:
+ indexes.append(index)
+ else:
+ index = 0
+ try:
+ while True: # if not match a ValueError exception will be raise
+ index = array_to_change.index(filter_in, index)
+ if insert:
+ # match_found_skip = True
+ insert = False
+ break
+ indexes.append(index)
+ index += 1
+ except ValueError:
+ pass
+
+ # if match_found_skip:
+ # continue
+ elif not insert:
+ raise DbException(
+ "Wrong format at '{}'. Expecting '$+', '$[<index]' or '$[<filter>]'".format(
+ ":".join(_key_list)
+ )
+ )
+ for index in indexes:
+ if insert:
+ if (
+ index in to_insert_at_index
+ and to_insert_at_index[index] != _dict_reference[k]
+ ):
+ # Several different insertions on the same item of the array
+ raise DbException(
+ "Conflict at '{}'. Several insertions on same array index {}".format(
+ ":".join(_key_list), index
+ )
+ )
+ to_insert_at_index[index] = _dict_reference[k]
+ else:
+ if (
+ index in indexes_to_edit_delete
+ and values_to_edit_delete[index] != _dict_reference[k]
+ ):
+ # Several different editions on the same item of the array
+ raise DbException(
+ "Conflict at '{}'. Several editions on array index {}".format(
+ ":".join(_key_list), index
+ )
+ )
+ indexes_to_edit_delete.append(index)
+ values_to_edit_delete[index] = _dict_reference[k]
+ if not indexes:
+ if insert:
+ to_append[k] = _dict_reference[k]
+ # elif _dict_reference[k] is not None:
+ # raise DbException("Not found any match to edit in the array, or wrong format at '{}'".format(
+ # ":".join(_key_list)))
+
+ # edition/deletion is done before insertion
+ indexes_to_edit_delete.sort(reverse=True)
+ for index in indexes_to_edit_delete:
+ _key_list[-1] = str(index)
+ try:
+ if values_to_edit_delete[index] is None: # None->Anything
+ try:
+ del array_to_change[index]
+ except IndexError:
+ pass # it is not consider an error if this index does not exist
+ elif not isinstance(
+ values_to_edit_delete[index], dict
+ ): # NotDict->Anything
+ array_to_change[index] = deepcopy(values_to_edit_delete[index])
+ elif isinstance(array_to_change[index], dict): # Dict->Dict
+ deep_update_rfc7396(
+ array_to_change[index], values_to_edit_delete[index], _key_list
+ )
+ else: # Dict->NotDict
+ if isinstance(
+ array_to_change[index], list
+ ): # Dict->List. Check extra array edition
+ if _deep_update_array(
+ array_to_change[index],
+ values_to_edit_delete[index],
+ _key_list,
+ ):
+ continue
+ array_to_change[index] = deepcopy(values_to_edit_delete[index])
+ # calling deep_update_rfc7396 to delete the None values
+ deep_update_rfc7396(
+ array_to_change[index], values_to_edit_delete[index], _key_list
+ )
+ except IndexError:
+ raise DbException(
+ "Array edition index out of range at '{}'".format(
+ ":".join(_key_list)
+ )
+ )
+
+ # insertion with indexes
+ to_insert_indexes = list(to_insert_at_index.keys())
+ to_insert_indexes.sort(reverse=True)
+ for index in to_insert_indexes:
+ array_to_change.insert(index, to_insert_at_index[index])
+
+ # append
+ for k, insert_value in to_append.items():
+ _key_list[-1] = str(k)
+ insert_value_copy = deepcopy(insert_value)
+ if isinstance(insert_value_copy, dict):
+ # calling deep_update_rfc7396 to delete the None values
+ deep_update_rfc7396(insert_value_copy, insert_value, _key_list)
+ array_to_change.append(insert_value_copy)
+
+ _key_list.pop()
+ if array_edition:
+ return True
+ return False
+
+ if key_list is None:
+ key_list = []
+ key_list.append("")
+ for k in dict_reference:
+ key_list[-1] = str(k)
+ if dict_reference[k] is None: # None->Anything
+ if k in dict_to_change:
+ del dict_to_change[k]
+ elif not isinstance(dict_reference[k], dict): # NotDict->Anything
+ dict_to_change[k] = deepcopy(dict_reference[k])
+ elif k not in dict_to_change: # Dict->Empty
+ dict_to_change[k] = deepcopy(dict_reference[k])
+ # calling deep_update_rfc7396 to delete the None values
+ deep_update_rfc7396(dict_to_change[k], dict_reference[k], key_list)
+ elif isinstance(dict_to_change[k], dict): # Dict->Dict
+ deep_update_rfc7396(dict_to_change[k], dict_reference[k], key_list)
+ else: # Dict->NotDict
+ if isinstance(
+ dict_to_change[k], list
+ ): # Dict->List. Check extra array edition
+ if _deep_update_array(dict_to_change[k], dict_reference[k], key_list):
+ continue
+ dict_to_change[k] = deepcopy(dict_reference[k])
+ # calling deep_update_rfc7396 to delete the None values
+ deep_update_rfc7396(dict_to_change[k], dict_reference[k], key_list)
+ key_list.pop()
+
+
+def deep_update(dict_to_change, dict_reference):
+ """Maintained for backward compatibility. Use deep_update_rfc7396 instead"""
+ return deep_update_rfc7396(dict_to_change, dict_reference)
+
+
+class Encryption(DbBase):
+ def __init__(self, uri, config, encoding_type="ascii", logger_name="db"):
+ """Constructor.
+
+ Args:
+ uri (str): Connection string to connect to the database.
+ config (dict): Additional database info
+ encoding_type (str): ascii, utf-8 etc.
+ logger_name (str): Logger name
+
+ """
+ self._secret_key = None # 32 bytes length array used for encrypt/decrypt
+ self.encrypt_mode = AES.MODE_ECB
+ super(Encryption, self).__init__(
+ encoding_type=encoding_type, logger_name=logger_name
+ )
+ self._client = AsyncIOMotorClient(uri)
+ self._config = config
+
+ @property
+ def secret_key(self):
+ return self._secret_key
+
+ @secret_key.setter
+ def secret_key(self, value):
+ self._secret_key = value
+
+ @property
+ def _database(self):
+ return self._client[DB_NAME]
+
+ @property
+ def _admin_collection(self):
+ return self._database["admin"]
+
+ @property
+ def database_key(self):
+ return self._config.get("database_commonkey")
+
+ async def decrypt_fields(
+ self,
+ item: dict,
+ fields: typing.List[str],
+ schema_version: str = None,
+ salt: str = None,
+ ) -> None:
+ """Decrypt fields from a dictionary. Follows the same logic as in osm_common.
+
+ Args:
+
+ item (dict): Dictionary with the keys to be decrypted
+ fields (list): List of keys to decrypt
+ schema version (str): Schema version. (i.e. 1.11)
+ salt (str): Salt for the decryption
+
+ """
+ flags = re.I
+
+ async def process(_item):
+ if isinstance(_item, list):
+ for elem in _item:
+ await process(elem)
+ elif isinstance(_item, dict):
+ for key, val in _item.items():
+ if isinstance(val, str):
+ if any(re.search(f, key, flags) for f in fields):
+ _item[key] = await self.decrypt(val, schema_version, salt)
+ else:
+ await process(val)
+
+ await process(item)
+
+ async def encrypt(
+ self, value: str, schema_version: str = None, salt: str = None
+ ) -> str:
+ """Encrypt a value.
+
+ Args:
+ value (str): value to be encrypted. It is string/unicode
+ schema_version (str): used for version control. If None or '1.0' no encryption is done.
+ If '1.1' symmetric AES encryption is done
+ salt (str): optional salt to be used. Must be str
+
+ Returns:
+ Encrypted content of value (str)
+
+ """
+ await self.get_secret_key()
+ return self._encrypt_value(value, schema_version, salt)
+
+ async def decrypt(
+ self, value: str, schema_version: str = None, salt: str = None
+ ) -> str:
+ """Decrypt an encrypted value.
+ Args:
+
+ value (str): value to be decrypted. It is a base64 string
+ schema_version (str): used for known encryption method used.
+ If None or '1.0' no encryption has been done.
+ If '1.1' symmetric AES encryption has been done
+ salt (str): optional salt to be used
+
+ Returns:
+ Plain content of value (str)
+
+ """
+ await self.get_secret_key()
+ return self._decrypt_value(value, schema_version, salt)
+
+ def _join_secret_key(self, update_key: typing.Any) -> bytes:
+ """Join key with secret key.
+
+ Args:
+
+ update_key (str or bytes): str or bytes with the to update
+
+ Returns:
+
+ Joined key (bytes)
+ """
+ return self._join_keys(update_key, self.secret_key)
+
+ def _join_keys(self, key: typing.Any, secret_key: bytes) -> bytes:
+ """Join key with secret_key.
+
+ Args:
+
+ key (str or bytes): str or bytes of the key to update
+ secret_key (bytes): bytes of the secret key
+
+ Returns:
+
+ Joined key (bytes)
+ """
+ if isinstance(key, str):
+ update_key_bytes = key.encode(self.encoding_type)
+ else:
+ update_key_bytes = key
+ new_secret_key = bytearray(secret_key) if secret_key else bytearray(32)
+ for i, b in enumerate(update_key_bytes):
+ new_secret_key[i % 32] ^= b
+ return bytes(new_secret_key)
+
+ async def get_secret_key(self):
+ """Get secret key using the database key and the serial key in the DB.
+ The key is populated in the property self.secret_key.
+ """
+ if self.secret_key:
+ return
+ secret_key = None
+ if self.database_key:
+ secret_key = self._join_keys(self.database_key, None)
+ version_data = await self._admin_collection.find_one({"_id": "version"})
+ if version_data and version_data.get("serial"):
+ secret_key = self._join_keys(b64decode(version_data["serial"]), secret_key)
+ self._secret_key = secret_key
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+from http import HTTPStatus
+import logging
+from uuid import uuid4
+
+from osm_common.dbbase import DbBase, DbException
+from osm_common.dbmongo import deep_update
+
+
+__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
+
+
+class DbMemory(DbBase):
+ def __init__(self, logger_name="db", lock=False):
+ super().__init__(logger_name=logger_name, lock=lock)
+ self.db = {}
+
+ def db_connect(self, config):
+ """
+ Connect to database
+ :param config: Configuration of database
+ :return: None or raises DbException on error
+ """
+ if "logger_name" in config:
+ self.logger = logging.getLogger(config["logger_name"])
+ master_key = config.get("commonkey") or config.get("masterpassword")
+ if master_key:
+ self.set_secret_key(master_key)
+
+ @staticmethod
+ def _format_filter(q_filter):
+ db_filter = {}
+ # split keys with ANYINDEX in this way:
+ # {"A.B.ANYINDEX.C.D.ANYINDEX.E": v } -> {"A.B.ANYINDEX": {"C.D.ANYINDEX": {"E": v}}}
+ if q_filter:
+ for k, v in q_filter.items():
+ db_v = v
+ kleft, _, kright = k.rpartition(".ANYINDEX.")
+ while kleft:
+ k = kleft + ".ANYINDEX"
+ db_v = {kright: db_v}
+ kleft, _, kright = k.rpartition(".ANYINDEX.")
+ deep_update(db_filter, {k: db_v})
+
+ return db_filter
+
+ def _find(self, table, q_filter):
+ def recursive_find(key_list, key_next_index, content, oper, target):
+ if key_next_index == len(key_list) or content is None:
+ try:
+ if oper in ("eq", "cont"):
+ if isinstance(target, list):
+ if isinstance(content, list):
+ return any(
+ content_item in target for content_item in content
+ )
+ return content in target
+ elif isinstance(content, list):
+ return target in content
+ else:
+ return content == target
+ elif oper in ("neq", "ne", "ncont"):
+ if isinstance(target, list):
+ if isinstance(content, list):
+ return all(
+ content_item not in target
+ for content_item in content
+ )
+ return content not in target
+ elif isinstance(content, list):
+ return target not in content
+ else:
+ return content != target
+ if oper == "gt":
+ return content > target
+ elif oper == "gte":
+ return content >= target
+ elif oper == "lt":
+ return content < target
+ elif oper == "lte":
+ return content <= target
+ else:
+ raise DbException(
+ "Unknown filter operator '{}' in key '{}'".format(
+ oper, ".".join(key_list)
+ ),
+ http_code=HTTPStatus.BAD_REQUEST,
+ )
+ except TypeError:
+ return False
+
+ elif isinstance(content, dict):
+ return recursive_find(
+ key_list,
+ key_next_index + 1,
+ content.get(key_list[key_next_index]),
+ oper,
+ target,
+ )
+ elif isinstance(content, list):
+ look_for_match = True # when there is a match return immediately
+ if (target is None) != (
+ oper in ("neq", "ne", "ncont")
+ ): # one True and other False (Xor)
+ look_for_match = (
+ False # when there is not a match return immediately
+ )
+
+ for content_item in content:
+ if key_list[key_next_index] == "ANYINDEX" and isinstance(v, dict):
+ matches = True
+ if target:
+ for k2, v2 in target.items():
+ k_new_list = k2.split(".")
+ new_operator = "eq"
+ if k_new_list[-1] in (
+ "eq",
+ "ne",
+ "gt",
+ "gte",
+ "lt",
+ "lte",
+ "cont",
+ "ncont",
+ "neq",
+ ):
+ new_operator = k_new_list.pop()
+ if not recursive_find(
+ k_new_list, 0, content_item, new_operator, v2
+ ):
+ matches = False
+ break
+
+ else:
+ matches = recursive_find(
+ key_list, key_next_index, content_item, oper, target
+ )
+ if matches == look_for_match:
+ return matches
+ if key_list[key_next_index].isdecimal() and int(
+ key_list[key_next_index]
+ ) < len(content):
+ matches = recursive_find(
+ key_list,
+ key_next_index + 1,
+ content[int(key_list[key_next_index])],
+ oper,
+ target,
+ )
+ if matches == look_for_match:
+ return matches
+ return not look_for_match
+ else: # content is not dict, nor list neither None, so not found
+ if oper in ("neq", "ne", "ncont"):
+ return target is not None
+ else:
+ return target is None
+
+ for i, row in enumerate(self.db.get(table, ())):
+ q_filter = q_filter or {}
+ for k, v in q_filter.items():
+ k_list = k.split(".")
+ operator = "eq"
+ if k_list[-1] in (
+ "eq",
+ "ne",
+ "gt",
+ "gte",
+ "lt",
+ "lte",
+ "cont",
+ "ncont",
+ "neq",
+ ):
+ operator = k_list.pop()
+ matches = recursive_find(k_list, 0, row, operator, v)
+ if not matches:
+ break
+ else:
+ # match
+ yield i, row
+
+ def get_list(self, table, q_filter=None):
+ """
+ Obtain a list of entries matching q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :return: a list (can be empty) with the found entries. Raises DbException on error
+ """
+ try:
+ result = []
+ with self.lock:
+ for _, row in self._find(table, self._format_filter(q_filter)):
+ result.append(deepcopy(row))
+ return result
+ except DbException:
+ raise
+ except Exception as e: # TODO refine
+ raise DbException(str(e))
+
+ def count(self, table, q_filter=None):
+ """
+ Count the number of entries matching q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :return: number of entries found (can be zero)
+ :raise: DbException on error
+ """
+ try:
+ with self.lock:
+ return sum(1 for x in self._find(table, self._format_filter(q_filter)))
+ except DbException:
+ raise
+ except Exception as e: # TODO refine
+ raise DbException(str(e))
+
+ def get_one(self, table, q_filter=None, fail_on_empty=True, fail_on_more=True):
+ """
+ Obtain one entry matching q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
+ it raises a DbException
+ :param fail_on_more: If more than one matches filter it returns one of then unless this flag is set tu True, so
+ that it raises a DbException
+ :return: The requested element, or None
+ """
+ try:
+ result = None
+ with self.lock:
+ for _, row in self._find(table, self._format_filter(q_filter)):
+ if not fail_on_more:
+ return deepcopy(row)
+ if result:
+ raise DbException(
+ "Found more than one entry with filter='{}'".format(
+ q_filter
+ ),
+ HTTPStatus.CONFLICT.value,
+ )
+ result = row
+ if not result and fail_on_empty:
+ raise DbException(
+ "Not found entry with filter='{}'".format(q_filter),
+ HTTPStatus.NOT_FOUND,
+ )
+ return deepcopy(result)
+ except Exception as e: # TODO refine
+ raise DbException(str(e))
+
+ def del_list(self, table, q_filter=None):
+ """
+ Deletes all entries that match q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :return: Dict with the number of entries deleted
+ """
+ try:
+ id_list = []
+ with self.lock:
+ for i, _ in self._find(table, self._format_filter(q_filter)):
+ id_list.append(i)
+ deleted = len(id_list)
+ for i in reversed(id_list):
+ del self.db[table][i]
+ return {"deleted": deleted}
+ except DbException:
+ raise
+ except Exception as e: # TODO refine
+ raise DbException(str(e))
+
+ def del_one(self, table, q_filter=None, fail_on_empty=True):
+ """
+ Deletes one entry that matches q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :param fail_on_empty: If nothing matches filter it returns '0' deleted unless this flag is set tu True, in
+ which case it raises a DbException
+ :return: Dict with the number of entries deleted
+ """
+ try:
+ with self.lock:
+ for i, _ in self._find(table, self._format_filter(q_filter)):
+ break
+ else:
+ if fail_on_empty:
+ raise DbException(
+ "Not found entry with filter='{}'".format(q_filter),
+ HTTPStatus.NOT_FOUND,
+ )
+ return None
+ del self.db[table][i]
+ return {"deleted": 1}
+ except Exception as e: # TODO refine
+ raise DbException(str(e))
+
+ def _update(
+ self,
+ db_item,
+ update_dict,
+ unset=None,
+ pull=None,
+ push=None,
+ push_list=None,
+ pull_list=None,
+ ):
+ """
+ Modifies an entry at database
+ :param db_item: entry of the table to update
+ :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
+ :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
+ ignored. If not exist, it is ignored
+ :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
+ if exist in the array is removed. If not exist, it is ignored
+ :param pull_list: Same as pull but values are arrays where each item is removed from the array
+ :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
+ is appended to the end of the array
+ :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
+ whole array
+ :return: True if database has been changed, False if not; Exception on error
+ """
+
+ def _iterate_keys(k, db_nested, populate=True):
+ k_list = k.split(".")
+ k_item_prev = k_list[0]
+ populated = False
+ if k_item_prev not in db_nested and populate:
+ populated = True
+ db_nested[k_item_prev] = None
+ for k_item in k_list[1:]:
+ if isinstance(db_nested[k_item_prev], dict):
+ if k_item not in db_nested[k_item_prev]:
+ if not populate:
+ raise DbException(
+ "Cannot set '{}', not existing '{}'".format(k, k_item)
+ )
+ populated = True
+ db_nested[k_item_prev][k_item] = None
+ elif isinstance(db_nested[k_item_prev], list) and k_item.isdigit():
+ # extend list with Nones if index greater than list
+ k_item = int(k_item)
+ if k_item >= len(db_nested[k_item_prev]):
+ if not populate:
+ raise DbException(
+ "Cannot set '{}', index too large '{}'".format(
+ k, k_item
+ )
+ )
+ populated = True
+ db_nested[k_item_prev] += [None] * (
+ k_item - len(db_nested[k_item_prev]) + 1
+ )
+ elif db_nested[k_item_prev] is None:
+ if not populate:
+ raise DbException(
+ "Cannot set '{}', not existing '{}'".format(k, k_item)
+ )
+ populated = True
+ db_nested[k_item_prev] = {k_item: None}
+ else: # number, string, boolean, ... or list but with not integer key
+ raise DbException(
+ "Cannot set '{}' on existing '{}={}'".format(
+ k, k_item_prev, db_nested[k_item_prev]
+ )
+ )
+ db_nested = db_nested[k_item_prev]
+ k_item_prev = k_item
+ return db_nested, k_item_prev, populated
+
+ updated = False
+ try:
+ if update_dict:
+ for dot_k, v in update_dict.items():
+ dict_to_update, key_to_update, _ = _iterate_keys(dot_k, db_item)
+ dict_to_update[key_to_update] = v
+ updated = True
+ if unset:
+ for dot_k in unset:
+ try:
+ dict_to_update, key_to_update, _ = _iterate_keys(
+ dot_k, db_item, populate=False
+ )
+ del dict_to_update[key_to_update]
+ updated = True
+ except Exception as unset_error:
+ self.logger.error(f"{unset_error} occured while updating DB.")
+ if pull:
+ for dot_k, v in pull.items():
+ try:
+ dict_to_update, key_to_update, _ = _iterate_keys(
+ dot_k, db_item, populate=False
+ )
+ except Exception as pull_error:
+ self.logger.error(f"{pull_error} occured while updating DB.")
+ continue
+
+ if key_to_update not in dict_to_update:
+ continue
+ if not isinstance(dict_to_update[key_to_update], list):
+ raise DbException(
+ "Cannot pull '{}'. Target is not a list".format(dot_k)
+ )
+ while v in dict_to_update[key_to_update]:
+ dict_to_update[key_to_update].remove(v)
+ updated = True
+ if pull_list:
+ for dot_k, v in pull_list.items():
+ if not isinstance(v, list):
+ raise DbException(
+ "Invalid content at pull_list, '{}' must be an array".format(
+ dot_k
+ ),
+ http_code=HTTPStatus.BAD_REQUEST,
+ )
+ try:
+ dict_to_update, key_to_update, _ = _iterate_keys(
+ dot_k, db_item, populate=False
+ )
+ except Exception as iterate_error:
+ self.logger.error(
+ f"{iterate_error} occured while iterating keys in db update."
+ )
+ continue
+
+ if key_to_update not in dict_to_update:
+ continue
+ if not isinstance(dict_to_update[key_to_update], list):
+ raise DbException(
+ "Cannot pull_list '{}'. Target is not a list".format(dot_k)
+ )
+ for single_v in v:
+ while single_v in dict_to_update[key_to_update]:
+ dict_to_update[key_to_update].remove(single_v)
+ updated = True
+ if push:
+ for dot_k, v in push.items():
+ dict_to_update, key_to_update, populated = _iterate_keys(
+ dot_k, db_item
+ )
+ if (
+ isinstance(dict_to_update, dict)
+ and key_to_update not in dict_to_update
+ ):
+ dict_to_update[key_to_update] = [v]
+ updated = True
+ elif populated and dict_to_update[key_to_update] is None:
+ dict_to_update[key_to_update] = [v]
+ updated = True
+ elif not isinstance(dict_to_update[key_to_update], list):
+ raise DbException(
+ "Cannot push '{}'. Target is not a list".format(dot_k)
+ )
+ else:
+ dict_to_update[key_to_update].append(v)
+ updated = True
+ if push_list:
+ for dot_k, v in push_list.items():
+ if not isinstance(v, list):
+ raise DbException(
+ "Invalid content at push_list, '{}' must be an array".format(
+ dot_k
+ ),
+ http_code=HTTPStatus.BAD_REQUEST,
+ )
+ dict_to_update, key_to_update, populated = _iterate_keys(
+ dot_k, db_item
+ )
+ if (
+ isinstance(dict_to_update, dict)
+ and key_to_update not in dict_to_update
+ ):
+ dict_to_update[key_to_update] = v.copy()
+ updated = True
+ elif populated and dict_to_update[key_to_update] is None:
+ dict_to_update[key_to_update] = v.copy()
+ updated = True
+ elif not isinstance(dict_to_update[key_to_update], list):
+ raise DbException(
+ "Cannot push '{}'. Target is not a list".format(dot_k),
+ http_code=HTTPStatus.CONFLICT,
+ )
+ else:
+ dict_to_update[key_to_update] += v
+ updated = True
+
+ return updated
+ except DbException:
+ raise
+ except Exception as e: # TODO refine
+ raise DbException(str(e))
+
+ def set_one(
+ self,
+ table,
+ q_filter,
+ update_dict,
+ fail_on_empty=True,
+ unset=None,
+ pull=None,
+ push=None,
+ push_list=None,
+ pull_list=None,
+ ):
+ """
+ Modifies an entry at database
+ :param table: collection or table
+ :param q_filter: Filter
+ :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
+ :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
+ it raises a DbException
+ :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
+ ignored. If not exist, it is ignored
+ :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
+ if exist in the array is removed. If not exist, it is ignored
+ :param pull_list: Same as pull but values are arrays where each item is removed from the array
+ :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
+ is appended to the end of the array
+ :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
+ whole array
+ :return: Dict with the number of entries modified. None if no matching is found.
+ """
+ with self.lock:
+ for i, db_item in self._find(table, self._format_filter(q_filter)):
+ updated = self._update(
+ db_item,
+ update_dict,
+ unset=unset,
+ pull=pull,
+ push=push,
+ push_list=push_list,
+ pull_list=pull_list,
+ )
+ return {"updated": 1 if updated else 0}
+ else:
+ if fail_on_empty:
+ raise DbException(
+ "Not found entry with _id='{}'".format(q_filter),
+ HTTPStatus.NOT_FOUND,
+ )
+ return None
+
+ def set_list(
+ self,
+ table,
+ q_filter,
+ update_dict,
+ unset=None,
+ pull=None,
+ push=None,
+ push_list=None,
+ pull_list=None,
+ ):
+ """Modifies al matching entries at database. Same as push. Do not fail if nothing matches"""
+ with self.lock:
+ updated = 0
+ found = 0
+ for _, db_item in self._find(table, self._format_filter(q_filter)):
+ found += 1
+ if self._update(
+ db_item,
+ update_dict,
+ unset=unset,
+ pull=pull,
+ push=push,
+ push_list=push_list,
+ pull_list=pull_list,
+ ):
+ updated += 1
+ # if not found and fail_on_empty:
+ # raise DbException("Not found entry with '{}'".format(q_filter), HTTPStatus.NOT_FOUND)
+ return {"updated": updated} if found else None
+
+ def replace(self, table, _id, indata, fail_on_empty=True):
+ """
+ Replace the content of an entry
+ :param table: collection or table
+ :param _id: internal database id
+ :param indata: content to replace
+ :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
+ it raises a DbException
+ :return: Dict with the number of entries replaced
+ """
+ try:
+ with self.lock:
+ for i, _ in self._find(table, self._format_filter({"_id": _id})):
+ break
+ else:
+ if fail_on_empty:
+ raise DbException(
+ "Not found entry with _id='{}'".format(_id),
+ HTTPStatus.NOT_FOUND,
+ )
+ return None
+ self.db[table][i] = deepcopy(indata)
+ return {"updated": 1}
+ except DbException:
+ raise
+ except Exception as e: # TODO refine
+ raise DbException(str(e))
+
+ def create(self, table, indata):
+ """
+ Add a new entry at database
+ :param table: collection or table
+ :param indata: content to be added
+ :return: database '_id' of the inserted element. Raises a DbException on error
+ """
+ try:
+ id = indata.get("_id")
+ if not id:
+ id = str(uuid4())
+ indata["_id"] = id
+ with self.lock:
+ if table not in self.db:
+ self.db[table] = []
+ self.db[table].append(deepcopy(indata))
+ return id
+ except Exception as e: # TODO refine
+ raise DbException(str(e))
+
+ def create_list(self, table, indata_list):
+ """
+ Add a new entry at database
+ :param table: collection or table
+ :param indata_list: list content to be added
+ :return: list of inserted 'id's. Raises a DbException on error
+ """
+ try:
+ _ids = []
+ with self.lock:
+ for indata in indata_list:
+ _id = indata.get("_id")
+ if not _id:
+ _id = str(uuid4())
+ indata["_id"] = _id
+ with self.lock:
+ if table not in self.db:
+ self.db[table] = []
+ self.db[table].append(deepcopy(indata))
+ _ids.append(_id)
+ return _ids
+ except Exception as e: # TODO refine
+ raise DbException(str(e))
+
+
+if __name__ == "__main__":
+ # some test code
+ db = DbMemory()
+ db.create("test", {"_id": 1, "data": 1})
+ db.create("test", {"_id": 2, "data": 2})
+ db.create("test", {"_id": 3, "data": 3})
+ print("must be 3 items:", db.get_list("test"))
+ print("must return item 2:", db.get_list("test", {"_id": 2}))
+ db.del_one("test", {"_id": 2})
+ print("must be emtpy:", db.get_list("test", {"_id": 2}))
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from base64 import b64decode
+from copy import deepcopy
+from http import HTTPStatus
+import logging
+from time import sleep, time
+from uuid import uuid4
+
+from osm_common.dbbase import DbBase, DbException
+from pymongo import errors, MongoClient
+
+__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
+
+# TODO consider use this decorator for database access retries
+# @retry_mongocall
+# def retry_mongocall(call):
+# def _retry_mongocall(*args, **kwargs):
+# retry = 1
+# while True:
+# try:
+# return call(*args, **kwargs)
+# except pymongo.AutoReconnect as e:
+# if retry == 4:
+# raise DbException(e)
+# sleep(retry)
+# return _retry_mongocall
+
+
+def deep_update(to_update, update_with):
+ """
+ Similar to deepcopy but recursively with nested dictionaries. 'to_update' dict is updated with a content copy of
+ 'update_with' dict recursively
+ :param to_update: must be a dictionary to be modified
+ :param update_with: must be a dictionary. It is not changed
+ :return: to_update
+ """
+ for key in update_with:
+ if key in to_update:
+ if isinstance(to_update[key], dict) and isinstance(update_with[key], dict):
+ deep_update(to_update[key], update_with[key])
+ continue
+ to_update[key] = deepcopy(update_with[key])
+ return to_update
+
+
+class DbMongo(DbBase):
+ conn_initial_timout = 120
+ conn_timout = 10
+
+ def __init__(self, logger_name="db", lock=False):
+ super().__init__(logger_name=logger_name, lock=lock)
+ self.client = None
+ self.db = None
+ self.database_key = None
+ self.secret_obtained = False
+ # ^ This is used to know if database serial has been got. Database is inited by NBI, who generates the serial
+ # In case it is not ready when connected, it should be got later on before any decrypt operation
+
+ def get_secret_key(self):
+ if self.secret_obtained:
+ return
+
+ self.secret_key = None
+ if self.database_key:
+ self.set_secret_key(self.database_key)
+ version_data = self.get_one(
+ "admin", {"_id": "version"}, fail_on_empty=False, fail_on_more=True
+ )
+ if version_data and version_data.get("serial"):
+ self.set_secret_key(b64decode(version_data["serial"]))
+ self.secret_obtained = True
+
+ def db_connect(self, config, target_version=None):
+ """
+ Connect to database
+ :param config: Configuration of database
+ :param target_version: if provided it checks if database contains required version, raising exception otherwise.
+ :return: None or raises DbException on error
+ """
+ try:
+ if "logger_name" in config:
+ self.logger = logging.getLogger(config["logger_name"])
+ master_key = config.get("commonkey") or config.get("masterpassword")
+ if master_key:
+ self.database_key = master_key
+ self.set_secret_key(master_key)
+ if config.get("uri"):
+ self.client = MongoClient(
+ config["uri"], replicaSet=config.get("replicaset", None)
+ )
+ # when all modules are ready
+ self.db = self.client[config["name"]]
+ if "loglevel" in config:
+ self.logger.setLevel(getattr(logging, config["loglevel"]))
+ # get data to try a connection
+ now = time()
+ while True:
+ try:
+ version_data = self.get_one(
+ "admin",
+ {"_id": "version"},
+ fail_on_empty=False,
+ fail_on_more=True,
+ )
+ # check database status is ok
+ if version_data and version_data.get("status") != "ENABLED":
+ raise DbException(
+ "Wrong database status '{}'".format(
+ version_data.get("status")
+ ),
+ http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+ # check version
+ db_version = (
+ None if not version_data else version_data.get("version")
+ )
+ if target_version and target_version != db_version:
+ raise DbException(
+ "Invalid database version {}. Expected {}".format(
+ db_version, target_version
+ )
+ )
+ # get serial
+ if version_data and version_data.get("serial"):
+ self.secret_obtained = True
+ self.set_secret_key(b64decode(version_data["serial"]))
+ self.logger.info(
+ "Connected to database {} version {}".format(
+ config["name"], db_version
+ )
+ )
+ return
+ except errors.ConnectionFailure as e:
+ if time() - now >= self.conn_initial_timout:
+ raise
+ self.logger.info("Waiting to database up {}".format(e))
+ sleep(2)
+ except errors.PyMongoError as e:
+ raise DbException(e)
+
+ @staticmethod
+ def _format_filter(q_filter):
+ """
+ Translate query string q_filter into mongo database filter
+ :param q_filter: Query string content. Follows SOL005 section 4.3.2 guidelines, with the follow extensions and
+ differences:
+ It accept ".nq" (not equal) in addition to ".neq".
+ For arrays you can specify index (concrete index must match), nothing (any index may match) or 'ANYINDEX'
+ (two or more matches applies for the same array element). Examples:
+ with database register: {A: [{B: 1, C: 2}, {B: 6, C: 9}]}
+ query 'A.B=6' matches because array A contains one element with B equal to 6
+ query 'A.0.B=6' does no match because index 0 of array A contains B with value 1, but not 6
+ query 'A.B=6&A.C=2' matches because one element of array matches B=6 and other matchesC=2
+ query 'A.ANYINDEX.B=6&A.ANYINDEX.C=2' does not match because it is needed the same element of the
+ array matching both
+
+ Examples of translations from SOL005 to >> mongo # comment
+ A=B; A.eq=B >> A: B # must contain key A and equal to B or be a list that contains B
+ A.cont=B >> A: B
+ A=B&A=C; A=B,C >> A: {$in: [B, C]} # must contain key A and equal to B or C or be a list that contains
+ # B or C
+ A.cont=B&A.cont=C; A.cont=B,C >> A: {$in: [B, C]}
+ A.ncont=B >> A: {$nin: B} # must not contain key A or if present not equal to B or if a list,
+ # it must not not contain B
+ A.ncont=B,C; A.ncont=B&A.ncont=C >> A: {$nin: [B,C]} # must not contain key A or if present not equal
+ # neither B nor C; or if a list, it must not contain neither B nor C
+ A.ne=B&A.ne=C; A.ne=B,C >> A: {$nin: [B, C]}
+ A.gt=B >> A: {$gt: B} # must contain key A and greater than B
+ A.ne=B; A.neq=B >> A: {$ne: B} # must not contain key A or if present not equal to B, or if
+ # an array not contain B
+ A.ANYINDEX.B=C >> A: {$elemMatch: {B=C}
+ :return: database mongo filter
+ """
+ try:
+ db_filter = {}
+ if not q_filter:
+ return db_filter
+ for query_k, query_v in q_filter.items():
+ dot_index = query_k.rfind(".")
+ if dot_index > 1 and query_k[dot_index + 1 :] in (
+ "eq",
+ "ne",
+ "gt",
+ "gte",
+ "lt",
+ "lte",
+ "cont",
+ "ncont",
+ "neq",
+ ):
+ operator = "$" + query_k[dot_index + 1 :]
+ if operator == "$neq":
+ operator = "$ne"
+ k = query_k[:dot_index]
+ else:
+ operator = "$eq"
+ k = query_k
+
+ v = query_v
+ if isinstance(v, list):
+ if operator in ("$eq", "$cont"):
+ operator = "$in"
+ v = query_v
+ elif operator in ("$ne", "$ncont"):
+ operator = "$nin"
+ v = query_v
+ else:
+ v = query_v.join(",")
+
+ if operator in ("$eq", "$cont"):
+ # v cannot be a comma separated list, because operator would have been changed to $in
+ db_v = v
+ elif operator == "$ncount":
+ # v cannot be a comma separated list, because operator would have been changed to $nin
+ db_v = {"$ne": v}
+ else:
+ db_v = {operator: v}
+
+ # process the ANYINDEX word at k.
+ kleft, _, kright = k.rpartition(".ANYINDEX.")
+ while kleft:
+ k = kleft
+ db_v = {"$elemMatch": {kright: db_v}}
+ kleft, _, kright = k.rpartition(".ANYINDEX.")
+
+ # insert in db_filter
+ # maybe db_filter[k] exist. e.g. in the query string for values between 5 and 8: "a.gt=5&a.lt=8"
+ deep_update(db_filter, {k: db_v})
+
+ return db_filter
+ except Exception as e:
+ raise DbException(
+ "Invalid query string filter at {}:{}. Error: {}".format(query_k, v, e),
+ http_code=HTTPStatus.BAD_REQUEST,
+ )
+
+ def get_list(self, table, q_filter=None):
+ """
+ Obtain a list of entries matching q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :return: a list (can be empty) with the found entries. Raises DbException on error
+ """
+ try:
+ result = []
+ with self.lock:
+ collection = self.db[table]
+ db_filter = self._format_filter(q_filter)
+ rows = collection.find(db_filter)
+ for row in rows:
+ result.append(row)
+ return result
+ except DbException:
+ raise
+ except Exception as e: # TODO refine
+ raise DbException(e)
+
+ def count(self, table, q_filter=None):
+ """
+ Count the number of entries matching q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :return: number of entries found (can be zero)
+ :raise: DbException on error
+ """
+ try:
+ with self.lock:
+ collection = self.db[table]
+ db_filter = self._format_filter(q_filter)
+ count = collection.count_documents(db_filter)
+ return count
+ except DbException:
+ raise
+ except Exception as e: # TODO refine
+ raise DbException(e)
+
+ def get_one(self, table, q_filter=None, fail_on_empty=True, fail_on_more=True):
+ """
+ Obtain one entry matching q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
+ it raises a DbException
+ :param fail_on_more: If more than one matches filter it returns one of then unless this flag is set tu True, so
+ that it raises a DbException
+ :return: The requested element, or None
+ """
+ try:
+ db_filter = self._format_filter(q_filter)
+ with self.lock:
+ collection = self.db[table]
+ if not (fail_on_empty and fail_on_more):
+ return collection.find_one(db_filter)
+ rows = list(collection.find(db_filter))
+ if len(rows) == 0:
+ if fail_on_empty:
+ raise DbException(
+ "Not found any {} with filter='{}'".format(
+ table[:-1], q_filter
+ ),
+ HTTPStatus.NOT_FOUND,
+ )
+
+ return None
+ elif len(rows) > 1:
+ if fail_on_more:
+ raise DbException(
+ "Found more than one {} with filter='{}'".format(
+ table[:-1], q_filter
+ ),
+ HTTPStatus.CONFLICT,
+ )
+ return rows[0]
+ except Exception as e: # TODO refine
+ raise DbException(e)
+
+ def del_list(self, table, q_filter=None):
+ """
+ Deletes all entries that match q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :return: Dict with the number of entries deleted
+ """
+ try:
+ with self.lock:
+ collection = self.db[table]
+ rows = collection.delete_many(self._format_filter(q_filter))
+ return {"deleted": rows.deleted_count}
+ except DbException:
+ raise
+ except Exception as e: # TODO refine
+ raise DbException(e)
+
+ def del_one(self, table, q_filter=None, fail_on_empty=True):
+ """
+ Deletes one entry that matches q_filter
+ :param table: collection or table
+ :param q_filter: Filter
+ :param fail_on_empty: If nothing matches filter it returns '0' deleted unless this flag is set tu True, in
+ which case it raises a DbException
+ :return: Dict with the number of entries deleted
+ """
+ try:
+ with self.lock:
+ collection = self.db[table]
+ rows = collection.delete_one(self._format_filter(q_filter))
+ if rows.deleted_count == 0:
+ if fail_on_empty:
+ raise DbException(
+ "Not found any {} with filter='{}'".format(
+ table[:-1], q_filter
+ ),
+ HTTPStatus.NOT_FOUND,
+ )
+ return None
+ return {"deleted": rows.deleted_count}
+ except Exception as e: # TODO refine
+ raise DbException(e)
+
+ def create(self, table, indata):
+ """
+ Add a new entry at database
+ :param table: collection or table
+ :param indata: content to be added
+ :return: database id of the inserted element. Raises a DbException on error
+ """
+ try:
+ with self.lock:
+ collection = self.db[table]
+ data = collection.insert_one(indata)
+ return data.inserted_id
+ except Exception as e: # TODO refine
+ raise DbException(e)
+
+ def create_list(self, table, indata_list):
+ """
+ Add several entries at once
+ :param table: collection or table
+ :param indata_list: content list to be added.
+ :return: the list of inserted '_id's. Exception on error
+ """
+ try:
+ for item in indata_list:
+ if item.get("_id") is None:
+ item["_id"] = str(uuid4())
+ with self.lock:
+ collection = self.db[table]
+ data = collection.insert_many(indata_list)
+ return data.inserted_ids
+ except Exception as e: # TODO refine
+ raise DbException(e)
+
+ def set_one(
+ self,
+ table,
+ q_filter,
+ update_dict,
+ fail_on_empty=True,
+ unset=None,
+ pull=None,
+ push=None,
+ push_list=None,
+ pull_list=None,
+ upsert=False,
+ ):
+ """
+ Modifies an entry at database
+ :param table: collection or table
+ :param q_filter: Filter
+ :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
+ :param fail_on_empty: If nothing matches filter it returns None unless this flag is set to True, in which case
+ it raises a DbException
+ :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
+ ignored. If not exist, it is ignored
+ :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
+ if exist in the array is removed. If not exist, it is ignored
+ :param pull_list: Same as pull but values are arrays where each item is removed from the array
+ :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys and value
+ is appended to the end of the array
+ :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
+ whole array
+ :param upsert: If this parameter is set to True and no document is found using 'q_filter' it will be created.
+ By default this is false.
+ :return: Dict with the number of entries modified. None if no matching is found.
+ """
+ try:
+ db_oper = {}
+ if update_dict:
+ db_oper["$set"] = update_dict
+ if unset:
+ db_oper["$unset"] = unset
+ if pull or pull_list:
+ db_oper["$pull"] = pull or {}
+ if pull_list:
+ db_oper["$pull"].update(
+ {k: {"$in": v} for k, v in pull_list.items()}
+ )
+ if push or push_list:
+ db_oper["$push"] = push or {}
+ if push_list:
+ db_oper["$push"].update(
+ {k: {"$each": v} for k, v in push_list.items()}
+ )
+
+ with self.lock:
+ collection = self.db[table]
+ rows = collection.update_one(
+ self._format_filter(q_filter), db_oper, upsert=upsert
+ )
+ if rows.matched_count == 0:
+ if fail_on_empty:
+ raise DbException(
+ "Not found any {} with filter='{}'".format(
+ table[:-1], q_filter
+ ),
+ HTTPStatus.NOT_FOUND,
+ )
+ return None
+ return {"modified": rows.modified_count}
+ except Exception as e: # TODO refine
+ raise DbException(e)
+
+ def set_list(
+ self,
+ table,
+ q_filter,
+ update_dict,
+ unset=None,
+ pull=None,
+ push=None,
+ push_list=None,
+ pull_list=None,
+ ):
+ """
+ Modifies al matching entries at database
+ :param table: collection or table
+ :param q_filter: Filter
+ :param update_dict: Plain dictionary with the content to be updated. It is a dot separated keys and a value
+ :param unset: Plain dictionary with the content to be removed if exist. It is a dot separated keys, value is
+ ignored. If not exist, it is ignored
+ :param pull: Plain dictionary with the content to be removed from an array. It is a dot separated keys and value
+ if exist in the array is removed. If not exist, it is ignored
+ :param push: Plain dictionary with the content to be appended to an array. It is a dot separated keys, the
+ single value is appended to the end of the array
+ :param pull_list: Same as pull but values are arrays where each item is removed from the array
+ :param push_list: Same as push but values are arrays where each item is and appended instead of appending the
+ whole array
+ :return: Dict with the number of entries modified
+ """
+ try:
+ db_oper = {}
+ if update_dict:
+ db_oper["$set"] = update_dict
+ if unset:
+ db_oper["$unset"] = unset
+ if pull or pull_list:
+ db_oper["$pull"] = pull or {}
+ if pull_list:
+ db_oper["$pull"].update(
+ {k: {"$in": v} for k, v in pull_list.items()}
+ )
+ if push or push_list:
+ db_oper["$push"] = push or {}
+ if push_list:
+ db_oper["$push"].update(
+ {k: {"$each": v} for k, v in push_list.items()}
+ )
+ with self.lock:
+ collection = self.db[table]
+ rows = collection.update_many(self._format_filter(q_filter), db_oper)
+ return {"modified": rows.modified_count}
+ except Exception as e: # TODO refine
+ raise DbException(e)
+
+ def replace(self, table, _id, indata, fail_on_empty=True):
+ """
+ Replace the content of an entry
+ :param table: collection or table
+ :param _id: internal database id
+ :param indata: content to replace
+ :param fail_on_empty: If nothing matches filter it returns None unless this flag is set tu True, in which case
+ it raises a DbException
+ :return: Dict with the number of entries replaced
+ """
+ try:
+ db_filter = {"_id": _id}
+ with self.lock:
+ collection = self.db[table]
+ rows = collection.replace_one(db_filter, indata)
+ if rows.matched_count == 0:
+ if fail_on_empty:
+ raise DbException(
+ "Not found any {} with _id='{}'".format(table[:-1], _id),
+ HTTPStatus.NOT_FOUND,
+ )
+ return None
+ return {"replaced": rows.modified_count}
+ except Exception as e: # TODO refine
+ raise DbException(e)
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from http import HTTPStatus
+import logging
+from threading import Lock
+
+from osm_common.common_utils import FakeLock
+
+
+__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
+
+
+class FsException(Exception):
+ def __init__(self, message, http_code=HTTPStatus.INTERNAL_SERVER_ERROR):
+ self.http_code = http_code
+ Exception.__init__(self, "storage exception " + message)
+
+
+class FsBase(object):
+ def __init__(self, logger_name="fs", lock=False):
+ """
+ Constructor of FsBase
+ :param logger_name: logging name
+ :param lock: Used to protect simultaneous access to the same instance class by several threads:
+ False, None: Do not protect, this object will only be accessed by one thread
+ True: This object needs to be protected by several threads accessing.
+ Lock object. Use thi Lock for the threads access protection
+ """
+ self.logger = logging.getLogger(logger_name)
+ if not lock:
+ self.lock = FakeLock()
+ elif lock is True:
+ self.lock = Lock()
+ elif isinstance(lock, Lock):
+ self.lock = lock
+ else:
+ raise ValueError("lock parameter must be a Lock class or boolean")
+
+ def get_params(self):
+ return {}
+
+ def fs_connect(self, config):
+ pass
+
+ def fs_disconnect(self):
+ pass
+
+ def mkdir(self, folder):
+ raise FsException("Method 'mkdir' not implemented")
+
+ def dir_rename(self, src, dst):
+ raise FsException("Method 'dir_rename' not implemented")
+
+ def dir_ls(self, storage):
+ raise FsException("Method 'dir_ls' not implemented")
+
+ def file_exists(self, storage):
+ raise FsException("Method 'file_exists' not implemented")
+
+ def file_size(self, storage):
+ raise FsException("Method 'file_size' not implemented")
+
+ def file_extract(self, tar_object, path):
+ raise FsException("Method 'file_extract' not implemented")
+
+ def file_open(self, storage, mode):
+ raise FsException("Method 'file_open' not implemented")
+
+ def file_delete(self, storage, ignore_non_exist=False):
+ raise FsException("Method 'file_delete' not implemented")
+
+ def sync(self, from_path=None):
+ raise FsException("Method 'sync' not implemented")
+
+ def reverse_sync(self, from_path):
+ raise FsException("Method 'reverse_sync' not implemented")
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+import logging
+import os
+from shutil import rmtree
+import tarfile
+import zipfile
+
+from osm_common.fsbase import FsBase, FsException
+
+__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
+
+
+class FsLocal(FsBase):
+ def __init__(self, logger_name="fs", lock=False):
+ super().__init__(logger_name, lock)
+ self.path = None
+
+ def get_params(self):
+ return {"fs": "local", "path": self.path}
+
+ def fs_connect(self, config):
+ try:
+ if "logger_name" in config:
+ self.logger = logging.getLogger(config["logger_name"])
+ self.path = config["path"]
+ if not self.path.endswith("/"):
+ self.path += "/"
+ if not os.path.exists(self.path):
+ raise FsException(
+ "Invalid configuration param at '[storage]': path '{}' does not exist".format(
+ config["path"]
+ )
+ )
+ except FsException:
+ raise
+ except Exception as e: # TODO refine
+ raise FsException(str(e))
+
+ def fs_disconnect(self):
+ pass # TODO
+
+ def mkdir(self, folder):
+ """
+ Creates a folder or parent object location
+ :param folder:
+ :return: None or raises and exception
+ """
+ try:
+ os.mkdir(self.path + folder)
+ except FileExistsError: # make it idempotent
+ pass
+ except Exception as e:
+ raise FsException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ def dir_rename(self, src, dst):
+ """
+ Rename one directory name. If dst exist, it replaces (deletes) existing directory
+ :param src: source directory
+ :param dst: destination directory
+ :return: None or raises and exception
+ """
+ try:
+ if os.path.exists(self.path + dst):
+ rmtree(self.path + dst)
+
+ os.rename(self.path + src, self.path + dst)
+
+ except Exception as e:
+ raise FsException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ def file_exists(self, storage, mode=None):
+ """
+ Indicates if "storage" file exist
+ :param storage: can be a str or a str list
+ :param mode: can be 'file' exist as a regular file; 'dir' exists as a directory or; 'None' just exists
+ :return: True, False
+ """
+ if isinstance(storage, str):
+ f = storage
+ else:
+ f = "/".join(storage)
+ if os.path.exists(self.path + f):
+ if not mode:
+ return True
+ if mode == "file" and os.path.isfile(self.path + f):
+ return True
+ if mode == "dir" and os.path.isdir(self.path + f):
+ return True
+ return False
+
+ def file_size(self, storage):
+ """
+ return file size
+ :param storage: can be a str or a str list
+ :return: file size
+ """
+ if isinstance(storage, str):
+ f = storage
+ else:
+ f = "/".join(storage)
+ return os.path.getsize(self.path + f)
+
+ def file_extract(self, compressed_object, path):
+ """
+ extract a tar file
+ :param compressed_object: object of type tar or zip
+ :param path: can be a str or a str list, or a tar object where to extract the tar_object
+ :return: None
+ """
+ if isinstance(path, str):
+ f = self.path + path
+ else:
+ f = self.path + "/".join(path)
+
+ if type(compressed_object) is tarfile.TarFile:
+ compressed_object.extractall(path=f)
+ elif (
+ type(compressed_object) is zipfile.ZipFile
+ ): # Just a check to know if this works with both tar and zip
+ compressed_object.extractall(path=f)
+
+ def file_open(self, storage, mode):
+ """
+ Open a file
+ :param storage: can be a str or list of str
+ :param mode: file mode
+ :return: file object
+ """
+ try:
+ if isinstance(storage, str):
+ f = storage
+ else:
+ f = "/".join(storage)
+ return open(self.path + f, mode)
+ except FileNotFoundError:
+ raise FsException(
+ "File {} does not exist".format(f), http_code=HTTPStatus.NOT_FOUND
+ )
+ except IOError:
+ raise FsException(
+ "File {} cannot be opened".format(f), http_code=HTTPStatus.BAD_REQUEST
+ )
+
+ def dir_ls(self, storage):
+ """
+ return folder content
+ :param storage: can be a str or list of str
+ :return: folder content
+ """
+ try:
+ if isinstance(storage, str):
+ f = storage
+ else:
+ f = "/".join(storage)
+ return os.listdir(self.path + f)
+ except NotADirectoryError:
+ raise FsException(
+ "File {} does not exist".format(f), http_code=HTTPStatus.NOT_FOUND
+ )
+ except IOError:
+ raise FsException(
+ "File {} cannot be opened".format(f), http_code=HTTPStatus.BAD_REQUEST
+ )
+
+ def file_delete(self, storage, ignore_non_exist=False):
+ """
+ Delete storage content recursively
+ :param storage: can be a str or list of str
+ :param ignore_non_exist: not raise exception if storage does not exist
+ :return: None
+ """
+ try:
+ if isinstance(storage, str):
+ f = self.path + storage
+ else:
+ f = self.path + "/".join(storage)
+ if os.path.exists(f):
+ rmtree(f)
+ elif not ignore_non_exist:
+ raise FsException(
+ "File {} does not exist".format(storage),
+ http_code=HTTPStatus.NOT_FOUND,
+ )
+ except (IOError, PermissionError) as e:
+ raise FsException(
+ "File {} cannot be deleted: {}".format(f, e),
+ http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+
+ def sync(self, from_path=None):
+ pass # Not needed in fslocal
+
+ def reverse_sync(self, from_path):
+ pass # Not needed in fslocal
--- /dev/null
+# Copyright 2019 Canonical
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: eduardo.sousa@canonical.com
+##
+import datetime
+import errno
+from http import HTTPStatus
+from io import BytesIO, StringIO
+import logging
+import os
+import tarfile
+import zipfile
+
+from gridfs import errors, GridFSBucket
+from osm_common.fsbase import FsBase, FsException
+from pymongo import MongoClient
+
+
+__author__ = "Eduardo Sousa <eduardo.sousa@canonical.com>"
+
+
+class GridByteStream(BytesIO):
+ def __init__(self, filename, fs, mode):
+ BytesIO.__init__(self)
+ self._id = None
+ self.filename = filename
+ self.fs = fs
+ self.mode = mode
+ self.file_type = "file" # Set "file" as default file_type
+
+ self.__initialize__()
+
+ def __initialize__(self):
+ grid_file = None
+
+ cursor = self.fs.find({"filename": self.filename})
+
+ for requested_file in cursor:
+ exception_file = next(cursor, None)
+
+ if exception_file:
+ raise FsException(
+ "Multiple files found", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
+ )
+
+ if requested_file.metadata["type"] in ("file", "sym"):
+ grid_file = requested_file
+ self.file_type = requested_file.metadata["type"]
+ else:
+ raise FsException(
+ "Type isn't file", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
+ )
+
+ if grid_file:
+ self._id = grid_file._id
+ self.fs.download_to_stream(self._id, self)
+
+ if "r" in self.mode:
+ self.seek(0, 0)
+
+ def close(self):
+ if "r" in self.mode:
+ super(GridByteStream, self).close()
+ return
+
+ if self._id:
+ self.fs.delete(self._id)
+
+ cursor = self.fs.find(
+ {"filename": self.filename.split("/")[0], "metadata": {"type": "dir"}}
+ )
+
+ parent_dir = next(cursor, None)
+
+ if not parent_dir:
+ parent_dir_name = self.filename.split("/")[0]
+ self.filename = self.filename.replace(
+ parent_dir_name, parent_dir_name[:-1], 1
+ )
+
+ self.seek(0, 0)
+ if self._id:
+ self.fs.upload_from_stream_with_id(
+ self._id, self.filename, self, metadata={"type": self.file_type}
+ )
+ else:
+ self.fs.upload_from_stream(
+ self.filename, self, metadata={"type": self.file_type}
+ )
+ super(GridByteStream, self).close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+
+class GridStringStream(StringIO):
+ def __init__(self, filename, fs, mode):
+ StringIO.__init__(self)
+ self._id = None
+ self.filename = filename
+ self.fs = fs
+ self.mode = mode
+ self.file_type = "file" # Set "file" as default file_type
+
+ self.__initialize__()
+
+ def __initialize__(self):
+ grid_file = None
+
+ cursor = self.fs.find({"filename": self.filename})
+
+ for requested_file in cursor:
+ exception_file = next(cursor, None)
+
+ if exception_file:
+ raise FsException(
+ "Multiple files found", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
+ )
+
+ if requested_file.metadata["type"] in ("file", "dir"):
+ grid_file = requested_file
+ self.file_type = requested_file.metadata["type"]
+ else:
+ raise FsException(
+ "File type isn't file", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
+ )
+
+ if grid_file:
+ stream = BytesIO()
+ self._id = grid_file._id
+ self.fs.download_to_stream(self._id, stream)
+ stream.seek(0)
+ self.write(stream.read().decode("utf-8"))
+ stream.close()
+
+ if "r" in self.mode:
+ self.seek(0, 0)
+
+ def close(self):
+ if "r" in self.mode:
+ super(GridStringStream, self).close()
+ return
+
+ if self._id:
+ self.fs.delete(self._id)
+
+ cursor = self.fs.find(
+ {"filename": self.filename.split("/")[0], "metadata": {"type": "dir"}}
+ )
+
+ parent_dir = next(cursor, None)
+
+ if not parent_dir:
+ parent_dir_name = self.filename.split("/")[0]
+ self.filename = self.filename.replace(
+ parent_dir_name, parent_dir_name[:-1], 1
+ )
+
+ self.seek(0, 0)
+ stream = BytesIO()
+ stream.write(self.read().encode("utf-8"))
+ stream.seek(0, 0)
+ if self._id:
+ self.fs.upload_from_stream_with_id(
+ self._id, self.filename, stream, metadata={"type": self.file_type}
+ )
+ else:
+ self.fs.upload_from_stream(
+ self.filename, stream, metadata={"type": self.file_type}
+ )
+ stream.close()
+ super(GridStringStream, self).close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+
+class FsMongo(FsBase):
+ def __init__(self, logger_name="fs", lock=False):
+ super().__init__(logger_name, lock)
+ self.path = None
+ self.client = None
+ self.fs = None
+
+ def __update_local_fs(self, from_path=None):
+ dir_cursor = self.fs.find({"metadata.type": "dir"}, no_cursor_timeout=True)
+
+ valid_paths = []
+
+ for directory in dir_cursor:
+ if from_path and not directory.filename.startswith(from_path):
+ continue
+ self.logger.debug("Making dir {}".format(self.path + directory.filename))
+ os.makedirs(self.path + directory.filename, exist_ok=True)
+ valid_paths.append(self.path + directory.filename)
+
+ file_cursor = self.fs.find(
+ {"metadata.type": {"$in": ["file", "sym"]}}, no_cursor_timeout=True
+ )
+
+ for writing_file in file_cursor:
+ if from_path and not writing_file.filename.startswith(from_path):
+ continue
+ file_path = self.path + writing_file.filename
+
+ if writing_file.metadata["type"] == "sym":
+ with BytesIO() as b:
+ self.fs.download_to_stream(writing_file._id, b)
+ b.seek(0)
+ link = b.read().decode("utf-8")
+
+ try:
+ self.logger.debug("Sync removing {}".format(file_path))
+ os.remove(file_path)
+ except OSError as e:
+ if e.errno != errno.ENOENT:
+ # This is probably permission denied or worse
+ raise
+ os.symlink(
+ link, os.path.realpath(os.path.normpath(os.path.abspath(file_path)))
+ )
+ else:
+ folder = os.path.dirname(file_path)
+ if folder not in valid_paths:
+ self.logger.debug("Sync local directory {}".format(file_path))
+ os.makedirs(folder, exist_ok=True)
+ with open(file_path, "wb+") as file_stream:
+ self.logger.debug("Sync download {}".format(file_path))
+ self.fs.download_to_stream(writing_file._id, file_stream)
+ if "permissions" in writing_file.metadata:
+ os.chmod(file_path, writing_file.metadata["permissions"])
+
+ def get_params(self):
+ return {"fs": "mongo", "path": self.path}
+
+ def fs_connect(self, config):
+ try:
+ if "logger_name" in config:
+ self.logger = logging.getLogger(config["logger_name"])
+ if "path" in config:
+ self.path = config["path"]
+ else:
+ raise FsException('Missing parameter "path"')
+ if not self.path.endswith("/"):
+ self.path += "/"
+ if not os.path.exists(self.path):
+ raise FsException(
+ "Invalid configuration param at '[storage]': path '{}' does not exist".format(
+ config["path"]
+ )
+ )
+ elif not os.access(self.path, os.W_OK):
+ raise FsException(
+ "Invalid configuration param at '[storage]': path '{}' is not writable".format(
+ config["path"]
+ )
+ )
+ if all(key in config.keys() for key in ["uri", "collection"]):
+ self.client = MongoClient(config["uri"])
+ self.fs = GridFSBucket(self.client[config["collection"]])
+ else:
+ if "collection" not in config.keys():
+ raise FsException('Missing parameter "collection"')
+ else:
+ raise FsException('Missing parameters: "uri"')
+ except FsException:
+ raise
+ except Exception as e: # TODO refine
+ raise FsException(str(e))
+
+ def fs_disconnect(self):
+ pass # TODO
+
+ def mkdir(self, folder):
+ """
+ Creates a folder or parent object location
+ :param folder:
+ :return: None or raises an exception
+ """
+ folder = folder.rstrip("/")
+ try:
+ self.fs.upload_from_stream(folder, BytesIO(), metadata={"type": "dir"})
+ except errors.FileExists: # make it idempotent
+ pass
+ except Exception as e:
+ raise FsException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ def dir_rename(self, src, dst):
+ """
+ Rename one directory name. If dst exist, it replaces (deletes) existing directory
+ :param src: source directory
+ :param dst: destination directory
+ :return: None or raises and exception
+ """
+ dst = dst.rstrip("/")
+ src = src.rstrip("/")
+
+ try:
+ dst_cursor = self.fs.find(
+ {"filename": {"$regex": "^{}(/|$)".format(dst)}}, no_cursor_timeout=True
+ )
+
+ for dst_file in dst_cursor:
+ self.fs.delete(dst_file._id)
+
+ src_cursor = self.fs.find(
+ {"filename": {"$regex": "^{}(/|$)".format(src)}}, no_cursor_timeout=True
+ )
+
+ for src_file in src_cursor:
+ self.fs.rename(src_file._id, src_file.filename.replace(src, dst, 1))
+ except Exception as e:
+ raise FsException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ def file_exists(self, storage, mode=None):
+ """
+ Indicates if "storage" file exist
+ :param storage: can be a str or a str list
+ :param mode: can be 'file' exist as a regular file; 'dir' exists as a directory or; 'None' just exists
+ :return: True, False
+ """
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ f = f.rstrip("/")
+
+ cursor = self.fs.find({"filename": f})
+
+ for requested_file in cursor:
+ exception_file = next(cursor, None)
+
+ if exception_file:
+ raise FsException(
+ "Multiple files found", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
+ )
+
+ self.logger.debug("Entry {} metadata {}".format(f, requested_file.metadata))
+
+ # if no special mode is required just check it does exists
+ if not mode:
+ return True
+
+ if requested_file.metadata["type"] == mode:
+ return True
+
+ if requested_file.metadata["type"] == "sym" and mode == "file":
+ return True
+
+ return False
+
+ def file_size(self, storage):
+ """
+ return file size
+ :param storage: can be a str or a str list
+ :return: file size
+ """
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ f = f.rstrip("/")
+
+ cursor = self.fs.find({"filename": f})
+
+ for requested_file in cursor:
+ exception_file = next(cursor, None)
+
+ if exception_file:
+ raise FsException(
+ "Multiple files found", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
+ )
+
+ return requested_file.length
+
+ def file_extract(self, compressed_object, path):
+ """
+ extract a tar file
+ :param compressed_object: object of type tar or zip
+ :param path: can be a str or a str list, or a tar object where to extract the tar_object
+ :return: None
+ """
+ f = path if isinstance(path, str) else "/".join(path)
+ f = f.rstrip("/")
+
+ if type(compressed_object) is tarfile.TarFile:
+ for member in compressed_object.getmembers():
+ if member.isfile():
+ stream = compressed_object.extractfile(member)
+ elif member.issym():
+ stream = BytesIO(member.linkname.encode("utf-8"))
+ else:
+ stream = BytesIO()
+
+ if member.isfile():
+ file_type = "file"
+ elif member.issym():
+ file_type = "sym"
+ else:
+ file_type = "dir"
+
+ metadata = {"type": file_type, "permissions": member.mode}
+ member.name = member.name.rstrip("/")
+
+ self.logger.debug("Uploading {}/{}".format(f, member.name))
+ self.fs.upload_from_stream(
+ f + "/" + member.name, stream, metadata=metadata
+ )
+
+ stream.close()
+ elif type(compressed_object) is zipfile.ZipFile:
+ for member in compressed_object.infolist():
+ if member.is_dir():
+ stream = BytesIO()
+ else:
+ stream = compressed_object.read(member)
+
+ if member.is_dir():
+ file_type = "dir"
+ else:
+ file_type = "file"
+
+ metadata = {"type": file_type}
+ member.filename = member.filename.rstrip("/")
+
+ self.logger.debug("Uploading {}/{}".format(f, member.filename))
+ self.fs.upload_from_stream(
+ f + "/" + member.filename, stream, metadata=metadata
+ )
+
+ if member.is_dir():
+ stream.close()
+
+ def file_open(self, storage, mode):
+ """
+ Open a file
+ :param storage: can be a str or list of str
+ :param mode: file mode
+ :return: file object
+ """
+ try:
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ f = f.rstrip("/")
+
+ if "b" in mode:
+ return GridByteStream(f, self.fs, mode)
+ else:
+ return GridStringStream(f, self.fs, mode)
+ except errors.NoFile:
+ raise FsException(
+ "File {} does not exist".format(f), http_code=HTTPStatus.NOT_FOUND
+ )
+ except IOError:
+ raise FsException(
+ "File {} cannot be opened".format(f), http_code=HTTPStatus.BAD_REQUEST
+ )
+
+ def dir_ls(self, storage):
+ """
+ return folder content
+ :param storage: can be a str or list of str
+ :return: folder content
+ """
+ try:
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ f = f.rstrip("/")
+
+ files = []
+ dir_cursor = self.fs.find({"filename": f})
+ for requested_dir in dir_cursor:
+ exception_dir = next(dir_cursor, None)
+
+ if exception_dir:
+ raise FsException(
+ "Multiple directories found",
+ http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+
+ if requested_dir.metadata["type"] != "dir":
+ raise FsException(
+ "File {} does not exist".format(f),
+ http_code=HTTPStatus.NOT_FOUND,
+ )
+
+ if f.endswith("/"):
+ f = f[:-1]
+
+ files_cursor = self.fs.find(
+ {"filename": {"$regex": "^{}/([^/])*".format(f)}}
+ )
+ for children_file in files_cursor:
+ files += [children_file.filename.replace(f + "/", "", 1)]
+
+ return files
+ except IOError:
+ raise FsException(
+ "File {} cannot be opened".format(f), http_code=HTTPStatus.BAD_REQUEST
+ )
+
+ def file_delete(self, storage, ignore_non_exist=False):
+ """
+ Delete storage content recursively
+ :param storage: can be a str or list of str
+ :param ignore_non_exist: not raise exception if storage does not exist
+ :return: None
+ """
+ try:
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ f = f.rstrip("/")
+
+ file_cursor = self.fs.find({"filename": f})
+ found = False
+ for requested_file in file_cursor:
+ found = True
+ exception_file = next(file_cursor, None)
+
+ if exception_file:
+ self.logger.error(
+ "Cannot delete duplicate file: {} and {}".format(
+ requested_file.filename, exception_file.filename
+ )
+ )
+ raise FsException(
+ "Multiple files found",
+ http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+
+ if requested_file.metadata["type"] == "dir":
+ dir_cursor = self.fs.find(
+ {"filename": {"$regex": "^{}/".format(f)}}
+ )
+
+ for tmp in dir_cursor:
+ self.logger.debug("Deleting {}".format(tmp.filename))
+ self.fs.delete(tmp._id)
+
+ self.logger.debug("Deleting {}".format(requested_file.filename))
+ self.fs.delete(requested_file._id)
+ if not found and not ignore_non_exist:
+ raise FsException(
+ "File {} does not exist".format(storage),
+ http_code=HTTPStatus.NOT_FOUND,
+ )
+ except IOError as e:
+ raise FsException(
+ "File {} cannot be deleted: {}".format(f, e),
+ http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+
+ def sync(self, from_path=None):
+ """
+ Sync from FSMongo to local storage
+ :param from_path: if supplied, only copy content from this path, not all
+ :return: None
+ """
+ if from_path:
+ if os.path.isabs(from_path):
+ from_path = os.path.relpath(from_path, self.path)
+ self.__update_local_fs(from_path=from_path)
+
+ def _update_mongo_fs(self, from_path):
+ os_path = self.path + from_path
+ # Obtain list of files and dirs in filesystem
+ members = []
+ for root, dirs, files in os.walk(os_path):
+ for folder in dirs:
+ member = {"filename": os.path.join(root, folder), "type": "dir"}
+ if os.path.islink(member["filename"]):
+ member["type"] = "sym"
+ members.append(member)
+ for file in files:
+ filename = os.path.join(root, file)
+ if os.path.islink(filename):
+ file_type = "sym"
+ else:
+ file_type = "file"
+ member = {"filename": os.path.join(root, file), "type": file_type}
+ members.append(member)
+
+ # Obtain files in mongo dict
+ remote_files = self._get_mongo_files(from_path)
+
+ # Upload members if they do not exists or have been modified
+ # We will do this for performance (avoid updating unmodified files) and to avoid
+ # updating a file with an older one in case there are two sources for synchronization
+ # in high availability scenarios
+ for member in members:
+ # obtain permission
+ mask = int(oct(os.stat(member["filename"]).st_mode)[-3:], 8)
+
+ # convert to relative path
+ rel_filename = os.path.relpath(member["filename"], self.path)
+ # get timestamp in UTC because mongo stores upload date in UTC:
+ # https://www.mongodb.com/docs/v4.0/tutorial/model-time-data/#overview
+ last_modified_date = datetime.datetime.utcfromtimestamp(
+ os.path.getmtime(member["filename"])
+ )
+
+ remote_file = remote_files.get(rel_filename)
+ upload_date = (
+ remote_file[0].uploadDate if remote_file else datetime.datetime.min
+ )
+ # remove processed files from dict
+ remote_files.pop(rel_filename, None)
+
+ if last_modified_date >= upload_date:
+ stream = None
+ fh = None
+ try:
+ file_type = member["type"]
+ if file_type == "dir":
+ stream = BytesIO()
+ elif file_type == "sym":
+ stream = BytesIO(
+ os.readlink(member["filename"]).encode("utf-8")
+ )
+ else:
+ fh = open(member["filename"], "rb")
+ stream = BytesIO(fh.read())
+
+ metadata = {"type": file_type, "permissions": mask}
+
+ self.logger.debug("Sync upload {}".format(rel_filename))
+ self.fs.upload_from_stream(rel_filename, stream, metadata=metadata)
+
+ # delete old files
+ if remote_file:
+ for file in remote_file:
+ self.logger.debug("Sync deleting {}".format(file.filename))
+ self.fs.delete(file._id)
+ finally:
+ if fh:
+ fh.close()
+ if stream:
+ stream.close()
+
+ # delete files that are not anymore in local fs
+ for remote_file in remote_files.values():
+ for file in remote_file:
+ self.fs.delete(file._id)
+
+ def _get_mongo_files(self, from_path=None):
+ file_dict = {}
+ file_cursor = self.fs.find(no_cursor_timeout=True, sort=[("uploadDate", -1)])
+ for file in file_cursor:
+ if from_path and not file.filename.startswith(from_path):
+ continue
+ if file.filename in file_dict:
+ file_dict[file.filename].append(file)
+ else:
+ file_dict[file.filename] = [file]
+ return file_dict
+
+ def reverse_sync(self, from_path: str):
+ """
+ Sync from local storage to FSMongo
+ :param from_path: base directory to upload content to mongo fs
+ :return: None
+ """
+ if os.path.isabs(from_path):
+ from_path = os.path.relpath(from_path, self.path)
+ self._update_mongo_fs(from_path=from_path)
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+import logging
+from threading import Lock
+
+from osm_common.common_utils import FakeLock
+
+__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
+
+
+class MsgException(Exception):
+ """
+ Base Exception class for all msgXXXX exceptions
+ """
+
+ def __init__(self, message, http_code=HTTPStatus.SERVICE_UNAVAILABLE):
+ """
+ General exception
+ :param message: descriptive text
+ :param http_code: <http.HTTPStatus> type. It contains ".value" (http error code) and ".name" (http error name
+ """
+ self.http_code = http_code
+ Exception.__init__(self, "messaging exception " + message)
+
+
+class MsgBase(object):
+ """
+ Base class for all msgXXXX classes
+ """
+
+ def __init__(self, logger_name="msg", lock=False):
+ """
+ Constructor of FsBase
+ :param logger_name: logging name
+ :param lock: Used to protect simultaneous access to the same instance class by several threads:
+ False, None: Do not protect, this object will only be accessed by one thread
+ True: This object needs to be protected by several threads accessing.
+ Lock object. Use thi Lock for the threads access protection
+ """
+ self.logger = logging.getLogger(logger_name)
+ if not lock:
+ self.lock = FakeLock()
+ elif lock is True:
+ self.lock = Lock()
+ elif isinstance(lock, Lock):
+ self.lock = lock
+ else:
+ raise ValueError("lock parameter must be a Lock class or boolean")
+
+ def connect(self, config):
+ pass
+
+ def disconnect(self):
+ pass
+
+ def write(self, topic, key, msg):
+ raise MsgException(
+ "Method 'write' not implemented", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
+ )
+
+ def read(self, topic):
+ raise MsgException(
+ "Method 'read' not implemented", http_code=HTTPStatus.INTERNAL_SERVER_ERROR
+ )
+
+ async def aiowrite(self, topic, key, msg):
+ raise MsgException(
+ "Method 'aiowrite' not implemented",
+ http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
+
+ async def aioread(
+ self, topic, callback=None, aiocallback=None, group_id=None, **kwargs
+ ):
+ raise MsgException(
+ "Method 'aioread' not implemented",
+ http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import logging
+
+from aiokafka import AIOKafkaConsumer
+from aiokafka import AIOKafkaProducer
+from aiokafka.errors import KafkaError
+from osm_common.msgbase import MsgBase, MsgException
+import yaml
+
+__author__ = (
+ "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>, "
+ "Guillermo Calvino <guillermo.calvinosanchez@altran.com>"
+)
+
+
+class MsgKafka(MsgBase):
+ def __init__(self, logger_name="msg", lock=False):
+ super().__init__(logger_name, lock)
+ self.host = None
+ self.port = None
+ self.consumer = None
+ self.producer = None
+ self.broker = None
+ self.group_id = None
+
+ def connect(self, config):
+ try:
+ if "logger_name" in config:
+ self.logger = logging.getLogger(config["logger_name"])
+ self.host = config["host"]
+ self.port = config["port"]
+ self.broker = str(self.host) + ":" + str(self.port)
+ self.group_id = config.get("group_id")
+
+ except Exception as e: # TODO refine
+ raise MsgException(str(e))
+
+ def disconnect(self):
+ try:
+ pass
+ except Exception as e: # TODO refine
+ raise MsgException(str(e))
+
+ def write(self, topic, key, msg):
+ """
+ Write a message at kafka bus
+ :param topic: message topic, must be string
+ :param key: message key, must be string
+ :param msg: message content, can be string or dictionary
+ :return: None or raises MsgException on failing
+ """
+ retry = 2 # Try two times
+ while retry:
+ try:
+ asyncio.run(self.aiowrite(topic=topic, key=key, msg=msg))
+ break
+ except Exception as e:
+ retry -= 1
+ if retry == 0:
+ raise MsgException(
+ "Error writing {} topic: {}".format(topic, str(e))
+ )
+
+ def read(self, topic):
+ """
+ Read from one or several topics.
+ :param topic: can be str: single topic; or str list: several topics
+ :return: topic, key, message; or None
+ """
+ try:
+ return asyncio.run(self.aioread(topic))
+ except MsgException:
+ raise
+ except Exception as e:
+ raise MsgException("Error reading {} topic: {}".format(topic, str(e)))
+
+ async def aiowrite(self, topic, key, msg):
+ """
+ Asyncio write
+ :param topic: str kafka topic
+ :param key: str kafka key
+ :param msg: str or dictionary kafka message
+ :return: None
+ """
+ try:
+ self.producer = AIOKafkaProducer(
+ key_serializer=str.encode,
+ value_serializer=str.encode,
+ bootstrap_servers=self.broker,
+ )
+ await self.producer.start()
+ await self.producer.send(
+ topic=topic, key=key, value=yaml.safe_dump(msg, default_flow_style=True)
+ )
+ except Exception as e:
+ raise MsgException(
+ "Error publishing topic '{}', key '{}': {}".format(topic, key, e)
+ )
+ finally:
+ await self.producer.stop()
+
+ async def aioread(
+ self,
+ topic,
+ callback=None,
+ aiocallback=None,
+ group_id=None,
+ from_beginning=None,
+ **kwargs
+ ):
+ """
+ Asyncio read from one or several topics.
+ :param topic: can be str: single topic; or str list: several topics
+ :param callback: synchronous callback function that will handle the message in kafka bus
+ :param aiocallback: async callback function that will handle the message in kafka bus
+ :param group_id: kafka group_id to use. Can be False (set group_id to None), None (use general group_id provided
+ at connect inside config), or a group_id string
+ :param from_beginning: if True, messages will be obtained from beginning instead of only new ones.
+ If group_id is supplied, only the not processed messages by other worker are obtained.
+ If group_id is None, all messages stored at kafka are obtained.
+ :param kwargs: optional keyword arguments for callback function
+ :return: If no callback defined, it returns (topic, key, message)
+ """
+ if group_id is False:
+ group_id = None
+ elif group_id is None:
+ group_id = self.group_id
+ try:
+ if isinstance(topic, (list, tuple)):
+ topic_list = topic
+ else:
+ topic_list = (topic,)
+ self.consumer = AIOKafkaConsumer(
+ bootstrap_servers=self.broker,
+ group_id=group_id,
+ auto_offset_reset="earliest" if from_beginning else "latest",
+ )
+ await self.consumer.start()
+ self.consumer.subscribe(topic_list)
+
+ async for message in self.consumer:
+ if callback:
+ callback(
+ message.topic,
+ yaml.safe_load(message.key),
+ yaml.safe_load(message.value),
+ **kwargs
+ )
+ elif aiocallback:
+ await aiocallback(
+ message.topic,
+ yaml.safe_load(message.key),
+ yaml.safe_load(message.value),
+ **kwargs
+ )
+ else:
+ return (
+ message.topic,
+ yaml.safe_load(message.key),
+ yaml.safe_load(message.value),
+ )
+ except KafkaError as e:
+ raise MsgException(str(e))
+ finally:
+ await self.consumer.stop()
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+from http import HTTPStatus
+import logging
+import os
+from time import sleep
+
+from osm_common.msgbase import MsgBase, MsgException
+import yaml
+
+__author__ = "Alfonso Tierno <alfonso.tiernosepulveda@telefonica.com>"
+"""
+This emulated kafka bus by just using a shared file system. Useful for testing or devops.
+One file is used per topic. Only one producer and one consumer is allowed per topic. Both consumer and producer
+access to the same file. e.g. same volume if running with docker.
+One text line per message is used in yaml format.
+"""
+
+
+class MsgLocal(MsgBase):
+ def __init__(self, logger_name="msg", lock=False):
+ super().__init__(logger_name, lock)
+ self.path = None
+ # create a different file for each topic
+ self.files_read = {}
+ self.files_write = {}
+ self.buffer = {}
+
+ def connect(self, config):
+ try:
+ if "logger_name" in config:
+ self.logger = logging.getLogger(config["logger_name"])
+ self.path = config["path"]
+ if not self.path.endswith("/"):
+ self.path += "/"
+ if not os.path.exists(self.path):
+ os.mkdir(self.path)
+
+ except MsgException:
+ raise
+ except Exception as e: # TODO refine
+ raise MsgException(str(e), http_code=HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ def disconnect(self):
+ for topic, f in self.files_read.items():
+ try:
+ f.close()
+ self.files_read[topic] = None
+ except Exception as read_topic_error:
+ if isinstance(read_topic_error, (IOError, FileNotFoundError)):
+ self.logger.exception(
+ f"{read_topic_error} occured while closing read topic files."
+ )
+ elif isinstance(read_topic_error, KeyError):
+ self.logger.exception(
+ f"{read_topic_error} occured while reading from files_read dictionary."
+ )
+ else:
+ self.logger.exception(
+ f"{read_topic_error} occured while closing read topics."
+ )
+
+ for topic, f in self.files_write.items():
+ try:
+ f.close()
+ self.files_write[topic] = None
+ except Exception as write_topic_error:
+ if isinstance(write_topic_error, (IOError, FileNotFoundError)):
+ self.logger.exception(
+ f"{write_topic_error} occured while closing write topic files."
+ )
+ elif isinstance(write_topic_error, KeyError):
+ self.logger.exception(
+ f"{write_topic_error} occured while reading from files_write dictionary."
+ )
+ else:
+ self.logger.exception(
+ f"{write_topic_error} occured while closing write topics."
+ )
+
+ def write(self, topic, key, msg):
+ """
+ Insert a message into topic
+ :param topic: topic
+ :param key: key text to be inserted
+ :param msg: value object to be inserted, can be str, object ...
+ :return: None or raises and exception
+ """
+ try:
+ with self.lock:
+ if topic not in self.files_write:
+ self.files_write[topic] = open(self.path + topic, "a+")
+ yaml.safe_dump(
+ {key: msg},
+ self.files_write[topic],
+ default_flow_style=True,
+ width=20000,
+ )
+ self.files_write[topic].flush()
+ except Exception as e: # TODO refine
+ raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ def read(self, topic, blocks=True):
+ """
+ Read from one or several topics. it is non blocking returning None if nothing is available
+ :param topic: can be str: single topic; or str list: several topics
+ :param blocks: indicates if it should wait and block until a message is present or returns None
+ :return: topic, key, message; or None if blocks==True
+ """
+ try:
+ if isinstance(topic, (list, tuple)):
+ topic_list = topic
+ else:
+ topic_list = (topic,)
+ while True:
+ for single_topic in topic_list:
+ with self.lock:
+ if single_topic not in self.files_read:
+ self.files_read[single_topic] = open(
+ self.path + single_topic, "a+"
+ )
+ self.buffer[single_topic] = ""
+ self.buffer[single_topic] += self.files_read[
+ single_topic
+ ].readline()
+ if not self.buffer[single_topic].endswith("\n"):
+ continue
+ msg_dict = yaml.safe_load(self.buffer[single_topic])
+ self.buffer[single_topic] = ""
+ if len(msg_dict) != 1:
+ raise ValueError(
+ "Length of message dictionary is not equal to 1"
+ )
+ for k, v in msg_dict.items():
+ return single_topic, k, v
+ if not blocks:
+ return None
+ sleep(2)
+ except Exception as e: # TODO refine
+ raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ async def aioread(
+ self, topic, callback=None, aiocallback=None, group_id=None, **kwargs
+ ):
+ """
+ Asyncio read from one or several topics. It blocks
+ :param topic: can be str: single topic; or str list: several topics
+ :param callback: synchronous callback function that will handle the message
+ :param aiocallback: async callback function that will handle the message
+ :param group_id: group_id to use for load balancing. Can be False (set group_id to None), None (use general
+ group_id provided at connect inside config), or a group_id string
+ :param kwargs: optional keyword arguments for callback function
+ :return: If no callback defined, it returns (topic, key, message)
+ """
+ try:
+ while True:
+ msg = self.read(topic, blocks=False)
+ if msg:
+ if callback:
+ callback(*msg, **kwargs)
+ elif aiocallback:
+ await aiocallback(*msg, **kwargs)
+ else:
+ return msg
+ await asyncio.sleep(2)
+ except MsgException:
+ raise
+ except Exception as e: # TODO refine
+ raise MsgException(str(e), HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ async def aiowrite(self, topic, key, msg):
+ """
+ Asyncio write. It blocks
+ :param topic: str
+ :param key: str
+ :param msg: message, can be str or yaml
+ :return: nothing if ok or raises an exception
+ """
+ return self.write(topic, key, msg)
--- /dev/null
+#######################################################################################
+# Copyright ETSI Contributors and Others.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#######################################################################################
+aiokafka==0.12.0
+ # via -r requirements.in
+async-timeout==4.0.3
+ # via
+ # -r requirements.in
+ # aiokafka
+dataclasses==0.6
+ # via -r requirements.in
+dnspython==2.8.0
+ # via pymongo
+motor==3.7.1
+ # via -r requirements.in
+packaging==25.0
+ # via aiokafka
+pycryptodome==3.23.0
+ # via -r requirements.in
+pymongo==4.15.2
+ # via
+ # -r requirements.in
+ # motor
+pyyaml==6.0.3
+ # via -r requirements.in
+typing-extensions==4.15.0
+ # via aiokafka
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: agarcia@whitestack.com or fbravo@whitestack.com
+##
+
+"""Python module for interacting with ETSI GS NFV-SOL004 compliant packages
+
+This module provides a SOL004Package class for validating and interacting with
+ETSI SOL004 packages. A valid SOL004 package may have its files arranged according
+to one of the following two structures:
+
+SOL004 with metadata directory SOL004 without metadata directory
+
+native_charm_vnf/ native_charm_vnf/
+├── TOSCA-Metadata ├── native_charm_vnfd.mf
+│ └── TOSCA.meta ├── native_charm_vnfd.yaml
+├── manifest.mf ├── ChangeLog.txt
+├── Definitions ├── Licenses
+│ └── native_charm_vnfd.yaml │ └── license.lic
+├── Files ├── Files
+│ ├── icons │ └── icons
+│ │ └── osm.png │ └── osm.png
+│ ├── Licenses └── Scripts
+│ │ └── license.lic ├── cloud_init
+│ └── changelog.txt │ └── cloud-config.txt
+└── Scripts └── charms
+ ├── cloud_init └── simple
+ │ └── cloud-config.txt ├── config.yaml
+ └── charms ├── hooks
+ └── simple │ ├── install
+ ├── config.yaml ...
+ ├── hooks │
+ │ ├── install └── src
+ ... └── charm.py
+ └── src
+ └── charm.py
+"""
+
+import datetime
+import os
+
+import yaml
+
+from .sol_package import SOLPackage
+
+
+class SOL004PackageException(Exception):
+ pass
+
+
+class SOL004Package(SOLPackage):
+ _MANIFEST_VNFD_ID = "vnfd_id"
+ _MANIFEST_VNFD_PRODUCT_NAME = "vnfd_product_name"
+ _MANIFEST_VNFD_PROVIDER_ID = "vnfd_provider_id"
+ _MANIFEST_VNFD_SOFTWARE_VERSION = "vnfd_software_version"
+ _MANIFEST_VNFD_PACKAGE_VERSION = "vnfd_package_version"
+ _MANIFEST_VNFD_RELEASE_DATE_TIME = "vnfd_release_date_time"
+ _MANIFEST_VNFD_COMPATIBLE_SPECIFICATION_VERSIONS = (
+ "compatible_specification_versions"
+ )
+ _MANIFEST_VNFM_INFO = "vnfm_info"
+
+ _MANIFEST_ALL_FIELDS = [
+ _MANIFEST_VNFD_ID,
+ _MANIFEST_VNFD_PRODUCT_NAME,
+ _MANIFEST_VNFD_PROVIDER_ID,
+ _MANIFEST_VNFD_SOFTWARE_VERSION,
+ _MANIFEST_VNFD_PACKAGE_VERSION,
+ _MANIFEST_VNFD_RELEASE_DATE_TIME,
+ _MANIFEST_VNFD_COMPATIBLE_SPECIFICATION_VERSIONS,
+ _MANIFEST_VNFM_INFO,
+ ]
+
+ def __init__(self, package_path=""):
+ super().__init__(package_path)
+
+ def generate_manifest_data_from_descriptor(self):
+ descriptor_path = os.path.join(
+ self._package_path, self.get_descriptor_location()
+ )
+ with open(descriptor_path, "r") as descriptor:
+ try:
+ vnfd_data = yaml.safe_load(descriptor)["vnfd"]
+ except yaml.YAMLError as e:
+ print("Error reading descriptor {}: {}".format(descriptor_path, e))
+ return
+
+ self._manifest_metadata = {}
+ self._manifest_metadata[self._MANIFEST_VNFD_ID] = vnfd_data.get(
+ "id", "default-id"
+ )
+ self._manifest_metadata[self._MANIFEST_VNFD_PRODUCT_NAME] = vnfd_data.get(
+ "product-name", "default-product-name"
+ )
+ self._manifest_metadata[self._MANIFEST_VNFD_PROVIDER_ID] = vnfd_data.get(
+ "provider", "OSM"
+ )
+ self._manifest_metadata[
+ self._MANIFEST_VNFD_SOFTWARE_VERSION
+ ] = vnfd_data.get("version", "1.0")
+ self._manifest_metadata[self._MANIFEST_VNFD_PACKAGE_VERSION] = "1.0.0"
+ self._manifest_metadata[self._MANIFEST_VNFD_RELEASE_DATE_TIME] = (
+ datetime.datetime.now().astimezone().isoformat()
+ )
+ self._manifest_metadata[
+ self._MANIFEST_VNFD_COMPATIBLE_SPECIFICATION_VERSIONS
+ ] = "2.7.1"
+ self._manifest_metadata[self._MANIFEST_VNFM_INFO] = "OSM"
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: fbravo@whitestack.com
+##
+
+"""Python module for interacting with ETSI GS NFV-SOL007 compliant packages
+
+This module provides a SOL007Package class for validating and interacting with
+ETSI SOL007 packages. A valid SOL007 package may have its files arranged according
+to one of the following two structures:
+
+SOL007 with metadata directory SOL007 without metadata directory
+
+native_charm_vnf/ native_charm_vnf/
+├── TOSCA-Metadata ├── native_charm_nsd.mf
+│ └── TOSCA.meta ├── native_charm_nsd.yaml
+├── manifest.mf ├── ChangeLog.txt
+├── Definitions ├── Licenses
+│ └── native_charm_nsd.yaml │ └── license.lic
+├── Files ├── Files
+│ ├── icons │ └── icons
+│ │ └── osm.png │ └── osm.png
+│ ├── Licenses └── Scripts
+│ │ └── license.lic ├── cloud_init
+│ └── changelog.txt │ └── cloud-config.txt
+└── Scripts └── charms
+ ├── cloud_init └── simple
+ │ └── cloud-config.txt ├── config.yaml
+ └── charms ├── hooks
+ └── simple │ ├── install
+ ├── config.yaml ...
+ ├── hooks │
+ │ ├── install └── src
+ ... └── charm.py
+ └── src
+ └── charm.py
+"""
+
+import datetime
+import os
+
+import yaml
+
+from .sol_package import SOLPackage
+
+
+class SOL007PackageException(Exception):
+ pass
+
+
+class SOL007Package(SOLPackage):
+ _MANIFEST_NSD_INVARIANT_ID = "nsd_invariant_id"
+ _MANIFEST_NSD_NAME = "nsd_name"
+ _MANIFEST_NSD_DESIGNER = "nsd_designer"
+ _MANIFEST_NSD_FILE_STRUCTURE_VERSION = "nsd_file_structure_version"
+ _MANIFEST_NSD_RELEASE_DATE_TIME = "nsd_release_date_time"
+ _MANIFEST_NSD_COMPATIBLE_SPECIFICATION_VERSIONS = (
+ "compatible_specification_versions"
+ )
+
+ _MANIFEST_ALL_FIELDS = [
+ _MANIFEST_NSD_INVARIANT_ID,
+ _MANIFEST_NSD_NAME,
+ _MANIFEST_NSD_DESIGNER,
+ _MANIFEST_NSD_FILE_STRUCTURE_VERSION,
+ _MANIFEST_NSD_RELEASE_DATE_TIME,
+ _MANIFEST_NSD_COMPATIBLE_SPECIFICATION_VERSIONS,
+ ]
+
+ def __init__(self, package_path=""):
+ super().__init__(package_path)
+
+ def generate_manifest_data_from_descriptor(self):
+ descriptor_path = os.path.join(
+ self._package_path, self.get_descriptor_location()
+ )
+ with open(descriptor_path, "r") as descriptor:
+ try:
+ nsd_data = yaml.safe_load(descriptor)["nsd"]
+ except yaml.YAMLError as e:
+ print("Error reading descriptor {}: {}".format(descriptor_path, e))
+ return
+
+ self._manifest_metadata = {}
+ self._manifest_metadata[self._MANIFEST_NSD_INVARIANT_ID] = nsd_data.get(
+ "id", "default-id"
+ )
+ self._manifest_metadata[self._MANIFEST_NSD_NAME] = nsd_data.get(
+ "name", "default-name"
+ )
+ self._manifest_metadata[self._MANIFEST_NSD_DESIGNER] = nsd_data.get(
+ "designer", "OSM"
+ )
+ self._manifest_metadata[
+ self._MANIFEST_NSD_FILE_STRUCTURE_VERSION
+ ] = nsd_data.get("version", "1.0")
+ self._manifest_metadata[self._MANIFEST_NSD_RELEASE_DATE_TIME] = (
+ datetime.datetime.now().astimezone().isoformat()
+ )
+ self._manifest_metadata[
+ self._MANIFEST_NSD_COMPATIBLE_SPECIFICATION_VERSIONS
+ ] = "2.7.1"
--- /dev/null
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: fbravo@whitestack.com or agarcia@whitestack.com
+##
+import hashlib
+import os
+
+import yaml
+
+
+class SOLPackageException(Exception):
+ pass
+
+
+class SOLPackage:
+ _METADATA_FILE_PATH = "TOSCA-Metadata/TOSCA.meta"
+ _METADATA_DESCRIPTOR_FIELD = "Entry-Definitions"
+ _METADATA_MANIFEST_FIELD = "ETSI-Entry-Manifest"
+ _METADATA_CHANGELOG_FIELD = "ETSI-Entry-Change-Log"
+ _METADATA_LICENSES_FIELD = "ETSI-Entry-Licenses"
+ _METADATA_DEFAULT_CHANGELOG_PATH = "ChangeLog.txt"
+ _METADATA_DEFAULT_LICENSES_PATH = "Licenses"
+ _MANIFEST_FILE_PATH_FIELD = "Source"
+ _MANIFEST_FILE_HASH_ALGORITHM_FIELD = "Algorithm"
+ _MANIFEST_FILE_HASH_DIGEST_FIELD = "Hash"
+
+ _MANIFEST_ALL_FIELDS = []
+
+ def __init__(self, package_path=""):
+ self._package_path = package_path
+
+ self._package_metadata = self._parse_package_metadata()
+
+ try:
+ self._manifest_data = self._parse_manifest_data()
+ except Exception:
+ self._manifest_data = None
+
+ try:
+ self._manifest_metadata = self._parse_manifest_metadata()
+ except Exception:
+ self._manifest_metadata = None
+
+ def _parse_package_metadata(self):
+ try:
+ return self._parse_package_metadata_with_metadata_dir()
+ except FileNotFoundError:
+ return self._parse_package_metadata_without_metadata_dir()
+
+ def _parse_package_metadata_with_metadata_dir(self):
+ try:
+ return self._parse_file_in_blocks(self._METADATA_FILE_PATH)
+ except FileNotFoundError as e:
+ raise e
+ except (Exception, OSError) as e:
+ raise SOLPackageException(
+ "Error parsing {}: {}".format(self._METADATA_FILE_PATH, e)
+ )
+
+ def _parse_package_metadata_without_metadata_dir(self):
+ package_root_files = {f for f in os.listdir(self._package_path)}
+ package_root_yamls = [
+ f for f in package_root_files if f.endswith(".yml") or f.endswith(".yaml")
+ ]
+ if len(package_root_yamls) != 1:
+ error_msg = "Error parsing package metadata: there should be exactly 1 descriptor YAML, found {}"
+ raise SOLPackageException(error_msg.format(len(package_root_yamls)))
+
+ base_manifest = [
+ {
+ SOLPackage._METADATA_DESCRIPTOR_FIELD: package_root_yamls[0],
+ SOLPackage._METADATA_MANIFEST_FIELD: "{}.mf".format(
+ os.path.splitext(package_root_yamls[0])[0]
+ ),
+ SOLPackage._METADATA_CHANGELOG_FIELD: SOLPackage._METADATA_DEFAULT_CHANGELOG_PATH,
+ SOLPackage._METADATA_LICENSES_FIELD: SOLPackage._METADATA_DEFAULT_LICENSES_PATH,
+ }
+ ]
+
+ return base_manifest
+
+ def _parse_manifest_data(self):
+ manifest_path = None
+ for tosca_meta in self._package_metadata:
+ if SOLPackage._METADATA_MANIFEST_FIELD in tosca_meta:
+ manifest_path = tosca_meta[SOLPackage._METADATA_MANIFEST_FIELD]
+ break
+ else:
+ error_msg = "Error parsing {}: no {} field on path".format(
+ self._METADATA_FILE_PATH, self._METADATA_MANIFEST_FIELD
+ )
+ raise SOLPackageException(error_msg)
+
+ try:
+ return self._parse_file_in_blocks(manifest_path)
+
+ except (Exception, OSError) as e:
+ raise SOLPackageException("Error parsing {}: {}".format(manifest_path, e))
+
+ def _parse_manifest_metadata(self):
+ try:
+ base_manifest = {}
+ manifest_file = os.open(
+ os.path.join(
+ self._package_path, base_manifest[self._METADATA_MANIFEST_FIELD]
+ ),
+ "rw",
+ )
+ for line in manifest_file:
+ fields_in_line = line.split(":", maxsplit=1)
+ fields_in_line[0] = fields_in_line[0].strip()
+ fields_in_line[1] = fields_in_line[1].strip()
+ if fields_in_line[0] in self._MANIFEST_ALL_FIELDS:
+ base_manifest[fields_in_line[0]] = fields_in_line[1]
+ return base_manifest
+ except (Exception, OSError) as e:
+ raise SOLPackageException(
+ "Error parsing {}: {}".format(
+ base_manifest[SOLPackage._METADATA_MANIFEST_FIELD], e
+ )
+ )
+
+ def _get_package_file_full_path(self, file_relative_path):
+ return os.path.join(self._package_path, file_relative_path)
+
+ def _parse_file_in_blocks(self, file_relative_path):
+ file_path = self._get_package_file_full_path(file_relative_path)
+ with open(file_path) as f:
+ blocks = f.read().split("\n\n")
+ parsed_blocks = map(yaml.safe_load, blocks)
+ return [block for block in parsed_blocks if block is not None]
+
+ def _get_package_file_manifest_data(self, file_relative_path):
+ for file_data in self._manifest_data:
+ if (
+ file_data.get(SOLPackage._MANIFEST_FILE_PATH_FIELD, "")
+ == file_relative_path
+ ):
+ return file_data
+
+ error_msg = (
+ "Error parsing {} manifest data: file not found on manifest file".format(
+ file_relative_path
+ )
+ )
+ raise SOLPackageException(error_msg)
+
+ def get_package_file_hash_digest_from_manifest(self, file_relative_path):
+ """Returns the hash digest of a file inside this package as specified on the manifest file."""
+ file_manifest_data = self._get_package_file_manifest_data(file_relative_path)
+ try:
+ return file_manifest_data[SOLPackage._MANIFEST_FILE_HASH_DIGEST_FIELD]
+ except Exception as e:
+ raise SOLPackageException(
+ "Error parsing {} hash digest: {}".format(file_relative_path, e)
+ )
+
+ def get_package_file_hash_algorithm_from_manifest(self, file_relative_path):
+ """Returns the hash algorithm of a file inside this package as specified on the manifest file."""
+ file_manifest_data = self._get_package_file_manifest_data(file_relative_path)
+ try:
+ return file_manifest_data[SOLPackage._MANIFEST_FILE_HASH_ALGORITHM_FIELD]
+ except Exception as e:
+ raise SOLPackageException(
+ "Error parsing {} hash digest: {}".format(file_relative_path, e)
+ )
+
+ @staticmethod
+ def _get_hash_function_from_hash_algorithm(hash_algorithm):
+ function_to_algorithm = {"SHA-256": hashlib.sha256, "SHA-512": hashlib.sha512}
+ if hash_algorithm not in function_to_algorithm:
+ error_msg = (
+ "Error checking hash function: hash algorithm {} not supported".format(
+ hash_algorithm
+ )
+ )
+ raise SOLPackageException(error_msg)
+ return function_to_algorithm[hash_algorithm]
+
+ def _calculate_file_hash(self, file_relative_path, hash_algorithm):
+ file_path = self._get_package_file_full_path(file_relative_path)
+ hash_function = self._get_hash_function_from_hash_algorithm(hash_algorithm)
+ try:
+ with open(file_path, "rb") as f:
+ return hash_function(f.read()).hexdigest()
+ except Exception as e:
+ raise SOLPackageException(
+ "Error hashing {}: {}".format(file_relative_path, e)
+ )
+
+ def validate_package_file_hash(self, file_relative_path):
+ """Validates the integrity of a file using the hash algorithm and digest on the package manifest."""
+ hash_algorithm = self.get_package_file_hash_algorithm_from_manifest(
+ file_relative_path
+ )
+ file_hash = self._calculate_file_hash(file_relative_path, hash_algorithm)
+ expected_file_hash = self.get_package_file_hash_digest_from_manifest(
+ file_relative_path
+ )
+ if file_hash != expected_file_hash:
+ error_msg = "Error validating {} hash: calculated hash {} is different than manifest hash {}"
+ raise SOLPackageException(
+ error_msg.format(file_relative_path, file_hash, expected_file_hash)
+ )
+
+ def validate_package_hashes(self):
+ """Validates the integrity of all files listed on the package manifest."""
+ for file_data in self._manifest_data:
+ if SOLPackage._MANIFEST_FILE_PATH_FIELD in file_data:
+ file_relative_path = file_data[SOLPackage._MANIFEST_FILE_PATH_FIELD]
+ self.validate_package_file_hash(file_relative_path)
+
+ def create_or_update_metadata_file(self):
+ """
+ Creates or updates the metadata file with the hashes calculated for each one of the package's files
+ """
+ if not self._manifest_metadata:
+ self.generate_manifest_data_from_descriptor()
+
+ self.write_manifest_data_into_file()
+
+ def generate_manifest_data_from_descriptor(self):
+ pass
+
+ def write_manifest_data_into_file(self):
+ with open(self.get_manifest_location(), "w") as metadata_file:
+ # Write manifest metadata
+ for metadata_entry in self._manifest_metadata:
+ metadata_file.write(
+ "{}: {}\n".format(
+ metadata_entry, self._manifest_metadata[metadata_entry]
+ )
+ )
+
+ # Write package's files hashes
+ file_hashes = {}
+ for root, dirs, files in os.walk(self._package_path):
+ for a_file in files:
+ file_path = os.path.join(root, a_file)
+ file_relative_path = file_path[len(self._package_path) :]
+ if file_relative_path.startswith("/"):
+ file_relative_path = file_relative_path[1:]
+ file_hashes[file_relative_path] = self._calculate_file_hash(
+ file_relative_path, "SHA-512"
+ )
+
+ for file, hash in file_hashes.items():
+ file_block = "Source: {}\nAlgorithm: SHA-512\nHash: {}\n\n".format(
+ file, hash
+ )
+ metadata_file.write(file_block)
+
+ def get_descriptor_location(self):
+ """Returns this package descriptor location as a relative path from the package root."""
+ for tosca_meta in self._package_metadata:
+ if SOLPackage._METADATA_DESCRIPTOR_FIELD in tosca_meta:
+ return tosca_meta[SOLPackage._METADATA_DESCRIPTOR_FIELD]
+
+ error_msg = "Error: no {} entry found on {}".format(
+ SOLPackage._METADATA_DESCRIPTOR_FIELD, SOLPackage._METADATA_FILE_PATH
+ )
+ raise SOLPackageException(error_msg)
+
+ def get_manifest_location(self):
+ """Return the VNF/NS manifest location as a relative path from the package root."""
+ for tosca_meta in self._package_metadata:
+ if SOLPackage._METADATA_MANIFEST_FIELD in tosca_meta:
+ return tosca_meta[SOLPackage._METADATA_MANIFEST_FIELD]
+
+ raise SOLPackageException("No manifest file defined for this package")
# limitations under the License.
[DEFAULT]
-X-Python3-Version : >= 3.5
+X-Python3-Version : >= 3.13
+X-Debian-Depends : python3, python3-pkg-resources
\ No newline at end of file
--- /dev/null
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+
+vnfd:
+ description: A VNF consisting of 1 VDU connected to two external VL, and one for
+ data and another one for management
+ df:
+ - id: default-df
+ instantiation-level:
+ - id: default-instantiation-level
+ vdu-level:
+ - number-of-instances: 1
+ vdu-id: mgmtVM
+ vdu-profile:
+ - id: mgmtVM
+ min-number-of-instances: 1
+ vdu-configuration-id: mgmtVM-vdu-configuration
+ ext-cpd:
+ - id: vnf-mgmt-ext
+ int-cpd:
+ cpd: mgmtVM-eth0-int
+ vdu-id: mgmtVM
+ - id: vnf-data-ext
+ int-cpd:
+ cpd: dataVM-xe0-int
+ vdu-id: mgmtVM
+ id: native_charm-vnf
+ mgmt-cp: vnf-mgmt-ext
+ product-name: native_charm-vnf
+ sw-image-desc:
+ - id: ubuntu18.04
+ image: ubuntu18.04
+ name: ubuntu18.04
+ vdu:
+ - cloud-init-file: cloud-config.txt
+ id: mgmtVM
+ int-cpd:
+ - id: mgmtVM-eth0-int
+ virtual-network-interface-requirement:
+ - name: mgmtVM-eth0
+ position: 1
+ virtual-interface:
+ type: PARAVIRT
+ - id: dataVM-xe0-int
+ virtual-network-interface-requirement:
+ - name: dataVM-xe0
+ position: 2
+ virtual-interface:
+ type: PARAVIRT
+ name: mgmtVM
+ sw-image-desc: ubuntu18.04
+ virtual-compute-desc: mgmtVM-compute
+ virtual-storage-desc:
+ - mgmtVM-storage
+ vdu-configuration:
+ - config-access:
+ ssh-access:
+ default-user: ubuntu
+ required: true
+ config-primitive:
+ - name: touch
+ parameter:
+ - data-type: STRING
+ default-value: /home/ubuntu/touched
+ name: filename
+ id: mgmtVM-vdu-configuration
+ initial-config-primitive:
+ - name: touch
+ parameter:
+ - data-type: STRING
+ name: filename
+ value: /home/ubuntu/first-touch
+ seq: 1
+ juju:
+ charm: simple
+ proxy: false
+ version: 1.0
+ virtual-compute-desc:
+ - id: mgmtVM-compute
+ virtual-cpu:
+ num-virtual-cpu: 1
+ virtual-memory:
+ size: 1.0
+ virtual-storage-desc:
+ - id: mgmtVM-storage
+ size-of-storage: 10
--- /dev/null
+#!/usr/bin/env python3
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+
+import subprocess
+import sys
+
+from ops.charm import CharmBase
+from ops.main import main
+from ops.model import ActiveStatus
+
+sys.path.append("lib")
+
+
+class MyNativeCharm(CharmBase):
+ def __init__(self, framework, key):
+ super().__init__(framework, key)
+
+ # Listen to charm events
+ self.framework.observe(self.on.config_changed, self.on_config_changed)
+ self.framework.observe(self.on.install, self.on_install)
+ self.framework.observe(self.on.start, self.on_start)
+
+ # Listen to the touch action event
+ self.framework.observe(self.on.touch_action, self.on_touch_action)
+
+ def on_config_changed(self, event):
+ """Handle changes in configuration"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_install(self, event):
+ """Called when the charm is being installed"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_start(self, event):
+ """Called when the charm is being started"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_touch_action(self, event):
+ """Touch a file."""
+
+ filename = event.params["filename"]
+ try:
+ subprocess.run(["touch", filename], check=True)
+ event.set_results({"created": True, "filename": filename})
+ except subprocess.CalledProcessError as e:
+ event.fail("Action failed: {}".format(e))
+ self.model.unit.status = ActiveStatus()
+
+
+if __name__ == "__main__":
+ main(MyNativeCharm)
--- /dev/null
+#
+# Copyright 2020 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+
+#cloud-config
+chpasswd: { expire: False }
+ssh_pwauth: True
+
+write_files:
+- content: |
+ # My new helloworld file
+
+ owner: root:root
+ permissions: '0644'
+ path: /root/helloworld.txt
--- /dev/null
+#
+# Copyright 2020 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: agarcia@whitestack.com
+##
+
+TOSCA-Meta-Version: 1.0
+CSAR-Version: 1.0
+Created-By: Diego Armando Maradona
+Entry-Definitions: Definitions/native_charm_vnfd.yaml # Points to the main descriptor of the package
+ETSI-Entry-Manifest: manifest.mf # Points to the ETSI manifest file
+ETSI-Entry-Change-Log: Files/Changelog.txt # Points to package changelog
+ETSI-Entry-Licenses: Files/Licenses # Points to package licenses folder
+
+# In principle, we could add one block per package file to specify MIME types
+Name: Definitions/native_charm_vnfd.yaml # path to file within package
+Content-Type: application/yaml # MIME type of file
+
+Name: Scripts/cloud_init/cloud-config.txt
+Content-Type: application/yaml
+
--- /dev/null
+#
+# Copyright 2020 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: agarcia@whitestack.com
+##
+
+# General definitions of the package
+vnfd_id: native_charm-vnf
+vnf_product_name: native_charm-vnf
+vnf_provider_id: AFA
+vnf_software_version: 1.0
+vnf_package_version: 1.0
+vnf_release_date_time: 2021.12.01T11:36-03:00
+compatible_specification_versions: 3.3.1
+vnfm_info: OSM
+
+Source: Definitions/native_charm_vnfd.yaml
+Algorithm: SHA-256
+Hash: ede8daf9748ac4849e1a1aac955d6c84cafef9ea34067eaef76ee4e5996974c2
+
+Source: Scripts/cloud_init/cloud-config.txt
+Algorithm: SHA-256
+Hash: 7455ca868843cc5da1f0a2255cdedb64a69df3b618c344b83b82848a94540eda
+
+
+# The below sections are all wrong on purpose as they are intended for testing
+
+# Invalid hash algorithm
+Source: Scripts/charms/simple/src/charm.py
+Algorithm: SHA-733
+Hash: ea72f897a966e6174ed9164fabc3c500df5a2f712eb6b22ab2408afb07d04d14
+
+# Wrong hash
+Source: Scripts/charms/simple/hooks/start
+Algorithm: SHA-256
+Hash: 123456aaaaaa123456aaaaaae2bb9d0197f41619165dde6cf205c974f9aa86ae
+
+# Unspecified hash
+Source: Scripts/charms/simple/hooks/upgrade-charm
+Algorithm: SHA-256
--- /dev/null
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+
+vnfd:
+ description: A VNF consisting of 1 VDU connected to two external VL, and one for
+ data and another one for management
+ df:
+ - id: default-df
+ instantiation-level:
+ - id: default-instantiation-level
+ vdu-level:
+ - number-of-instances: 1
+ vdu-id: mgmtVM
+ vdu-profile:
+ - id: mgmtVM
+ min-number-of-instances: 1
+ vdu-configuration-id: mgmtVM-vdu-configuration
+ ext-cpd:
+ - id: vnf-mgmt-ext
+ int-cpd:
+ cpd: mgmtVM-eth0-int
+ vdu-id: mgmtVM
+ - id: vnf-data-ext
+ int-cpd:
+ cpd: dataVM-xe0-int
+ vdu-id: mgmtVM
+ id: native_charm-vnf
+ mgmt-cp: vnf-mgmt-ext
+ product-name: native_charm-vnf
+ sw-image-desc:
+ - id: ubuntu18.04
+ image: ubuntu18.04
+ name: ubuntu18.04
+ vdu:
+ - cloud-init-file: cloud-config.txt
+ id: mgmtVM
+ int-cpd:
+ - id: mgmtVM-eth0-int
+ virtual-network-interface-requirement:
+ - name: mgmtVM-eth0
+ position: 1
+ virtual-interface:
+ type: PARAVIRT
+ - id: dataVM-xe0-int
+ virtual-network-interface-requirement:
+ - name: dataVM-xe0
+ position: 2
+ virtual-interface:
+ type: PARAVIRT
+ name: mgmtVM
+ sw-image-desc: ubuntu18.04
+ virtual-compute-desc: mgmtVM-compute
+ virtual-storage-desc:
+ - mgmtVM-storage
+ vdu-configuration:
+ - config-access:
+ ssh-access:
+ default-user: ubuntu
+ required: true
+ config-primitive:
+ - name: touch
+ parameter:
+ - data-type: STRING
+ default-value: /home/ubuntu/touched
+ name: filename
+ id: mgmtVM-vdu-configuration
+ initial-config-primitive:
+ - name: touch
+ parameter:
+ - data-type: STRING
+ name: filename
+ value: /home/ubuntu/first-touch
+ seq: 1
+ juju:
+ charm: simple
+ proxy: false
+ version: 1.0
+ virtual-compute-desc:
+ - id: mgmtVM-compute
+ virtual-cpu:
+ num-virtual-cpu: 1
+ virtual-memory:
+ size: 1.0
+ virtual-storage-desc:
+ - id: mgmtVM-storage
+ size-of-storage: 10
--- /dev/null
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+
+1.0.0: First version
--- /dev/null
+#
+# Copyright 2020 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
--- /dev/null
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+touch:
+ description: "Touch a file on the VNF."
+ params:
+ filename:
+ description: "The name of the file to touch."
+ type: string
+ default: ""
+ required:
+ - filename
--- /dev/null
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+options: {}
\ No newline at end of file
--- /dev/null
+#!/usr/bin/env python3
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+
+import sys
+import subprocess
+
+sys.path.append("lib")
+
+from ops.charm import CharmBase
+from ops.main import main
+from ops.model import ActiveStatus
+
+class MyNativeCharm(CharmBase):
+
+ def __init__(self, framework, key):
+ super().__init__(framework, key)
+
+ # Listen to charm events
+ self.framework.observe(self.on.config_changed, self.on_config_changed)
+ self.framework.observe(self.on.install, self.on_install)
+ self.framework.observe(self.on.start, self.on_start)
+
+ # Listen to the touch action event
+ self.framework.observe(self.on.touch_action, self.on_touch_action)
+
+ def on_config_changed(self, event):
+ """Handle changes in configuration"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_install(self, event):
+ """Called when the charm is being installed"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_start(self, event):
+ """Called when the charm is being started"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_touch_action(self, event):
+ """Touch a file."""
+
+ filename = event.params["filename"]
+ try:
+ subprocess.run(["touch", filename], check=True)
+ event.set_results({"created": True, "filename": filename})
+ except subprocess.CalledProcessError as e:
+ event.fail("Action failed: {}".format(e))
+ self.model.unit.status = ActiveStatus()
+
+
+if __name__ == "__main__":
+ main(MyNativeCharm)
+
--- /dev/null
+#!/usr/bin/env python3
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+
+import sys
+import subprocess
+
+sys.path.append("lib")
+
+from ops.charm import CharmBase
+from ops.main import main
+from ops.model import ActiveStatus
+
+class MyNativeCharm(CharmBase):
+
+ def __init__(self, framework, key):
+ super().__init__(framework, key)
+
+ # Listen to charm events
+ self.framework.observe(self.on.config_changed, self.on_config_changed)
+ self.framework.observe(self.on.install, self.on_install)
+ self.framework.observe(self.on.start, self.on_start)
+
+ # Listen to the touch action event
+ self.framework.observe(self.on.touch_action, self.on_touch_action)
+
+ def on_config_changed(self, event):
+ """Handle changes in configuration"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_install(self, event):
+ """Called when the charm is being installed"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_start(self, event):
+ """Called when the charm is being started"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_touch_action(self, event):
+ """Touch a file."""
+
+ filename = event.params["filename"]
+ try:
+ subprocess.run(["touch", filename], check=True)
+ event.set_results({"created": True, "filename": filename})
+ except subprocess.CalledProcessError as e:
+ event.fail("Action failed: {}".format(e))
+ self.model.unit.status = ActiveStatus()
+
+
+if __name__ == "__main__":
+ main(MyNativeCharm)
+
--- /dev/null
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+
+name: simple-native
+summary: A simple native charm
+description: |
+ Simple native charm
+series:
+ - bionic
+ - xenial
+ - focal
\ No newline at end of file
--- /dev/null
+#!/usr/bin/env python3
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+
+import subprocess
+import sys
+
+from ops.charm import CharmBase
+from ops.main import main
+from ops.model import ActiveStatus
+
+sys.path.append("lib")
+
+
+class MyNativeCharm(CharmBase):
+ def __init__(self, framework, key):
+ super().__init__(framework, key)
+
+ # Listen to charm events
+ self.framework.observe(self.on.config_changed, self.on_config_changed)
+ self.framework.observe(self.on.install, self.on_install)
+ self.framework.observe(self.on.start, self.on_start)
+
+ # Listen to the touch action event
+ self.framework.observe(self.on.touch_action, self.on_touch_action)
+
+ def on_config_changed(self, event):
+ """Handle changes in configuration"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_install(self, event):
+ """Called when the charm is being installed"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_start(self, event):
+ """Called when the charm is being started"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_touch_action(self, event):
+ """Touch a file."""
+
+ filename = event.params["filename"]
+ try:
+ subprocess.run(["touch", filename], check=True)
+ event.set_results({"created": True, "filename": filename})
+ except subprocess.CalledProcessError as e:
+ event.fail("Action failed: {}".format(e))
+ self.model.unit.status = ActiveStatus()
+
+
+if __name__ == "__main__":
+ main(MyNativeCharm)
--- /dev/null
+#
+# Copyright 2020 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+
+
+#cloud-config
+chpasswd: { expire: False }
+ssh_pwauth: True
+
+write_files:
+- content: |
+ # My new helloworld file
+
+ owner: root:root
+ permissions: '0644'
+ path: /root/helloworld.txt
--- /dev/null
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: agarcia@whitestack.com
+##
+
+TOSCA-Meta-Version: 1.0
+CSAR-Version: 1.0
+Created-By: Diego Armando Maradona
+Entry-Definitions: Definitions/native_charm_vnfd.yaml # Points to the main descriptor of the package
+ETSI-Entry-Manifest: manifest.mf # Points to the ETSI manifest file
+ETSI-Entry-Change-Log: Files/Changelog.txt # Points to package changelog
+ETSI-Entry-Licenses: Files/Licenses # Points to package licenses folder
+
+# In principle, we could add one block per package file to specify MIME types
+Name: Definitions/native_charm_vnfd.yaml # path to file within package
+Content-Type: application/yaml # MIME type of file
+
+Name: Scripts/cloud_init/cloud-config.txt
+Content-Type: application/yaml
\ No newline at end of file
--- /dev/null
+#
+# Copyright 2020 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: agarcia@whitestack.com
+##
+
+# General definitions of the package
+vnfd_id: native_charm-vnf
+vnf_product_name: native_charm-vnf
+vnf_provider_id: AFA
+vnf_software_version: 1.0
+vnf_package_version: 1.0
+vnf_release_date_time: 2021.12.01T11:36-03:00
+compatible_specification_versions: 3.3.1
+vnfm_info: OSM
+
+# One block for every file in the package
+Source: Definitions/native_charm_vnfd.yaml
+Algorithm: SHA-256
+Hash: ede8daf9748ac4849e1a1aac955d6c84cafef9ea34067eaef76ee4e5996974c2
+
+
+
+Source: Scripts/cloud_init/cloud-config.txt
+Algorithm: SHA-256
+Hash: 0eef3f1a642339e2053af48a7e370dac1952f9cb81166e439e8f72afd6f03621
+
+# Charms files
+
+Source: Scripts/charms/simple/src/charm.py
+Algorithm: SHA-256
+Hash: ea72f897a966e6174ed9164fabc3c500df5a2f712eb6b22ab2408afb07d04d14
+
+Source: Scripts/charms/simple/hooks/start
+Algorithm: SHA-256
+Hash: 312490afd82cc86ad823e4d9e2bb9d0197f41619165dde6cf205c974f9aa86ae
+
+Source: Scripts/charms/simple/hooks/install
+Algorithm: SHA-256
+Hash: 312490afd82cc86ad823e4d9e2bb9d0197f41619165dde6cf205c974f9aa86ae
+
+Source: Scripts/charms/simple/actions.yaml
+Algorithm: SHA-256
+Hash: 988ca2653ae6a3977149faaebd664a12858e0025f226b27d2cee1fa954c9462d
+
+Source: Scripts/charms/simple/metadata.yaml
+Algorithm: SHA-256
+Hash: e00cfaf41a518aef0f486e4ae04a5ae19feffa774abfbdb68379bb5b5b102479
+
+Source: Scripts/charms/simple/config.yaml
+Algorithm: SHA-256
+Hash: f5cbf31b9c299504f3b577417b6c82bde5e3eafd74ee11fdeecf8c8bff6cf3e2
+
+
+# And on and on
--- /dev/null
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+
+1.0.0: First version
--- /dev/null
+#
+# Copyright 2020 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
--- /dev/null
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+touch:
+ description: "Touch a file on the VNF."
+ params:
+ filename:
+ description: "The name of the file to touch."
+ type: string
+ default: ""
+ required:
+ - filename
--- /dev/null
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+options: {}
\ No newline at end of file
--- /dev/null
+#!/usr/bin/env python3
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+
+import sys
+import subprocess
+
+sys.path.append("lib")
+
+from ops.charm import CharmBase
+from ops.main import main
+from ops.model import ActiveStatus
+
+class MyNativeCharm(CharmBase):
+
+ def __init__(self, framework, key):
+ super().__init__(framework, key)
+
+ # Listen to charm events
+ self.framework.observe(self.on.config_changed, self.on_config_changed)
+ self.framework.observe(self.on.install, self.on_install)
+ self.framework.observe(self.on.start, self.on_start)
+
+ # Listen to the touch action event
+ self.framework.observe(self.on.touch_action, self.on_touch_action)
+
+ def on_config_changed(self, event):
+ """Handle changes in configuration"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_install(self, event):
+ """Called when the charm is being installed"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_start(self, event):
+ """Called when the charm is being started"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_touch_action(self, event):
+ """Touch a file."""
+
+ filename = event.params["filename"]
+ try:
+ subprocess.run(["touch", filename], check=True)
+ event.set_results({"created": True, "filename": filename})
+ except subprocess.CalledProcessError as e:
+ event.fail("Action failed: {}".format(e))
+ self.model.unit.status = ActiveStatus()
+
+
+if __name__ == "__main__":
+ main(MyNativeCharm)
+
--- /dev/null
+#!/usr/bin/env python3
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+
+import sys
+import subprocess
+
+sys.path.append("lib")
+
+from ops.charm import CharmBase
+from ops.main import main
+from ops.model import ActiveStatus
+
+class MyNativeCharm(CharmBase):
+
+ def __init__(self, framework, key):
+ super().__init__(framework, key)
+
+ # Listen to charm events
+ self.framework.observe(self.on.config_changed, self.on_config_changed)
+ self.framework.observe(self.on.install, self.on_install)
+ self.framework.observe(self.on.start, self.on_start)
+
+ # Listen to the touch action event
+ self.framework.observe(self.on.touch_action, self.on_touch_action)
+
+ def on_config_changed(self, event):
+ """Handle changes in configuration"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_install(self, event):
+ """Called when the charm is being installed"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_start(self, event):
+ """Called when the charm is being started"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_touch_action(self, event):
+ """Touch a file."""
+
+ filename = event.params["filename"]
+ try:
+ subprocess.run(["touch", filename], check=True)
+ event.set_results({"created": True, "filename": filename})
+ except subprocess.CalledProcessError as e:
+ event.fail("Action failed: {}".format(e))
+ self.model.unit.status = ActiveStatus()
+
+
+if __name__ == "__main__":
+ main(MyNativeCharm)
+
--- /dev/null
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+
+name: simple-native
+summary: A simple native charm
+description: |
+ Simple native charm
+series:
+ - bionic
+ - xenial
+ - focal
\ No newline at end of file
--- /dev/null
+#!/usr/bin/env python3
+##
+# Copyright 2020 Canonical Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+##
+
+import subprocess
+import sys
+
+from ops.charm import CharmBase
+from ops.main import main
+from ops.model import ActiveStatus
+
+sys.path.append("lib")
+
+
+class MyNativeCharm(CharmBase):
+ def __init__(self, framework, key):
+ super().__init__(framework, key)
+
+ # Listen to charm events
+ self.framework.observe(self.on.config_changed, self.on_config_changed)
+ self.framework.observe(self.on.install, self.on_install)
+ self.framework.observe(self.on.start, self.on_start)
+
+ # Listen to the touch action event
+ self.framework.observe(self.on.touch_action, self.on_touch_action)
+
+ def on_config_changed(self, event):
+ """Handle changes in configuration"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_install(self, event):
+ """Called when the charm is being installed"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_start(self, event):
+ """Called when the charm is being started"""
+ self.model.unit.status = ActiveStatus()
+
+ def on_touch_action(self, event):
+ """Touch a file."""
+
+ filename = event.params["filename"]
+ try:
+ subprocess.run(["touch", filename], check=True)
+ event.set_results({"created": True, "filename": filename})
+ except subprocess.CalledProcessError as e:
+ event.fail("Action failed: {}".format(e))
+ self.model.unit.status = ActiveStatus()
+
+
+if __name__ == "__main__":
+ main(MyNativeCharm)
--- /dev/null
+#
+# Copyright 2020 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+
+
+#cloud-config
+chpasswd: { expire: False }
+ssh_pwauth: True
+
+write_files:
+- content: |
+ # My new helloworld file
+
+ owner: root:root
+ permissions: '0644'
+ path: /root/helloworld.txt
--- /dev/null
+#
+# Copyright 2020 Whitestack, LLC
+# *************************************************************
+#
+# This file is part of OSM common repository.
+# All Rights Reserved to Whitestack, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: agarcia@whitestack.com
+##
+
+# General definitions of the package
+vnfd_id: native_charm-vnf
+vnf_product_name: native_charm-vnf
+vnf_provider_id: AFA
+vnf_software_version: 1.0
+vnf_package_version: 1.0
+vnf_release_date_time: 2021.12.01T11:36-03:00
+compatible_specification_versions: 3.3.1
+vnfm_info: OSM
+
+# One block for every file in the package
+Source: native_charm_vnfd.yaml
+Algorithm: SHA-256
+Hash: ae06780c082041676df4ca4130ef223548eee6389007ba259416f59044450a7c
+
+
+
+Source: Scripts/cloud_init/cloud-config.txt
+Algorithm: SHA-256
+Hash: 0eef3f1a642339e2053af48a7e370dac1952f9cb81166e439e8f72afd6f03621
+
+# Charms files
+
+Source: Scripts/charms/simple/src/charm.py
+Algorithm: SHA-256
+Hash: ea72f897a966e6174ed9164fabc3c500df5a2f712eb6b22ab2408afb07d04d14
+
+Source: Scripts/charms/simple/hooks/start
+Algorithm: SHA-256
+Hash: 312490afd82cc86ad823e4d9e2bb9d0197f41619165dde6cf205c974f9aa86ae
+
+Source: Scripts/charms/simple/hooks/install
+Algorithm: SHA-256
+Hash: 312490afd82cc86ad823e4d9e2bb9d0197f41619165dde6cf205c974f9aa86ae
+
+Source: Scripts/charms/simple/actions.yaml
+Algorithm: SHA-256
+Hash: 988ca2653ae6a3977149faaebd664a12858e0025f226b27d2cee1fa954c9462d
+
+Source: Scripts/charms/simple/metadata.yaml
+Algorithm: SHA-256
+Hash: e00cfaf41a518aef0f486e4ae04a5ae19feffa774abfbdb68379bb5b5b102479
+
+Source: Scripts/charms/simple/config.yaml
+Algorithm: SHA-256
+Hash: f5cbf31b9c299504f3b577417b6c82bde5e3eafd74ee11fdeecf8c8bff6cf3e2
+
+
+# And on and on
--- /dev/null
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+
+metadata:
+ template_name: native_charm-vnf
+ template_author: AFA
+ template_version: 1.1
+
+vnfd:
+ description: A VNF consisting of 1 VDU connected to two external VL, and one for
+ data and another one for management
+ df:
+ - id: default-df
+ instantiation-level:
+ - id: default-instantiation-level
+ vdu-level:
+ - number-of-instances: 1
+ vdu-id: mgmtVM
+ vdu-profile:
+ - id: mgmtVM
+ min-number-of-instances: 1
+ vdu-configuration-id: mgmtVM-vdu-configuration
+ ext-cpd:
+ - id: vnf-mgmt-ext
+ int-cpd:
+ cpd: mgmtVM-eth0-int
+ vdu-id: mgmtVM
+ - id: vnf-data-ext
+ int-cpd:
+ cpd: dataVM-xe0-int
+ vdu-id: mgmtVM
+ id: native_charm-vnf
+ mgmt-cp: vnf-mgmt-ext
+ product-name: native_charm-vnf
+ provider: AFA
+ sw-image-desc:
+ - id: ubuntu18.04
+ image: ubuntu18.04
+ name: ubuntu18.04
+ vdu:
+ - cloud-init-file: cloud-config.txt
+ id: mgmtVM
+ int-cpd:
+ - id: mgmtVM-eth0-int
+ virtual-network-interface-requirement:
+ - name: mgmtVM-eth0
+ position: 1
+ virtual-interface:
+ type: PARAVIRT
+ - id: dataVM-xe0-int
+ virtual-network-interface-requirement:
+ - name: dataVM-xe0
+ position: 2
+ virtual-interface:
+ type: PARAVIRT
+ name: mgmtVM
+ sw-image-desc: ubuntu18.04
+ virtual-compute-desc: mgmtVM-compute
+ virtual-storage-desc:
+ - mgmtVM-storage
+ vdu-configuration:
+ - config-access:
+ ssh-access:
+ default-user: ubuntu
+ required: true
+ config-primitive:
+ - name: touch
+ parameter:
+ - data-type: STRING
+ default-value: /home/ubuntu/touched
+ name: filename
+ id: mgmtVM-vdu-configuration
+ initial-config-primitive:
+ - name: touch
+ parameter:
+ - data-type: STRING
+ name: filename
+ value: /home/ubuntu/first-touch
+ seq: 1
+ juju:
+ charm: simple
+ proxy: false
+ version: 1.0
+ virtual-compute-desc:
+ - id: mgmtVM-compute
+ virtual-cpu:
+ num-virtual-cpu: 1
+ virtual-memory:
+ size: 1.0
+ virtual-storage-desc:
+ - id: mgmtVM-storage
+ size-of-storage: 10
--- /dev/null
+# Copyright 2018 Whitestack, LLC
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
+##
+import asyncio
+import copy
+from copy import deepcopy
+import http
+from http import HTTPStatus
+import logging
+from os import urandom
+import unittest
+from unittest.mock import MagicMock, Mock, patch
+
+from Crypto.Cipher import AES
+from osm_common.dbbase import DbBase, DbException, deep_update, Encryption
+import pytest
+
+
+# Variables used in TestBaseEncryption and TestAsyncEncryption
+salt = "1afd5d1a-4a7e-4d9c-8c65-251290183106"
+value = "private key txt"
+padded_value = b"private key txt\0"
+padded_encoded_value = b"private key txt\x00"
+encoding_type = "ascii"
+encyrpt_mode = AES.MODE_ECB
+secret_key = b"\xeev\xc2\xb8\xb2#;Ek\xd0\xb5['\x04\xed\x1f\xb9?\xc5Ig\x80\xd5\x8d\x8aT\xd7\xf8Q\xe2u!"
+encyrpted_value = "ZW5jcnlwdGVkIGRhdGE="
+encyrpted_bytes = b"ZW5jcnlwdGVkIGRhdGE="
+data_to_b4_encode = b"encrypted data"
+b64_decoded = b"decrypted data"
+schema_version = "1.1"
+joined_key = b"\x9d\x17\xaf\xc8\xdeF\x1b.\x0e\xa9\xb5['\x04\xed\x1f\xb9?\xc5Ig\x80\xd5\x8d\x8aT\xd7\xf8Q\xe2u!"
+serial_bytes = b"\xf8\x96Z\x1c:}\xb5\xdf\x94\x8d\x0f\x807\xe6)\x8f\xf5!\xee}\xc2\xfa\xb3\t\xb9\xe4\r7\x19\x08\xa5b"
+base64_decoded_serial = b"g\xbe\xdb"
+decrypted_val1 = "BiV9YZEuSRAudqvz7Gs+bg=="
+decrypted_val2 = "q4LwnFdoryzbZJM5mCAnpA=="
+item = {
+ "secret": "mysecret",
+ "cacert": "mycacert",
+ "path": "/var",
+ "ip": "192.168.12.23",
+}
+
+
+def exception_message(message):
+ return "database exception " + message
+
+
+@pytest.fixture
+def db_base():
+ return DbBase()
+
+
+def test_constructor():
+ db_base = DbBase()
+ assert db_base is not None
+ assert isinstance(db_base, DbBase)
+
+
+def test_db_connect(db_base):
+ with pytest.raises(DbException) as excinfo:
+ db_base.db_connect(None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'db_connect' not implemented")
+ )
+
+
+def test_db_disconnect(db_base):
+ db_base.db_disconnect()
+
+
+def test_get_list(db_base):
+ with pytest.raises(DbException) as excinfo:
+ db_base.get_list(None, None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'get_list' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+def test_get_one(db_base):
+ with pytest.raises(DbException) as excinfo:
+ db_base.get_one(None, None, None, None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'get_one' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+def test_create(db_base):
+ with pytest.raises(DbException) as excinfo:
+ db_base.create(None, None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'create' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+def test_create_list(db_base):
+ with pytest.raises(DbException) as excinfo:
+ db_base.create_list(None, None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'create_list' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+def test_del_list(db_base):
+ with pytest.raises(DbException) as excinfo:
+ db_base.del_list(None, None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'del_list' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+def test_del_one(db_base):
+ with pytest.raises(DbException) as excinfo:
+ db_base.del_one(None, None, None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'del_one' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+class TestEncryption(unittest.TestCase):
+ def setUp(self):
+ master_key = "Setting a long master key with numbers 123 and capitals AGHBNHD and symbols %&8)!'"
+ db_base1 = DbBase()
+ db_base2 = DbBase()
+ db_base3 = DbBase()
+ # set self.secret_key obtained when connect
+ db_base1.set_secret_key(master_key, replace=True)
+ db_base1.set_secret_key(urandom(32))
+ db_base2.set_secret_key(None, replace=True)
+ db_base2.set_secret_key(urandom(30))
+ db_base3.set_secret_key(master_key)
+ self.db_bases = [db_base1, db_base2, db_base3]
+
+ def test_encrypt_decrypt(self):
+ TEST = (
+ ("plain text 1 ! ", None),
+ ("plain text 2 with salt ! ", "1afd5d1a-4a7e-4d9c-8c65-251290183106"),
+ )
+ for db_base in self.db_bases:
+ for value, salt in TEST:
+ # no encryption
+ encrypted = db_base.encrypt(value, schema_version="1.0", salt=salt)
+ self.assertEqual(
+ encrypted, value, "value '{}' has been encrypted".format(value)
+ )
+ decrypted = db_base.decrypt(encrypted, schema_version="1.0", salt=salt)
+ self.assertEqual(
+ decrypted, value, "value '{}' has been decrypted".format(value)
+ )
+
+ # encrypt/decrypt
+ encrypted = db_base.encrypt(value, schema_version="1.1", salt=salt)
+ self.assertNotEqual(
+ encrypted, value, "value '{}' has not been encrypted".format(value)
+ )
+ self.assertIsInstance(encrypted, str, "Encrypted is not ascii text")
+ decrypted = db_base.decrypt(encrypted, schema_version="1.1", salt=salt)
+ self.assertEqual(
+ decrypted, value, "value is not equal after encryption/decryption"
+ )
+
+ def test_encrypt_decrypt_salt(self):
+ value = "value to be encrypted!"
+ encrypted = []
+ for db_base in self.db_bases:
+ for salt in (None, "salt 1", "1afd5d1a-4a7e-4d9c-8c65-251290183106"):
+ # encrypt/decrypt
+ encrypted.append(
+ db_base.encrypt(value, schema_version="1.1", salt=salt)
+ )
+ self.assertNotEqual(
+ encrypted[-1],
+ value,
+ "value '{}' has not been encrypted".format(value),
+ )
+ self.assertIsInstance(encrypted[-1], str, "Encrypted is not ascii text")
+ decrypted = db_base.decrypt(
+ encrypted[-1], schema_version="1.1", salt=salt
+ )
+ self.assertEqual(
+ decrypted, value, "value is not equal after encryption/decryption"
+ )
+ for i in range(0, len(encrypted)):
+ for j in range(i + 1, len(encrypted)):
+ self.assertNotEqual(
+ encrypted[i],
+ encrypted[j],
+ "encryption with different salt must contain different result",
+ )
+ # decrypt with a different master key
+ try:
+ decrypted = self.db_bases[-1].decrypt(
+ encrypted[0], schema_version="1.1", salt=None
+ )
+ self.assertNotEqual(
+ encrypted[0],
+ decrypted,
+ "Decryption with different KEY must generate different result",
+ )
+ except DbException as e:
+ self.assertEqual(
+ e.http_code,
+ HTTPStatus.INTERNAL_SERVER_ERROR,
+ "Decryption with different KEY does not provide expected http_code",
+ )
+
+
+class AsyncMock(MagicMock):
+ async def __call__(self, *args, **kwargs):
+ args = deepcopy(args)
+ kwargs = deepcopy(kwargs)
+ return super(AsyncMock, self).__call__(*args, **kwargs)
+
+
+class CopyingMock(MagicMock):
+ def __call__(self, *args, **kwargs):
+ args = deepcopy(args)
+ kwargs = deepcopy(kwargs)
+ return super(CopyingMock, self).__call__(*args, **kwargs)
+
+
+def check_if_assert_not_called(mocks: list):
+ for mocking in mocks:
+ mocking.assert_not_called()
+
+
+class TestBaseEncryption(unittest.TestCase):
+ @patch("logging.getLogger", autospec=True)
+ def setUp(self, mock_logger):
+ mock_logger = logging.getLogger()
+ mock_logger.disabled = True
+ self.db_base = DbBase()
+ self.mock_cipher = CopyingMock()
+ self.db_base.encoding_type = encoding_type
+ self.db_base.encrypt_mode = encyrpt_mode
+ self.db_base.secret_key = secret_key
+ self.mock_padded_msg = CopyingMock()
+
+ def test_pad_data_len_not_multiplication_of_16(self):
+ data = "hello word hello hello word hello word"
+ data_len = len(data)
+ expected_len = 48
+ padded = self.db_base.pad_data(data)
+ self.assertEqual(len(padded), expected_len)
+ self.assertTrue("\0" * (expected_len - data_len) in padded)
+
+ def test_pad_data_len_multiplication_of_16(self):
+ data = "hello word!!!!!!"
+ padded = self.db_base.pad_data(data)
+ self.assertEqual(padded, data)
+ self.assertFalse("\0" in padded)
+
+ def test_pad_data_empty_string(self):
+ data = ""
+ expected_len = 0
+ padded = self.db_base.pad_data(data)
+ self.assertEqual(len(padded), expected_len)
+ self.assertFalse("\0" in padded)
+
+ def test_pad_data_not_string(self):
+ data = None
+ with self.assertRaises(Exception) as err:
+ self.db_base.pad_data(data)
+ self.assertEqual(
+ str(err.exception),
+ "database exception Incorrect data type: type(None), string is expected.",
+ )
+
+ def test_unpad_data_null_char_at_right(self):
+ null_padded_data = "hell0word\0\0"
+ expected_length = len(null_padded_data) - 2
+ unpadded = self.db_base.unpad_data(null_padded_data)
+ self.assertEqual(len(unpadded), expected_length)
+ self.assertFalse("\0" in unpadded)
+ self.assertTrue("0" in unpadded)
+
+ def test_unpad_data_null_char_is_not_rightest(self):
+ null_padded_data = "hell0word\r\t\0\n"
+ expected_length = len(null_padded_data)
+ unpadded = self.db_base.unpad_data(null_padded_data)
+ self.assertEqual(len(unpadded), expected_length)
+ self.assertTrue("\0" in unpadded)
+
+ def test_unpad_data_with_spaces_at_right(self):
+ null_padded_data = " hell0word\0 "
+ expected_length = len(null_padded_data)
+ unpadded = self.db_base.unpad_data(null_padded_data)
+ self.assertEqual(len(unpadded), expected_length)
+ self.assertTrue("\0" in unpadded)
+
+ def test_unpad_data_empty_string(self):
+ data = ""
+ unpadded = self.db_base.unpad_data(data)
+ self.assertEqual(unpadded, "")
+ self.assertFalse("\0" in unpadded)
+
+ def test_unpad_data_not_string(self):
+ data = None
+ with self.assertRaises(Exception) as err:
+ self.db_base.unpad_data(data)
+ self.assertEqual(
+ str(err.exception),
+ "database exception Incorrect data type: type(None), string is expected.",
+ )
+
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "pad_data")
+ def test__encrypt_value_schema_version_1_0_none_secret_key_none_salt(
+ self, mock_pad_data, mock_join_secret_key
+ ):
+ """schema_version 1.0, secret_key is None and salt is None."""
+ schema_version = "1.0"
+ salt = None
+ self.db_base.secret_key = None
+ result = self.db_base._encrypt_value(value, schema_version, salt)
+ self.assertEqual(result, value)
+ check_if_assert_not_called([mock_pad_data, mock_join_secret_key])
+
+ @patch("osm_common.dbbase.b64encode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "pad_data")
+ def test__encrypt_value_schema_version_1_1_with_secret_key_exists_with_salt(
+ self,
+ mock_pad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_encode,
+ ):
+ """schema_version 1.1, secret_key exists, salt exists."""
+ mock_aes.new.return_value = self.mock_cipher
+ self.mock_cipher.encrypt.return_value = data_to_b4_encode
+ self.mock_padded_msg.return_value = padded_value
+ mock_pad_data.return_value = self.mock_padded_msg
+ self.mock_padded_msg.encode.return_value = padded_encoded_value
+
+ mock_b64_encode.return_value = encyrpted_bytes
+
+ result = self.db_base._encrypt_value(value, schema_version, salt)
+
+ self.assertTrue(isinstance(result, str))
+ self.assertEqual(result, encyrpted_value)
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
+ mock_pad_data.assert_called_once_with(value)
+ mock_b64_encode.assert_called_once_with(data_to_b4_encode)
+ self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value)
+ self.mock_padded_msg.encode.assert_called_with(encoding_type)
+
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "pad_data")
+ def test__encrypt_value_schema_version_1_0_secret_key_not_exists(
+ self, mock_pad_data, mock_join_secret_key
+ ):
+ """schema_version 1.0, secret_key is None, salt exists."""
+ schema_version = "1.0"
+ self.db_base.secret_key = None
+ result = self.db_base._encrypt_value(value, schema_version, salt)
+ self.assertEqual(result, value)
+ check_if_assert_not_called([mock_pad_data, mock_join_secret_key])
+
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "pad_data")
+ def test__encrypt_value_schema_version_1_1_secret_key_not_exists(
+ self, mock_pad_data, mock_join_secret_key
+ ):
+ """schema_version 1.1, secret_key is None, salt exists."""
+ self.db_base.secret_key = None
+ result = self.db_base._encrypt_value(value, schema_version, salt)
+ self.assertEqual(result, value)
+ check_if_assert_not_called([mock_pad_data, mock_join_secret_key])
+
+ @patch("osm_common.dbbase.b64encode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "pad_data")
+ def test__encrypt_value_schema_version_1_1_secret_key_exists_without_salt(
+ self,
+ mock_pad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_encode,
+ ):
+ """schema_version 1.1, secret_key exists, salt is None."""
+ salt = None
+ mock_aes.new.return_value = self.mock_cipher
+ self.mock_cipher.encrypt.return_value = data_to_b4_encode
+
+ self.mock_padded_msg.return_value = padded_value
+ mock_pad_data.return_value = self.mock_padded_msg
+ self.mock_padded_msg.encode.return_value = padded_encoded_value
+
+ mock_b64_encode.return_value = encyrpted_bytes
+
+ result = self.db_base._encrypt_value(value, schema_version, salt)
+
+ self.assertEqual(result, encyrpted_value)
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
+ mock_pad_data.assert_called_once_with(value)
+ mock_b64_encode.assert_called_once_with(data_to_b4_encode)
+ self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value)
+ self.mock_padded_msg.encode.assert_called_with(encoding_type)
+
+ @patch("osm_common.dbbase.b64encode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "pad_data")
+ def test__encrypt_value_invalid_encrpt_mode(
+ self,
+ mock_pad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_encode,
+ ):
+ """encrypt_mode is invalid."""
+ mock_aes.new.side_effect = Exception("Invalid ciphering mode.")
+ self.db_base.encrypt_mode = "AES.MODE_XXX"
+
+ with self.assertRaises(Exception) as err:
+ self.db_base._encrypt_value(value, schema_version, salt)
+
+ self.assertEqual(str(err.exception), "Invalid ciphering mode.")
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], "AES.MODE_XXX")
+ check_if_assert_not_called([mock_pad_data, mock_b64_encode])
+
+ @patch("osm_common.dbbase.b64encode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "pad_data")
+ def test__encrypt_value_schema_version_1_1_secret_key_exists_value_none(
+ self,
+ mock_pad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_encode,
+ ):
+ """schema_version 1.1, secret_key exists, value is None."""
+ value = None
+ mock_aes.new.return_value = self.mock_cipher
+ mock_pad_data.side_effect = DbException(
+ "Incorrect data type: type(None), string is expected."
+ )
+
+ with self.assertRaises(Exception) as err:
+ self.db_base._encrypt_value(value, schema_version, salt)
+ self.assertEqual(
+ str(err.exception),
+ "database exception Incorrect data type: type(None), string is expected.",
+ )
+
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
+ mock_pad_data.assert_called_once_with(value)
+ check_if_assert_not_called(
+ [mock_b64_encode, self.mock_cipher.encrypt, mock_b64_encode]
+ )
+
+ @patch("osm_common.dbbase.b64encode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "pad_data")
+ def test__encrypt_value_join_secret_key_raises(
+ self,
+ mock_pad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_encode,
+ ):
+ """Method join_secret_key raises DbException."""
+ salt = b"3434o34-3wewrwr-222424-2242dwew"
+
+ mock_join_secret_key.side_effect = DbException("Unexpected type")
+
+ mock_aes.new.return_value = self.mock_cipher
+
+ with self.assertRaises(Exception) as err:
+ self.db_base._encrypt_value(value, schema_version, salt)
+
+ self.assertEqual(str(err.exception), "database exception Unexpected type")
+ check_if_assert_not_called(
+ [mock_pad_data, mock_aes.new, mock_b64_encode, self.mock_cipher.encrypt]
+ )
+ mock_join_secret_key.assert_called_once_with(salt)
+
+ @patch("osm_common.dbbase.b64encode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "pad_data")
+ def test__encrypt_value_schema_version_1_1_secret_key_exists_b64_encode_raises(
+ self,
+ mock_pad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_encode,
+ ):
+ """schema_version 1.1, secret_key exists, b64encode raises TypeError."""
+ mock_aes.new.return_value = self.mock_cipher
+ self.mock_cipher.encrypt.return_value = "encrypted data"
+
+ self.mock_padded_msg.return_value = padded_value
+ mock_pad_data.return_value = self.mock_padded_msg
+ self.mock_padded_msg.encode.return_value = padded_encoded_value
+
+ mock_b64_encode.side_effect = TypeError(
+ "A bytes-like object is required, not 'str'"
+ )
+
+ with self.assertRaises(Exception) as error:
+ self.db_base._encrypt_value(value, schema_version, salt)
+ self.assertEqual(
+ str(error.exception), "A bytes-like object is required, not 'str'"
+ )
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
+ mock_pad_data.assert_called_once_with(value)
+ self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value)
+ self.mock_padded_msg.encode.assert_called_with(encoding_type)
+ mock_b64_encode.assert_called_once_with("encrypted data")
+
+ @patch("osm_common.dbbase.b64encode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "pad_data")
+ def test__encrypt_value_cipher_encrypt_raises(
+ self,
+ mock_pad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_encode,
+ ):
+ """AES encrypt method raises Exception."""
+ mock_aes.new.return_value = self.mock_cipher
+ self.mock_cipher.encrypt.side_effect = Exception("Invalid data type.")
+
+ self.mock_padded_msg.return_value = padded_value
+ mock_pad_data.return_value = self.mock_padded_msg
+ self.mock_padded_msg.encode.return_value = padded_encoded_value
+
+ with self.assertRaises(Exception) as error:
+ self.db_base._encrypt_value(value, schema_version, salt)
+
+ self.assertEqual(str(error.exception), "Invalid data type.")
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
+ mock_pad_data.assert_called_once_with(value)
+ self.mock_cipher.encrypt.assert_called_once_with(padded_encoded_value)
+ self.mock_padded_msg.encode.assert_called_with(encoding_type)
+ mock_b64_encode.assert_not_called()
+
+ @patch.object(DbBase, "get_secret_key")
+ @patch.object(DbBase, "_encrypt_value")
+ def test_encrypt_without_schema_version_without_salt(
+ self, mock_encrypt_value, mock_get_secret_key
+ ):
+ """schema and salt is None."""
+ mock_encrypt_value.return_value = encyrpted_value
+ result = self.db_base.encrypt(value)
+ mock_encrypt_value.assert_called_once_with(value, None, None)
+ mock_get_secret_key.assert_called_once()
+ self.assertEqual(result, encyrpted_value)
+
+ @patch.object(DbBase, "get_secret_key")
+ @patch.object(DbBase, "_encrypt_value")
+ def test_encrypt_with_schema_version_with_salt(
+ self, mock_encrypt_value, mock_get_secret_key
+ ):
+ """schema version exists, salt is None."""
+ mock_encrypt_value.return_value = encyrpted_value
+ result = self.db_base.encrypt(value, schema_version, salt)
+ mock_encrypt_value.assert_called_once_with(value, schema_version, salt)
+ mock_get_secret_key.assert_called_once()
+ self.assertEqual(result, encyrpted_value)
+
+ @patch.object(DbBase, "get_secret_key")
+ @patch.object(DbBase, "_encrypt_value")
+ def test_encrypt_get_secret_key_raises(
+ self, mock_encrypt_value, mock_get_secret_key
+ ):
+ """get_secret_key method raises DbException."""
+ mock_get_secret_key.side_effect = DbException("KeyError")
+ with self.assertRaises(Exception) as error:
+ self.db_base.encrypt(value)
+ self.assertEqual(str(error.exception), "database exception KeyError")
+ mock_encrypt_value.assert_not_called()
+ mock_get_secret_key.assert_called_once()
+
+ @patch.object(DbBase, "get_secret_key")
+ @patch.object(DbBase, "_encrypt_value")
+ def test_encrypt_encrypt_raises(self, mock_encrypt_value, mock_get_secret_key):
+ """_encrypt method raises DbException."""
+ mock_encrypt_value.side_effect = DbException(
+ "Incorrect data type: type(None), string is expected."
+ )
+ with self.assertRaises(Exception) as error:
+ self.db_base.encrypt(value, schema_version, salt)
+ self.assertEqual(
+ str(error.exception),
+ "database exception Incorrect data type: type(None), string is expected.",
+ )
+ mock_encrypt_value.assert_called_once_with(value, schema_version, salt)
+ mock_get_secret_key.assert_called_once()
+
+ @patch("osm_common.dbbase.b64decode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "unpad_data")
+ def test__decrypt_value_schema_version_1_1_secret_key_exists_without_salt(
+ self,
+ mock_unpad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_decode,
+ ):
+ """schema_version 1.1, secret_key exists, salt is None."""
+ salt = None
+ mock_aes.new.return_value = self.mock_cipher
+ self.mock_cipher.decrypt.return_value = padded_encoded_value
+
+ mock_b64_decode.return_value = b64_decoded
+
+ mock_unpad_data.return_value = value
+
+ result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
+ self.assertEqual(result, value)
+
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
+ mock_unpad_data.assert_called_once_with("private key txt\0")
+ mock_b64_decode.assert_called_once_with(encyrpted_value)
+ self.mock_cipher.decrypt.assert_called_once_with(b64_decoded)
+
+ @patch("osm_common.dbbase.b64decode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "unpad_data")
+ def test__decrypt_value_schema_version_1_1_secret_key_exists_with_salt(
+ self,
+ mock_unpad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_decode,
+ ):
+ """schema_version 1.1, secret_key exists, salt is None."""
+ mock_aes.new.return_value = self.mock_cipher
+ self.mock_cipher.decrypt.return_value = padded_encoded_value
+
+ mock_b64_decode.return_value = b64_decoded
+
+ mock_unpad_data.return_value = value
+
+ result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
+ self.assertEqual(result, value)
+
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
+ mock_unpad_data.assert_called_once_with("private key txt\0")
+ mock_b64_decode.assert_called_once_with(encyrpted_value)
+ self.mock_cipher.decrypt.assert_called_once_with(b64_decoded)
+
+ @patch("osm_common.dbbase.b64decode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "unpad_data")
+ def test__decrypt_value_schema_version_1_1_without_secret_key(
+ self,
+ mock_unpad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_decode,
+ ):
+ """schema_version 1.1, secret_key is None, salt exists."""
+ self.db_base.secret_key = None
+
+ result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
+
+ self.assertEqual(result, encyrpted_value)
+ check_if_assert_not_called(
+ [
+ mock_join_secret_key,
+ mock_aes.new,
+ mock_unpad_data,
+ mock_b64_decode,
+ self.mock_cipher.decrypt,
+ ]
+ )
+
+ @patch("osm_common.dbbase.b64decode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "unpad_data")
+ def test__decrypt_value_schema_version_1_0_with_secret_key(
+ self,
+ mock_unpad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_decode,
+ ):
+ """schema_version 1.0, secret_key exists, salt exists."""
+ schema_version = "1.0"
+ result = self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
+
+ self.assertEqual(result, encyrpted_value)
+ check_if_assert_not_called(
+ [
+ mock_join_secret_key,
+ mock_aes.new,
+ mock_unpad_data,
+ mock_b64_decode,
+ self.mock_cipher.decrypt,
+ ]
+ )
+
+ @patch("osm_common.dbbase.b64decode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "unpad_data")
+ def test__decrypt_value_join_secret_key_raises(
+ self,
+ mock_unpad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_decode,
+ ):
+ """_join_secret_key raises TypeError."""
+ salt = object()
+ mock_join_secret_key.side_effect = TypeError("'type' object is not iterable")
+
+ with self.assertRaises(Exception) as error:
+ self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
+ self.assertEqual(str(error.exception), "'type' object is not iterable")
+
+ mock_join_secret_key.assert_called_once_with(salt)
+ check_if_assert_not_called(
+ [mock_aes.new, mock_unpad_data, mock_b64_decode, self.mock_cipher.decrypt]
+ )
+
+ @patch("osm_common.dbbase.b64decode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "unpad_data")
+ def test__decrypt_value_b64decode_raises(
+ self,
+ mock_unpad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_decode,
+ ):
+ """b64decode raises TypeError."""
+ mock_b64_decode.side_effect = TypeError(
+ "A str-like object is required, not 'bytes'"
+ )
+ with self.assertRaises(Exception) as error:
+ self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
+ self.assertEqual(
+ str(error.exception), "A str-like object is required, not 'bytes'"
+ )
+
+ mock_b64_decode.assert_called_once_with(encyrpted_value)
+ mock_join_secret_key.assert_called_once_with(salt)
+ check_if_assert_not_called(
+ [mock_aes.new, self.mock_cipher.decrypt, mock_unpad_data]
+ )
+
+ @patch("osm_common.dbbase.b64decode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "unpad_data")
+ def test__decrypt_value_invalid_encrypt_mode(
+ self,
+ mock_unpad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_decode,
+ ):
+ """Invalid AES encrypt mode."""
+ mock_aes.new.side_effect = Exception("Invalid ciphering mode.")
+ self.db_base.encrypt_mode = "AES.MODE_XXX"
+
+ mock_b64_decode.return_value = b64_decoded
+ with self.assertRaises(Exception) as error:
+ self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
+
+ self.assertEqual(str(error.exception), "Invalid ciphering mode.")
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], "AES.MODE_XXX")
+ mock_b64_decode.assert_called_once_with(encyrpted_value)
+ check_if_assert_not_called([mock_unpad_data, self.mock_cipher.decrypt])
+
+ @patch("osm_common.dbbase.b64decode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "unpad_data")
+ def test__decrypt_value_cipher_decrypt_raises(
+ self,
+ mock_unpad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_decode,
+ ):
+ """AES decrypt raises Exception."""
+ mock_b64_decode.return_value = b64_decoded
+
+ mock_aes.new.return_value = self.mock_cipher
+ self.mock_cipher.decrypt.side_effect = Exception("Invalid data type.")
+
+ with self.assertRaises(Exception) as error:
+ self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
+ self.assertEqual(str(error.exception), "Invalid data type.")
+
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
+ mock_b64_decode.assert_called_once_with(encyrpted_value)
+ self.mock_cipher.decrypt.assert_called_once_with(b64_decoded)
+ mock_unpad_data.assert_not_called()
+
+ @patch("osm_common.dbbase.b64decode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "unpad_data")
+ def test__decrypt_value_decode_raises(
+ self,
+ mock_unpad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_decode,
+ ):
+ """Decode raises UnicodeDecodeError."""
+ mock_aes.new.return_value = self.mock_cipher
+ self.mock_cipher.decrypt.return_value = b"\xd0\x000091"
+
+ mock_b64_decode.return_value = b64_decoded
+
+ mock_unpad_data.return_value = value
+ with self.assertRaises(Exception) as error:
+ self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
+ self.assertEqual(
+ str(error.exception),
+ "database exception Cannot decrypt information. Are you using same COMMONKEY in all OSM components?",
+ )
+ self.assertEqual(type(error.exception), DbException)
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
+ mock_b64_decode.assert_called_once_with(encyrpted_value)
+ self.mock_cipher.decrypt.assert_called_once_with(b64_decoded)
+ mock_unpad_data.assert_not_called()
+
+ @patch("osm_common.dbbase.b64decode")
+ @patch("osm_common.dbbase.AES")
+ @patch.object(DbBase, "_join_secret_key")
+ @patch.object(DbBase, "unpad_data")
+ def test__decrypt_value_unpad_data_raises(
+ self,
+ mock_unpad_data,
+ mock_join_secret_key,
+ mock_aes,
+ mock_b64_decode,
+ ):
+ """Method unpad_data raises error."""
+ mock_decrypted_message = MagicMock()
+ mock_decrypted_message.decode.return_value = None
+ mock_aes.new.return_value = self.mock_cipher
+ self.mock_cipher.decrypt.return_value = mock_decrypted_message
+ mock_unpad_data.side_effect = DbException(
+ "Incorrect data type: type(None), string is expected."
+ )
+ mock_b64_decode.return_value = b64_decoded
+
+ with self.assertRaises(Exception) as error:
+ self.db_base._decrypt_value(encyrpted_value, schema_version, salt)
+ self.assertEqual(
+ str(error.exception),
+ "database exception Incorrect data type: type(None), string is expected.",
+ )
+ self.assertEqual(type(error.exception), DbException)
+ mock_join_secret_key.assert_called_once_with(salt)
+ _call_mock_aes_new = mock_aes.new.call_args_list[0].args
+ self.assertEqual(_call_mock_aes_new[1], AES.MODE_ECB)
+ mock_b64_decode.assert_called_once_with(encyrpted_value)
+ self.mock_cipher.decrypt.assert_called_once_with(b64_decoded)
+ mock_decrypted_message.decode.assert_called_once_with(
+ self.db_base.encoding_type
+ )
+ mock_unpad_data.assert_called_once_with(None)
+
+ @patch.object(DbBase, "get_secret_key")
+ @patch.object(DbBase, "_decrypt_value")
+ def test_decrypt_without_schema_version_without_salt(
+ self, mock_decrypt_value, mock_get_secret_key
+ ):
+ """schema_version is None, salt is None."""
+ mock_decrypt_value.return_value = encyrpted_value
+ result = self.db_base.decrypt(value)
+ mock_decrypt_value.assert_called_once_with(value, None, None)
+ mock_get_secret_key.assert_called_once()
+ self.assertEqual(result, encyrpted_value)
+
+ @patch.object(DbBase, "get_secret_key")
+ @patch.object(DbBase, "_decrypt_value")
+ def test_decrypt_with_schema_version_with_salt(
+ self, mock_decrypt_value, mock_get_secret_key
+ ):
+ """schema_version and salt exist."""
+ mock_decrypt_value.return_value = encyrpted_value
+ result = self.db_base.decrypt(value, schema_version, salt)
+ mock_decrypt_value.assert_called_once_with(value, schema_version, salt)
+ mock_get_secret_key.assert_called_once()
+ self.assertEqual(result, encyrpted_value)
+
+ @patch.object(DbBase, "get_secret_key")
+ @patch.object(DbBase, "_decrypt_value")
+ def test_decrypt_get_secret_key_raises(
+ self, mock_decrypt_value, mock_get_secret_key
+ ):
+ """Method get_secret_key raises KeyError."""
+ mock_get_secret_key.side_effect = DbException("KeyError")
+ with self.assertRaises(Exception) as error:
+ self.db_base.decrypt(value)
+ self.assertEqual(str(error.exception), "database exception KeyError")
+ mock_decrypt_value.assert_not_called()
+ mock_get_secret_key.assert_called_once()
+
+ @patch.object(DbBase, "get_secret_key")
+ @patch.object(DbBase, "_decrypt_value")
+ def test_decrypt_decrypt_value_raises(
+ self, mock_decrypt_value, mock_get_secret_key
+ ):
+ """Method _decrypt raises error."""
+ mock_decrypt_value.side_effect = DbException(
+ "Incorrect data type: type(None), string is expected."
+ )
+ with self.assertRaises(Exception) as error:
+ self.db_base.decrypt(value, schema_version, salt)
+ self.assertEqual(
+ str(error.exception),
+ "database exception Incorrect data type: type(None), string is expected.",
+ )
+ mock_decrypt_value.assert_called_once_with(value, schema_version, salt)
+ mock_get_secret_key.assert_called_once()
+
+ def test_encrypt_decrypt_with_schema_version_1_1_with_salt(self):
+ """Encrypt and decrypt with schema version 1.1, salt exists."""
+ encrypted_msg = self.db_base.encrypt(value, schema_version, salt)
+ decrypted_msg = self.db_base.decrypt(encrypted_msg, schema_version, salt)
+ self.assertEqual(value, decrypted_msg)
+
+ def test_encrypt_decrypt_with_schema_version_1_0_with_salt(self):
+ """Encrypt and decrypt with schema version 1.0, salt exists."""
+ schema_version = "1.0"
+ encrypted_msg = self.db_base.encrypt(value, schema_version, salt)
+ decrypted_msg = self.db_base.decrypt(encrypted_msg, schema_version, salt)
+ self.assertEqual(value, decrypted_msg)
+
+ def test_encrypt_decrypt_with_schema_version_1_1_without_salt(self):
+ """Encrypt and decrypt with schema version 1.1 and without salt."""
+ salt = None
+ encrypted_msg = self.db_base.encrypt(value, schema_version, salt)
+ decrypted_msg = self.db_base.decrypt(encrypted_msg, schema_version, salt)
+ self.assertEqual(value, decrypted_msg)
+
+
+class TestAsyncEncryption(unittest.TestCase):
+ @patch("logging.getLogger", autospec=True)
+ def setUp(self, mock_logger):
+ mock_logger = logging.getLogger()
+ mock_logger.disabled = True
+ self.encryption = Encryption(uri="uri", config={})
+ self.encryption.encoding_type = encoding_type
+ self.encryption.encrypt_mode = encyrpt_mode
+ self.encryption._secret_key = secret_key
+ self.admin_collection = Mock()
+ self.admin_collection.find_one = AsyncMock()
+ self.encryption._client = {
+ "osm": {
+ "admin": self.admin_collection,
+ }
+ }
+
+ @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
+ def test_decrypt_fields_with_item_with_fields(self, mock_decrypt):
+ """item and fields exist."""
+ mock_decrypt.side_effect = [decrypted_val1, decrypted_val2]
+ input_item = copy.deepcopy(item)
+ expected_item = {
+ "secret": decrypted_val1,
+ "cacert": decrypted_val2,
+ "path": "/var",
+ "ip": "192.168.12.23",
+ }
+ fields = ["secret", "cacert"]
+
+ asyncio.run(
+ self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
+ )
+ self.assertEqual(input_item, expected_item)
+ _call_mock_decrypt = mock_decrypt.call_args_list
+ self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.1", salt))
+ self.assertEqual(_call_mock_decrypt[1].args, ("mycacert", "1.1", salt))
+
+ @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
+ def test_decrypt_fields_empty_item_with_fields(self, mock_decrypt):
+ """item is empty and fields exists."""
+ input_item = {}
+ fields = ["secret", "cacert"]
+ asyncio.run(
+ self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
+ )
+ self.assertEqual(input_item, {})
+ mock_decrypt.assert_not_called()
+
+ @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
+ def test_decrypt_fields_with_item_without_fields(self, mock_decrypt):
+ """item exists and fields is empty."""
+ input_item = copy.deepcopy(item)
+ fields = []
+ asyncio.run(
+ self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
+ )
+ self.assertEqual(input_item, item)
+ mock_decrypt.assert_not_called()
+
+ @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
+ def test_decrypt_fields_with_item_with_single_field(self, mock_decrypt):
+ """item exists and field has single value."""
+ mock_decrypt.return_value = decrypted_val1
+ fields = ["secret"]
+ input_item = copy.deepcopy(item)
+ expected_item = {
+ "secret": decrypted_val1,
+ "cacert": "mycacert",
+ "path": "/var",
+ "ip": "192.168.12.23",
+ }
+ asyncio.run(
+ self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
+ )
+ self.assertEqual(input_item, expected_item)
+ _call_mock_decrypt = mock_decrypt.call_args_list
+ self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.1", salt))
+
+ @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
+ def test_decrypt_fields_with_item_with_field_none_salt_1_0_schema_version(
+ self, mock_decrypt
+ ):
+ """item exists and field has single value, salt is None, schema version is 1.0."""
+ schema_version = "1.0"
+ salt = None
+ mock_decrypt.return_value = "mysecret"
+ input_item = copy.deepcopy(item)
+ fields = ["secret"]
+ asyncio.run(
+ self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
+ )
+ self.assertEqual(input_item, item)
+ _call_mock_decrypt = mock_decrypt.call_args_list
+ self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.0", None))
+
+ @patch.object(Encryption, "decrypt", new_callable=AsyncMock)
+ def test_decrypt_fields_decrypt_raises(self, mock_decrypt):
+ """Method decrypt raises error."""
+ mock_decrypt.side_effect = DbException(
+ "Incorrect data type: type(None), string is expected."
+ )
+ fields = ["secret"]
+ input_item = copy.deepcopy(item)
+ with self.assertRaises(Exception) as error:
+ asyncio.run(
+ self.encryption.decrypt_fields(input_item, fields, schema_version, salt)
+ )
+ self.assertEqual(
+ str(error.exception),
+ "database exception Incorrect data type: type(None), string is expected.",
+ )
+ self.assertEqual(input_item, item)
+ _call_mock_decrypt = mock_decrypt.call_args_list
+ self.assertEqual(_call_mock_decrypt[0].args, ("mysecret", "1.1", salt))
+
+ @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
+ @patch.object(Encryption, "_encrypt_value")
+ def test_encrypt(self, mock_encrypt_value, mock_get_secret_key):
+ """Method decrypt raises error."""
+ mock_encrypt_value.return_value = encyrpted_value
+ result = asyncio.run(self.encryption.encrypt(value, schema_version, salt))
+ self.assertEqual(result, encyrpted_value)
+ mock_get_secret_key.assert_called_once()
+ mock_encrypt_value.assert_called_once_with(value, schema_version, salt)
+
+ @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
+ @patch.object(Encryption, "_encrypt_value")
+ def test_encrypt_get_secret_key_raises(
+ self, mock_encrypt_value, mock_get_secret_key
+ ):
+ """Method get_secret_key raises error."""
+ mock_get_secret_key.side_effect = DbException("Unexpected type.")
+ with self.assertRaises(Exception) as error:
+ asyncio.run(self.encryption.encrypt(value, schema_version, salt))
+ self.assertEqual(str(error.exception), "database exception Unexpected type.")
+ mock_get_secret_key.assert_called_once()
+ mock_encrypt_value.assert_not_called()
+
+ @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
+ @patch.object(Encryption, "_encrypt_value")
+ def test_encrypt_get_encrypt_raises(self, mock_encrypt_value, mock_get_secret_key):
+ """Method _encrypt raises error."""
+ mock_encrypt_value.side_effect = TypeError(
+ "A bytes-like object is required, not 'str'"
+ )
+ with self.assertRaises(Exception) as error:
+ asyncio.run(self.encryption.encrypt(value, schema_version, salt))
+ self.assertEqual(
+ str(error.exception), "A bytes-like object is required, not 'str'"
+ )
+ mock_get_secret_key.assert_called_once()
+ mock_encrypt_value.assert_called_once_with(value, schema_version, salt)
+
+ @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
+ @patch.object(Encryption, "_decrypt_value")
+ def test_decrypt(self, mock_decrypt_value, mock_get_secret_key):
+ """Decrypted successfully."""
+ mock_decrypt_value.return_value = value
+ result = asyncio.run(
+ self.encryption.decrypt(encyrpted_value, schema_version, salt)
+ )
+ self.assertEqual(result, value)
+ mock_get_secret_key.assert_called_once()
+ mock_decrypt_value.assert_called_once_with(
+ encyrpted_value, schema_version, salt
+ )
+
+ @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
+ @patch.object(Encryption, "_decrypt_value")
+ def test_decrypt_get_secret_key_raises(
+ self, mock_decrypt_value, mock_get_secret_key
+ ):
+ """Method get_secret_key raises error."""
+ mock_get_secret_key.side_effect = DbException("Unexpected type.")
+ with self.assertRaises(Exception) as error:
+ asyncio.run(self.encryption.decrypt(encyrpted_value, schema_version, salt))
+ self.assertEqual(str(error.exception), "database exception Unexpected type.")
+ mock_get_secret_key.assert_called_once()
+ mock_decrypt_value.assert_not_called()
+
+ @patch.object(Encryption, "get_secret_key", new_callable=AsyncMock)
+ @patch.object(Encryption, "_decrypt_value")
+ def test_decrypt_decrypt_value_raises(
+ self, mock_decrypt_value, mock_get_secret_key
+ ):
+ """Method get_secret_key raises error."""
+ mock_decrypt_value.side_effect = TypeError(
+ "A bytes-like object is required, not 'str'"
+ )
+ with self.assertRaises(Exception) as error:
+ asyncio.run(self.encryption.decrypt(encyrpted_value, schema_version, salt))
+ self.assertEqual(
+ str(error.exception), "A bytes-like object is required, not 'str'"
+ )
+ mock_get_secret_key.assert_called_once()
+ mock_decrypt_value.assert_called_once_with(
+ encyrpted_value, schema_version, salt
+ )
+
+ def test_join_keys_string_key(self):
+ """key is string."""
+ string_key = "sample key"
+ result = self.encryption._join_keys(string_key, secret_key)
+ self.assertEqual(result, joined_key)
+ self.assertTrue(isinstance(result, bytes))
+
+ def test_join_keys_bytes_key(self):
+ """key is bytes."""
+ bytes_key = b"sample key"
+ result = self.encryption._join_keys(bytes_key, secret_key)
+ self.assertEqual(result, joined_key)
+ self.assertTrue(isinstance(result, bytes))
+ self.assertEqual(len(result.decode("unicode_escape")), 32)
+
+ def test_join_keys_int_key(self):
+ """key is int."""
+ int_key = 923
+ with self.assertRaises(Exception) as error:
+ self.encryption._join_keys(int_key, None)
+ self.assertEqual(str(error.exception), "'int' object is not iterable")
+
+ def test_join_keys_none_secret_key(self):
+ """key is as bytes and secret key is None."""
+ bytes_key = b"sample key"
+ result = self.encryption._join_keys(bytes_key, None)
+ self.assertEqual(
+ result,
+ b"sample key\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ )
+ self.assertTrue(isinstance(result, bytes))
+ self.assertEqual(len(result.decode("unicode_escape")), 32)
+
+ def test_join_keys_none_key_none_secret_key(self):
+ """key is None and secret key is None."""
+ with self.assertRaises(Exception) as error:
+ self.encryption._join_keys(None, None)
+ self.assertEqual(str(error.exception), "'NoneType' object is not iterable")
+
+ def test_join_keys_none_key(self):
+ """key is None and secret key exists."""
+ with self.assertRaises(Exception) as error:
+ self.encryption._join_keys(None, secret_key)
+ self.assertEqual(str(error.exception), "'NoneType' object is not iterable")
+
+ @patch.object(Encryption, "_join_keys")
+ def test_join_secret_key_string_sample_key(self, mock_join_keys):
+ """key is None and secret key exists as string."""
+ update_key = "sample key"
+ mock_join_keys.return_value = joined_key
+ result = self.encryption._join_secret_key(update_key)
+ self.assertEqual(result, joined_key)
+ self.assertTrue(isinstance(result, bytes))
+ mock_join_keys.assert_called_once_with(update_key, secret_key)
+
+ @patch.object(Encryption, "_join_keys")
+ def test_join_secret_key_byte_sample_key(self, mock_join_keys):
+ """key is None and secret key exists as bytes."""
+ update_key = b"sample key"
+ mock_join_keys.return_value = joined_key
+ result = self.encryption._join_secret_key(update_key)
+ self.assertEqual(result, joined_key)
+ self.assertTrue(isinstance(result, bytes))
+ mock_join_keys.assert_called_once_with(update_key, secret_key)
+
+ @patch.object(Encryption, "_join_keys")
+ def test_join_secret_key_join_keys_raises(self, mock_join_keys):
+ """Method _join_secret_key raises."""
+ update_key = 3434
+ mock_join_keys.side_effect = TypeError("'int' object is not iterable")
+ with self.assertRaises(Exception) as error:
+ self.encryption._join_secret_key(update_key)
+ self.assertEqual(str(error.exception), "'int' object is not iterable")
+ mock_join_keys.assert_called_once_with(update_key, secret_key)
+
+ @patch.object(Encryption, "_join_keys")
+ def test_get_secret_key_exists(self, mock_join_keys):
+ """secret_key exists."""
+ self.encryption._secret_key = secret_key
+ asyncio.run(self.encryption.get_secret_key())
+ self.assertEqual(self.encryption.secret_key, secret_key)
+ mock_join_keys.assert_not_called()
+
+ @patch.object(Encryption, "_join_keys")
+ @patch("osm_common.dbbase.b64decode")
+ def test_get_secret_key_not_exist_database_key_exist(
+ self, mock_b64decode, mock_join_keys
+ ):
+ """secret_key does not exist, database key exists."""
+ self.encryption._secret_key = None
+ self.encryption._admin_collection.find_one.return_value = None
+ self.encryption._config = {"database_commonkey": "osm_new_key"}
+ mock_join_keys.return_value = joined_key
+ asyncio.run(self.encryption.get_secret_key())
+ self.assertEqual(self.encryption.secret_key, joined_key)
+ self.assertEqual(mock_join_keys.call_count, 1)
+ mock_b64decode.assert_not_called()
+
+ @patch.object(Encryption, "_join_keys")
+ @patch("osm_common.dbbase.b64decode")
+ def test_get_secret_key_not_exist_with_database_key_version_data_exist_without_serial(
+ self, mock_b64decode, mock_join_keys
+ ):
+ """secret_key does not exist, database key exists."""
+ self.encryption._secret_key = None
+ self.encryption._admin_collection.find_one.return_value = {"version": "1.0"}
+ self.encryption._config = {"database_commonkey": "osm_new_key"}
+ mock_join_keys.return_value = joined_key
+ asyncio.run(self.encryption.get_secret_key())
+ self.assertEqual(self.encryption.secret_key, joined_key)
+ self.assertEqual(mock_join_keys.call_count, 1)
+ mock_b64decode.assert_not_called()
+ self.encryption._admin_collection.find_one.assert_called_once_with(
+ {"_id": "version"}
+ )
+ _call_mock_join_keys = mock_join_keys.call_args_list
+ self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None))
+
+ @patch.object(Encryption, "_join_keys")
+ @patch("osm_common.dbbase.b64decode")
+ def test_get_secret_key_not_exist_with_database_key_version_data_exist_with_serial(
+ self, mock_b64decode, mock_join_keys
+ ):
+ """secret_key does not exist, database key exists, version and serial exist
+ in admin collection."""
+ self.encryption._secret_key = None
+ self.encryption._admin_collection.find_one.return_value = {
+ "version": "1.0",
+ "serial": serial_bytes,
+ }
+ self.encryption._config = {"database_commonkey": "osm_new_key"}
+ mock_join_keys.side_effect = [secret_key, joined_key]
+ mock_b64decode.return_value = base64_decoded_serial
+ asyncio.run(self.encryption.get_secret_key())
+ self.assertEqual(self.encryption.secret_key, joined_key)
+ self.assertEqual(mock_join_keys.call_count, 2)
+ mock_b64decode.assert_called_once_with(serial_bytes)
+ self.encryption._admin_collection.find_one.assert_called_once_with(
+ {"_id": "version"}
+ )
+ _call_mock_join_keys = mock_join_keys.call_args_list
+ self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None))
+ self.assertEqual(
+ _call_mock_join_keys[1].args, (base64_decoded_serial, secret_key)
+ )
+
+ @patch.object(Encryption, "_join_keys")
+ @patch("osm_common.dbbase.b64decode")
+ def test_get_secret_key_join_keys_raises(self, mock_b64decode, mock_join_keys):
+ """Method _join_keys raises."""
+ self.encryption._secret_key = None
+ self.encryption._admin_collection.find_one.return_value = {
+ "version": "1.0",
+ "serial": serial_bytes,
+ }
+ self.encryption._config = {"database_commonkey": "osm_new_key"}
+ mock_join_keys.side_effect = DbException("Invalid data type.")
+ with self.assertRaises(Exception) as error:
+ asyncio.run(self.encryption.get_secret_key())
+ self.assertEqual(str(error.exception), "database exception Invalid data type.")
+ self.assertEqual(mock_join_keys.call_count, 1)
+ check_if_assert_not_called(
+ [mock_b64decode, self.encryption._admin_collection.find_one]
+ )
+ _call_mock_join_keys = mock_join_keys.call_args_list
+ self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None))
+
+ @patch.object(Encryption, "_join_keys")
+ @patch("osm_common.dbbase.b64decode")
+ def test_get_secret_key_b64decode_raises(self, mock_b64decode, mock_join_keys):
+ """Method b64decode raises."""
+ self.encryption._secret_key = None
+ self.encryption._admin_collection.find_one.return_value = {
+ "version": "1.0",
+ "serial": base64_decoded_serial,
+ }
+ self.encryption._config = {"database_commonkey": "osm_new_key"}
+ mock_join_keys.return_value = secret_key
+ mock_b64decode.side_effect = TypeError(
+ "A bytes-like object is required, not 'str'"
+ )
+ with self.assertRaises(Exception) as error:
+ asyncio.run(self.encryption.get_secret_key())
+ self.assertEqual(
+ str(error.exception), "A bytes-like object is required, not 'str'"
+ )
+ self.assertEqual(self.encryption.secret_key, None)
+ self.assertEqual(mock_join_keys.call_count, 1)
+ mock_b64decode.assert_called_once_with(base64_decoded_serial)
+ self.encryption._admin_collection.find_one.assert_called_once_with(
+ {"_id": "version"}
+ )
+ _call_mock_join_keys = mock_join_keys.call_args_list
+ self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None))
+
+ @patch.object(Encryption, "_join_keys")
+ @patch("osm_common.dbbase.b64decode")
+ def test_get_secret_key_admin_collection_find_one_raises(
+ self, mock_b64decode, mock_join_keys
+ ):
+ """admin_collection find_one raises."""
+ self.encryption._secret_key = None
+ self.encryption._admin_collection.find_one.side_effect = DbException(
+ "Connection failed."
+ )
+ self.encryption._config = {"database_commonkey": "osm_new_key"}
+ mock_join_keys.return_value = secret_key
+ with self.assertRaises(Exception) as error:
+ asyncio.run(self.encryption.get_secret_key())
+ self.assertEqual(str(error.exception), "database exception Connection failed.")
+ self.assertEqual(self.encryption.secret_key, None)
+ self.assertEqual(mock_join_keys.call_count, 1)
+ mock_b64decode.assert_not_called()
+ self.encryption._admin_collection.find_one.assert_called_once_with(
+ {"_id": "version"}
+ )
+ _call_mock_join_keys = mock_join_keys.call_args_list
+ self.assertEqual(_call_mock_join_keys[0].args, ("osm_new_key", None))
+
+ def test_encrypt_decrypt_with_schema_version_1_1_with_salt(self):
+ """Encrypt and decrypt with schema version 1.1, salt exists."""
+ encrypted_msg = asyncio.run(
+ self.encryption.encrypt(value, schema_version, salt)
+ )
+ decrypted_msg = asyncio.run(
+ self.encryption.decrypt(encrypted_msg, schema_version, salt)
+ )
+ self.assertEqual(value, decrypted_msg)
+
+ def test_encrypt_decrypt_with_schema_version_1_0_with_salt(self):
+ """Encrypt and decrypt with schema version 1.0, salt exists."""
+ schema_version = "1.0"
+ encrypted_msg = asyncio.run(
+ self.encryption.encrypt(value, schema_version, salt)
+ )
+ decrypted_msg = asyncio.run(
+ self.encryption.decrypt(encrypted_msg, schema_version, salt)
+ )
+ self.assertEqual(value, decrypted_msg)
+
+ def test_encrypt_decrypt_with_schema_version_1_1_without_salt(self):
+ """Encrypt and decrypt with schema version 1.1, without salt."""
+ salt = None
+ with self.assertRaises(Exception) as error:
+ asyncio.run(self.encryption.encrypt(value, schema_version, salt))
+ self.assertEqual(str(error.exception), "'NoneType' object is not iterable")
+
+
+class TestDeepUpdate(unittest.TestCase):
+ def test_update_dict(self):
+ # Original, patch, expected result
+ TEST = (
+ ({"a": "b"}, {"a": "c"}, {"a": "c"}),
+ ({"a": "b"}, {"b": "c"}, {"a": "b", "b": "c"}),
+ ({"a": "b"}, {"a": None}, {}),
+ ({"a": "b", "b": "c"}, {"a": None}, {"b": "c"}),
+ ({"a": ["b"]}, {"a": "c"}, {"a": "c"}),
+ ({"a": "c"}, {"a": ["b"]}, {"a": ["b"]}),
+ ({"a": {"b": "c"}}, {"a": {"b": "d", "c": None}}, {"a": {"b": "d"}}),
+ ({"a": [{"b": "c"}]}, {"a": [1]}, {"a": [1]}),
+ ({1: ["a", "b"]}, {1: ["c", "d"]}, {1: ["c", "d"]}),
+ ({1: {"a": "b"}}, {1: ["c"]}, {1: ["c"]}),
+ ({1: {"a": "foo"}}, {1: None}, {}),
+ ({1: {"a": "foo"}}, {1: "bar"}, {1: "bar"}),
+ ({"e": None}, {"a": 1}, {"e": None, "a": 1}),
+ ({1: [1, 2]}, {1: {"a": "b", "c": None}}, {1: {"a": "b"}}),
+ ({}, {"a": {"bb": {"ccc": None}}}, {"a": {"bb": {}}}),
+ )
+ for t in TEST:
+ deep_update(t[0], t[1])
+ self.assertEqual(t[0], t[2])
+ # test deepcopy is done. So that original dictionary does not reference the pach
+ test_original = {1: {"a": "b"}}
+ test_patch = {1: {"c": {"d": "e"}}}
+ test_result = {1: {"a": "b", "c": {"d": "e"}}}
+ deep_update(test_original, test_patch)
+ self.assertEqual(test_original, test_result)
+ test_patch[1]["c"]["f"] = "edition of patch, must not modify original"
+ self.assertEqual(test_original, test_result)
+
+ def test_update_array(self):
+ # This TEST contains a list with the the Original, patch, and expected result
+ TEST = (
+ # delete all instances of "a"/"d"
+ ({"A": ["a", "b", "a"]}, {"A": {"$a": None}}, {"A": ["b"]}),
+ ({"A": ["a", "b", "a"]}, {"A": {"$d": None}}, {"A": ["a", "b", "a"]}),
+ # delete and insert at 0
+ (
+ {"A": ["a", "b", "c"]},
+ {"A": {"$b": None, "$+[0]": "b"}},
+ {"A": ["b", "a", "c"]},
+ ),
+ # delete and edit
+ (
+ {"A": ["a", "b", "a"]},
+ {"A": {"$a": None, "$[1]": {"c": "d"}}},
+ {"A": [{"c": "d"}]},
+ ),
+ # insert if not exist
+ ({"A": ["a", "b", "c"]}, {"A": {"$+b": "b"}}, {"A": ["a", "b", "c"]}),
+ ({"A": ["a", "b", "c"]}, {"A": {"$+d": "f"}}, {"A": ["a", "b", "c", "f"]}),
+ # edit by filter
+ (
+ {"A": ["a", "b", "a"]},
+ {"A": {"$b": {"c": "d"}}},
+ {"A": ["a", {"c": "d"}, "a"]},
+ ),
+ (
+ {"A": ["a", "b", "a"]},
+ {"A": {"$b": None, "$+[0]": "b", "$+": "c"}},
+ {"A": ["b", "a", "a", "c"]},
+ ),
+ ({"A": ["a", "b", "a"]}, {"A": {"$c": None}}, {"A": ["a", "b", "a"]}),
+ # index deletion out of range
+ ({"A": ["a", "b", "a"]}, {"A": {"$[5]": None}}, {"A": ["a", "b", "a"]}),
+ # nested array->dict
+ (
+ {"A": ["a", "b", {"id": "1", "c": {"d": 2}}]},
+ {"A": {"$id: '1'": {"h": None, "c": {"d": "e", "f": "g"}}}},
+ {"A": ["a", "b", {"id": "1", "c": {"d": "e", "f": "g"}}]},
+ ),
+ (
+ {"A": [{"id": 1, "c": {"d": 2}}, {"id": 1, "c": {"f": []}}]},
+ {"A": {"$id: 1": {"h": None, "c": {"d": "e", "f": "g"}}}},
+ {
+ "A": [
+ {"id": 1, "c": {"d": "e", "f": "g"}},
+ {"id": 1, "c": {"d": "e", "f": "g"}},
+ ]
+ },
+ ),
+ # nested array->array
+ (
+ {"A": ["a", "b", ["a", "b"]]},
+ {"A": {"$b": None, "$[2]": {"$b": {}, "$+": "c"}}},
+ {"A": ["a", ["a", {}, "c"]]},
+ ),
+ # types str and int different, so not found
+ (
+ {"A": ["a", {"id": "1", "c": "d"}]},
+ {"A": {"$id: 1": {"c": "e"}}},
+ {"A": ["a", {"id": "1", "c": "d"}]},
+ ),
+ )
+ for t in TEST:
+ print(t)
+ deep_update(t[0], t[1])
+ self.assertEqual(t[0], t[2])
+
+ def test_update_badformat(self):
+ # This TEST contains original, incorrect patch and #TODO text that must be present
+ TEST = (
+ # conflict, index 0 is edited twice
+ ({"A": ["a", "b", "a"]}, {"A": {"$a": None, "$[0]": {"c": "d"}}}),
+ # conflict, two insertions at same index
+ ({"A": ["a", "b", "a"]}, {"A": {"$[1]": "c", "$[-2]": "d"}}),
+ ({"A": ["a", "b", "a"]}, {"A": {"$[1]": "c", "$[+1]": "d"}}),
+ # bad format keys with and without $
+ ({"A": ["a", "b", "a"]}, {"A": {"$b": {"c": "d"}, "c": 3}}),
+ # bad format empty $ and yaml incorrect
+ ({"A": ["a", "b", "a"]}, {"A": {"$": 3}}),
+ ({"A": ["a", "b", "a"]}, {"A": {"$a: b: c": 3}}),
+ ({"A": ["a", "b", "a"]}, {"A": {"$a: b, c: d": 3}}),
+ # insertion of None
+ ({"A": ["a", "b", "a"]}, {"A": {"$+": None}}),
+ # Not found, insertion of None
+ ({"A": ["a", "b", "a"]}, {"A": {"$+c": None}}),
+ # index edition out of range
+ ({"A": ["a", "b", "a"]}, {"A": {"$[5]": 6}}),
+ # conflict, two editions on index 2
+ (
+ {"A": ["a", {"id": "1", "c": "d"}]},
+ {"A": {"$id: '1'": {"c": "e"}, "$c: d": {"c": "f"}}},
+ ),
+ )
+ for t in TEST:
+ print(t)
+ self.assertRaises(DbException, deep_update, t[0], t[1])
+ try:
+ deep_update(t[0], t[1])
+ except DbException as e:
+ print(e)
+
+
+if __name__ == "__main__":
+ unittest.main()
--- /dev/null
+# Copyright 2018 Whitestack, LLC
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
+##
+from copy import deepcopy
+import http
+import logging
+import unittest
+from unittest.mock import MagicMock, Mock
+
+from osm_common.dbbase import DbException
+from osm_common.dbmemory import DbMemory
+import pytest
+
+__author__ = "Eduardo Sousa <eduardosousa@av.it.pt>"
+
+
+@pytest.fixture(scope="function", params=[True, False])
+def db_memory(request):
+ db = DbMemory(lock=request.param)
+ return db
+
+
+@pytest.fixture(scope="function", params=[True, False])
+def db_memory_with_data(request):
+ db = DbMemory(lock=request.param)
+
+ db.create("test", {"_id": 1, "data": 1})
+ db.create("test", {"_id": 2, "data": 2})
+ db.create("test", {"_id": 3, "data": 3})
+
+ return db
+
+
+@pytest.fixture(scope="function")
+def db_memory_with_many_data(request):
+ db = DbMemory(lock=False)
+
+ db.create_list(
+ "test",
+ [
+ {
+ "_id": 1,
+ "data": {"data2": {"data3": 1}},
+ "list": [{"a": 1}],
+ "text": "sometext",
+ },
+ {
+ "_id": 2,
+ "data": {"data2": {"data3": 2}},
+ "list": [{"a": 2}],
+ "list2": [1, 2, 3],
+ },
+ {"_id": 3, "data": {"data2": {"data3": 3}}, "list": [{"a": 3}]},
+ {"_id": 4, "data": {"data2": {"data3": 4}}, "list": [{"a": 4}, {"a": 0}]},
+ {"_id": 5, "data": {"data2": {"data3": 5}}, "list": [{"a": 5}]},
+ {"_id": 6, "data": {"data2": {"data3": 6}}, "list": [{"0": {"a": 1}}]},
+ {"_id": 7, "data": {"data2": {"data3": 7}}, "0": {"a": 0}},
+ {
+ "_id": 8,
+ "list": [
+ {"a": 3, "b": 0, "c": [{"a": 3, "b": 1}, {"a": 0, "b": "v"}]},
+ {"a": 0, "b": 1},
+ ],
+ },
+ ],
+ )
+ return db
+
+
+def empty_exception_message():
+ return "database exception "
+
+
+def get_one_exception_message(db_filter):
+ return "database exception Not found entry with filter='{}'".format(db_filter)
+
+
+def get_one_multiple_exception_message(db_filter):
+ return "database exception Found more than one entry with filter='{}'".format(
+ db_filter
+ )
+
+
+def del_one_exception_message(db_filter):
+ return "database exception Not found entry with filter='{}'".format(db_filter)
+
+
+def replace_exception_message(value):
+ return "database exception Not found entry with _id='{}'".format(value)
+
+
+def test_constructor():
+ db = DbMemory()
+ assert db.logger == logging.getLogger("db")
+ assert db.db == {}
+
+
+def test_constructor_with_logger():
+ logger_name = "db_local"
+ db = DbMemory(logger_name=logger_name)
+ assert db.logger == logging.getLogger(logger_name)
+ assert db.db == {}
+
+
+def test_db_connect():
+ logger_name = "db_local"
+ config = {"logger_name": logger_name}
+ db = DbMemory()
+ db.db_connect(config)
+ assert db.logger == logging.getLogger(logger_name)
+ assert db.db == {}
+
+
+def test_db_disconnect(db_memory):
+ db_memory.db_disconnect()
+
+
+@pytest.mark.parametrize(
+ "table, db_filter",
+ [
+ ("test", {}),
+ ("test", {"_id": 1}),
+ ("test", {"data": 1}),
+ ("test", {"_id": 1, "data": 1}),
+ ],
+)
+def test_get_list_with_empty_db(db_memory, table, db_filter):
+ result = db_memory.get_list(table, db_filter)
+ assert len(result) == 0
+
+
+@pytest.mark.parametrize(
+ "table, db_filter, expected_data",
+ [
+ (
+ "test",
+ {},
+ [{"_id": 1, "data": 1}, {"_id": 2, "data": 2}, {"_id": 3, "data": 3}],
+ ),
+ ("test", {"_id": 1}, [{"_id": 1, "data": 1}]),
+ ("test", {"data": 1}, [{"_id": 1, "data": 1}]),
+ ("test", {"_id": 1, "data": 1}, [{"_id": 1, "data": 1}]),
+ ("test", {"_id": 2}, [{"_id": 2, "data": 2}]),
+ ("test", {"data": 2}, [{"_id": 2, "data": 2}]),
+ ("test", {"_id": 2, "data": 2}, [{"_id": 2, "data": 2}]),
+ ("test", {"_id": 4}, []),
+ ("test", {"data": 4}, []),
+ ("test", {"_id": 4, "data": 4}, []),
+ ("test_table", {}, []),
+ ("test_table", {"_id": 1}, []),
+ ("test_table", {"data": 1}, []),
+ ("test_table", {"_id": 1, "data": 1}, []),
+ ],
+)
+def test_get_list_with_non_empty_db(
+ db_memory_with_data, table, db_filter, expected_data
+):
+ result = db_memory_with_data.get_list(table, db_filter)
+ assert len(result) == len(expected_data)
+ for data in expected_data:
+ assert data in result
+
+
+def test_get_list_exception(db_memory_with_data):
+ table = "test"
+ db_filter = {}
+ db_memory_with_data._find = MagicMock(side_effect=Exception())
+ with pytest.raises(DbException) as excinfo:
+ db_memory_with_data.get_list(table, db_filter)
+ assert str(excinfo.value) == empty_exception_message()
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "table, db_filter, expected_data",
+ [
+ ("test", {"_id": 1}, {"_id": 1, "data": 1}),
+ ("test", {"_id": 2}, {"_id": 2, "data": 2}),
+ ("test", {"_id": 3}, {"_id": 3, "data": 3}),
+ ("test", {"data": 1}, {"_id": 1, "data": 1}),
+ ("test", {"data": 2}, {"_id": 2, "data": 2}),
+ ("test", {"data": 3}, {"_id": 3, "data": 3}),
+ ("test", {"_id": 1, "data": 1}, {"_id": 1, "data": 1}),
+ ("test", {"_id": 2, "data": 2}, {"_id": 2, "data": 2}),
+ ("test", {"_id": 3, "data": 3}, {"_id": 3, "data": 3}),
+ ],
+)
+def test_get_one(db_memory_with_data, table, db_filter, expected_data):
+ result = db_memory_with_data.get_one(table, db_filter)
+ assert result == expected_data
+ assert len(db_memory_with_data.db) == 1
+ assert table in db_memory_with_data.db
+ assert len(db_memory_with_data.db[table]) == 3
+ assert result in db_memory_with_data.db[table]
+
+
+@pytest.mark.parametrize(
+ "db_filter, expected_ids",
+ [
+ ({}, [1, 2, 3, 4, 5, 6, 7, 8]),
+ ({"_id": 1}, [1]),
+ ({"data.data2.data3": 2}, [2]),
+ ({"data.data2.data3.eq": 2}, [2]),
+ ({"data.data2.data3": [2]}, [2]),
+ ({"data.data2.data3.cont": [2]}, [2]),
+ ({"data.data2.data3.neq": 2}, [1, 3, 4, 5, 6, 7, 8]),
+ ({"data.data2.data3.neq": [2]}, [1, 3, 4, 5, 6, 7, 8]),
+ ({"data.data2.data3.ncont": [2]}, [1, 3, 4, 5, 6, 7, 8]),
+ ({"data.data2.data3": [2, 3]}, [2, 3]),
+ ({"data.data2.data3.gt": 4}, [5, 6, 7]),
+ ({"data.data2.data3.gte": 4}, [4, 5, 6, 7]),
+ ({"data.data2.data3.lt": 4}, [1, 2, 3]),
+ ({"data.data2.data3.lte": 4}, [1, 2, 3, 4]),
+ ({"data.data2.data3.lte": 4.5}, [1, 2, 3, 4]),
+ ({"data.data2.data3.gt": "text"}, []),
+ ({"nonexist.nonexist": "4"}, []),
+ ({"nonexist.nonexist": None}, [1, 2, 3, 4, 5, 6, 7, 8]),
+ ({"nonexist.nonexist.neq": "4"}, [1, 2, 3, 4, 5, 6, 7, 8]),
+ ({"nonexist.nonexist.neq": None}, []),
+ ({"text.eq": "sometext"}, [1]),
+ ({"text.neq": "sometext"}, [2, 3, 4, 5, 6, 7, 8]),
+ ({"text.eq": "somet"}, []),
+ ({"text.gte": "a"}, [1]),
+ ({"text.gte": "somet"}, [1]),
+ ({"text.gte": "sometext"}, [1]),
+ ({"text.lt": "somet"}, []),
+ ({"data.data2.data3": 2, "data.data2.data4": None}, [2]),
+ ({"data.data2.data3": 2, "data.data2.data4": 5}, []),
+ ({"data.data2.data3": 4}, [4]),
+ ({"data.data2.data3": [3, 4, "e"]}, [3, 4]),
+ ({"data.data2.data3": None}, [8]),
+ ({"data.data2": "4"}, []),
+ ({"list.0.a": 1}, [1, 6]),
+ ({"list2": 1}, [2]),
+ ({"list2": [1, 5]}, [2]),
+ ({"list2": [1, 2]}, [2]),
+ ({"list2": [5, 7]}, []),
+ ({"list.ANYINDEX.a": 1}, [1]),
+ ({"list.a": 3, "list.b": 1}, [8]),
+ ({"list.ANYINDEX.a": 3, "list.ANYINDEX.b": 1}, []),
+ ({"list.ANYINDEX.a": 3, "list.ANYINDEX.c.a": 3}, [8]),
+ ({"list.ANYINDEX.a": 3, "list.ANYINDEX.b": 0}, [8]),
+ (
+ {
+ "list.ANYINDEX.a": 3,
+ "list.ANYINDEX.c.ANYINDEX.a": 0,
+ "list.ANYINDEX.c.ANYINDEX.b": "v",
+ },
+ [8],
+ ),
+ (
+ {
+ "list.ANYINDEX.a": 3,
+ "list.ANYINDEX.c.ANYINDEX.a": 0,
+ "list.ANYINDEX.c.ANYINDEX.b": 1,
+ },
+ [],
+ ),
+ ({"list.c.b": 1}, [8]),
+ ({"list.c.b": None}, [1, 2, 3, 4, 5, 6, 7]),
+ # ({"data.data2.data3": 4}, []),
+ # ({"data.data2.data3": 4}, []),
+ ],
+)
+def test_get_list(db_memory_with_many_data, db_filter, expected_ids):
+ result = db_memory_with_many_data.get_list("test", db_filter)
+ assert isinstance(result, list)
+ result_ids = [item["_id"] for item in result]
+ assert len(result) == len(
+ expected_ids
+ ), "for db_filter={} result={} expected_ids={}".format(
+ db_filter, result, result_ids
+ )
+ assert result_ids == expected_ids
+ for i in range(len(result)):
+ assert result[i] in db_memory_with_many_data.db["test"]
+
+ assert len(db_memory_with_many_data.db) == 1
+ assert "test" in db_memory_with_many_data.db
+ assert len(db_memory_with_many_data.db["test"]) == 8
+ result = db_memory_with_many_data.count("test", db_filter)
+ assert result == len(expected_ids)
+
+
+@pytest.mark.parametrize(
+ "table, db_filter, expected_data", [("test", {}, {"_id": 1, "data": 1})]
+)
+def test_get_one_with_multiple_results(
+ db_memory_with_data, table, db_filter, expected_data
+):
+ result = db_memory_with_data.get_one(table, db_filter, fail_on_more=False)
+ assert result == expected_data
+ assert len(db_memory_with_data.db) == 1
+ assert table in db_memory_with_data.db
+ assert len(db_memory_with_data.db[table]) == 3
+ assert result in db_memory_with_data.db[table]
+
+
+def test_get_one_with_multiple_results_exception(db_memory_with_data):
+ table = "test"
+ db_filter = {}
+ with pytest.raises(DbException) as excinfo:
+ db_memory_with_data.get_one(table, db_filter)
+ assert str(excinfo.value) == (
+ empty_exception_message() + get_one_multiple_exception_message(db_filter)
+ )
+ # assert excinfo.value.http_code == http.HTTPStatus.CONFLICT
+
+
+@pytest.mark.parametrize(
+ "table, db_filter",
+ [
+ ("test", {"_id": 4}),
+ ("test", {"data": 4}),
+ ("test", {"_id": 4, "data": 4}),
+ ("test_table", {"_id": 4}),
+ ("test_table", {"data": 4}),
+ ("test_table", {"_id": 4, "data": 4}),
+ ],
+)
+def test_get_one_with_non_empty_db_exception(db_memory_with_data, table, db_filter):
+ with pytest.raises(DbException) as excinfo:
+ db_memory_with_data.get_one(table, db_filter)
+ assert str(excinfo.value) == (
+ empty_exception_message() + get_one_exception_message(db_filter)
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "table, db_filter",
+ [
+ ("test", {"_id": 4}),
+ ("test", {"data": 4}),
+ ("test", {"_id": 4, "data": 4}),
+ ("test_table", {"_id": 4}),
+ ("test_table", {"data": 4}),
+ ("test_table", {"_id": 4, "data": 4}),
+ ],
+)
+def test_get_one_with_non_empty_db_none(db_memory_with_data, table, db_filter):
+ result = db_memory_with_data.get_one(table, db_filter, fail_on_empty=False)
+ assert result is None
+
+
+@pytest.mark.parametrize(
+ "table, db_filter",
+ [
+ ("test", {"_id": 4}),
+ ("test", {"data": 4}),
+ ("test", {"_id": 4, "data": 4}),
+ ("test_table", {"_id": 4}),
+ ("test_table", {"data": 4}),
+ ("test_table", {"_id": 4, "data": 4}),
+ ],
+)
+def test_get_one_with_empty_db_exception(db_memory, table, db_filter):
+ with pytest.raises(DbException) as excinfo:
+ db_memory.get_one(table, db_filter)
+ assert str(excinfo.value) == (
+ empty_exception_message() + get_one_exception_message(db_filter)
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "table, db_filter",
+ [
+ ("test", {"_id": 4}),
+ ("test", {"data": 4}),
+ ("test", {"_id": 4, "data": 4}),
+ ("test_table", {"_id": 4}),
+ ("test_table", {"data": 4}),
+ ("test_table", {"_id": 4, "data": 4}),
+ ],
+)
+def test_get_one_with_empty_db_none(db_memory, table, db_filter):
+ result = db_memory.get_one(table, db_filter, fail_on_empty=False)
+ assert result is None
+
+
+def test_get_one_generic_exception(db_memory_with_data):
+ table = "test"
+ db_filter = {}
+ db_memory_with_data._find = MagicMock(side_effect=Exception())
+ with pytest.raises(DbException) as excinfo:
+ db_memory_with_data.get_one(table, db_filter)
+ assert str(excinfo.value) == empty_exception_message()
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "table, db_filter, expected_data",
+ [
+ ("test", {}, []),
+ ("test", {"_id": 1}, [{"_id": 2, "data": 2}, {"_id": 3, "data": 3}]),
+ ("test", {"_id": 2}, [{"_id": 1, "data": 1}, {"_id": 3, "data": 3}]),
+ ("test", {"_id": 1, "data": 1}, [{"_id": 2, "data": 2}, {"_id": 3, "data": 3}]),
+ ("test", {"_id": 2, "data": 2}, [{"_id": 1, "data": 1}, {"_id": 3, "data": 3}]),
+ ],
+)
+def test_del_list_with_non_empty_db(
+ db_memory_with_data, table, db_filter, expected_data
+):
+ result = db_memory_with_data.del_list(table, db_filter)
+ assert result["deleted"] == (3 - len(expected_data))
+ assert len(db_memory_with_data.db) == 1
+ assert table in db_memory_with_data.db
+ assert len(db_memory_with_data.db[table]) == len(expected_data)
+ for data in expected_data:
+ assert data in db_memory_with_data.db[table]
+
+
+@pytest.mark.parametrize(
+ "table, db_filter",
+ [
+ ("test", {}),
+ ("test", {"_id": 1}),
+ ("test", {"_id": 2}),
+ ("test", {"data": 1}),
+ ("test", {"data": 2}),
+ ("test", {"_id": 1, "data": 1}),
+ ("test", {"_id": 2, "data": 2}),
+ ],
+)
+def test_del_list_with_empty_db(db_memory, table, db_filter):
+ result = db_memory.del_list(table, db_filter)
+ assert result["deleted"] == 0
+
+
+def test_del_list_generic_exception(db_memory_with_data):
+ table = "test"
+ db_filter = {}
+ db_memory_with_data._find = MagicMock(side_effect=Exception())
+ with pytest.raises(DbException) as excinfo:
+ db_memory_with_data.del_list(table, db_filter)
+ assert str(excinfo.value) == empty_exception_message()
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "table, db_filter, data",
+ [
+ ("test", {}, {"_id": 1, "data": 1}),
+ ("test", {"_id": 1}, {"_id": 1, "data": 1}),
+ ("test", {"data": 1}, {"_id": 1, "data": 1}),
+ ("test", {"_id": 1, "data": 1}, {"_id": 1, "data": 1}),
+ ("test", {"_id": 2}, {"_id": 2, "data": 2}),
+ ("test", {"data": 2}, {"_id": 2, "data": 2}),
+ ("test", {"_id": 2, "data": 2}, {"_id": 2, "data": 2}),
+ ],
+)
+def test_del_one(db_memory_with_data, table, db_filter, data):
+ result = db_memory_with_data.del_one(table, db_filter)
+ assert result == {"deleted": 1}
+ assert len(db_memory_with_data.db) == 1
+ assert table in db_memory_with_data.db
+ assert len(db_memory_with_data.db[table]) == 2
+ assert data not in db_memory_with_data.db[table]
+
+
+@pytest.mark.parametrize(
+ "table, db_filter",
+ [
+ ("test", {}),
+ ("test", {"_id": 1}),
+ ("test", {"_id": 2}),
+ ("test", {"data": 1}),
+ ("test", {"data": 2}),
+ ("test", {"_id": 1, "data": 1}),
+ ("test", {"_id": 2, "data": 2}),
+ ("test_table", {}),
+ ("test_table", {"_id": 1}),
+ ("test_table", {"_id": 2}),
+ ("test_table", {"data": 1}),
+ ("test_table", {"data": 2}),
+ ("test_table", {"_id": 1, "data": 1}),
+ ("test_table", {"_id": 2, "data": 2}),
+ ],
+)
+def test_del_one_with_empty_db_exception(db_memory, table, db_filter):
+ with pytest.raises(DbException) as excinfo:
+ db_memory.del_one(table, db_filter)
+ assert str(excinfo.value) == (
+ empty_exception_message() + del_one_exception_message(db_filter)
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "table, db_filter",
+ [
+ ("test", {}),
+ ("test", {"_id": 1}),
+ ("test", {"_id": 2}),
+ ("test", {"data": 1}),
+ ("test", {"data": 2}),
+ ("test", {"_id": 1, "data": 1}),
+ ("test", {"_id": 2, "data": 2}),
+ ("test_table", {}),
+ ("test_table", {"_id": 1}),
+ ("test_table", {"_id": 2}),
+ ("test_table", {"data": 1}),
+ ("test_table", {"data": 2}),
+ ("test_table", {"_id": 1, "data": 1}),
+ ("test_table", {"_id": 2, "data": 2}),
+ ],
+)
+def test_del_one_with_empty_db_none(db_memory, table, db_filter):
+ result = db_memory.del_one(table, db_filter, fail_on_empty=False)
+ assert result is None
+
+
+@pytest.mark.parametrize(
+ "table, db_filter",
+ [
+ ("test", {"_id": 4}),
+ ("test", {"_id": 5}),
+ ("test", {"data": 4}),
+ ("test", {"data": 5}),
+ ("test", {"_id": 1, "data": 2}),
+ ("test", {"_id": 2, "data": 3}),
+ ("test_table", {}),
+ ("test_table", {"_id": 1}),
+ ("test_table", {"_id": 2}),
+ ("test_table", {"data": 1}),
+ ("test_table", {"data": 2}),
+ ("test_table", {"_id": 1, "data": 1}),
+ ("test_table", {"_id": 2, "data": 2}),
+ ],
+)
+def test_del_one_with_non_empty_db_exception(db_memory_with_data, table, db_filter):
+ with pytest.raises(DbException) as excinfo:
+ db_memory_with_data.del_one(table, db_filter)
+ assert str(excinfo.value) == (
+ empty_exception_message() + del_one_exception_message(db_filter)
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "table, db_filter",
+ [
+ ("test", {"_id": 4}),
+ ("test", {"_id": 5}),
+ ("test", {"data": 4}),
+ ("test", {"data": 5}),
+ ("test", {"_id": 1, "data": 2}),
+ ("test", {"_id": 2, "data": 3}),
+ ("test_table", {}),
+ ("test_table", {"_id": 1}),
+ ("test_table", {"_id": 2}),
+ ("test_table", {"data": 1}),
+ ("test_table", {"data": 2}),
+ ("test_table", {"_id": 1, "data": 1}),
+ ("test_table", {"_id": 2, "data": 2}),
+ ],
+)
+def test_del_one_with_non_empty_db_none(db_memory_with_data, table, db_filter):
+ result = db_memory_with_data.del_one(table, db_filter, fail_on_empty=False)
+ assert result is None
+
+
+@pytest.mark.parametrize("fail_on_empty", [(True), (False)])
+def test_del_one_generic_exception(db_memory_with_data, fail_on_empty):
+ table = "test"
+ db_filter = {}
+ db_memory_with_data._find = MagicMock(side_effect=Exception())
+ with pytest.raises(DbException) as excinfo:
+ db_memory_with_data.del_one(table, db_filter, fail_on_empty=fail_on_empty)
+ assert str(excinfo.value) == empty_exception_message()
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "table, _id, indata",
+ [
+ ("test", 1, {"_id": 1, "data": 42}),
+ ("test", 1, {"_id": 1, "data": 42, "kk": 34}),
+ ("test", 1, {"_id": 1}),
+ ("test", 2, {"_id": 2, "data": 42}),
+ ("test", 2, {"_id": 2, "data": 42, "kk": 34}),
+ ("test", 2, {"_id": 2}),
+ ("test", 3, {"_id": 3, "data": 42}),
+ ("test", 3, {"_id": 3, "data": 42, "kk": 34}),
+ ("test", 3, {"_id": 3}),
+ ],
+)
+def test_replace(db_memory_with_data, table, _id, indata):
+ result = db_memory_with_data.replace(table, _id, indata)
+ assert result == {"updated": 1}
+ assert len(db_memory_with_data.db) == 1
+ assert table in db_memory_with_data.db
+ assert len(db_memory_with_data.db[table]) == 3
+ assert indata in db_memory_with_data.db[table]
+
+
+@pytest.mark.parametrize(
+ "table, _id, indata",
+ [
+ ("test", 1, {"_id": 1, "data": 42}),
+ ("test", 2, {"_id": 2}),
+ ("test", 3, {"_id": 3}),
+ ],
+)
+def test_replace_without_data_exception(db_memory, table, _id, indata):
+ with pytest.raises(DbException) as excinfo:
+ db_memory.replace(table, _id, indata, fail_on_empty=True)
+ assert str(excinfo.value) == (replace_exception_message(_id))
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "table, _id, indata",
+ [
+ ("test", 1, {"_id": 1, "data": 42}),
+ ("test", 2, {"_id": 2}),
+ ("test", 3, {"_id": 3}),
+ ],
+)
+def test_replace_without_data_none(db_memory, table, _id, indata):
+ result = db_memory.replace(table, _id, indata, fail_on_empty=False)
+ assert result is None
+
+
+@pytest.mark.parametrize(
+ "table, _id, indata",
+ [
+ ("test", 11, {"_id": 11, "data": 42}),
+ ("test", 12, {"_id": 12}),
+ ("test", 33, {"_id": 33}),
+ ],
+)
+def test_replace_with_data_exception(db_memory_with_data, table, _id, indata):
+ with pytest.raises(DbException) as excinfo:
+ db_memory_with_data.replace(table, _id, indata, fail_on_empty=True)
+ assert str(excinfo.value) == (replace_exception_message(_id))
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "table, _id, indata",
+ [
+ ("test", 11, {"_id": 11, "data": 42}),
+ ("test", 12, {"_id": 12}),
+ ("test", 33, {"_id": 33}),
+ ],
+)
+def test_replace_with_data_none(db_memory_with_data, table, _id, indata):
+ result = db_memory_with_data.replace(table, _id, indata, fail_on_empty=False)
+ assert result is None
+
+
+@pytest.mark.parametrize("fail_on_empty", [True, False])
+def test_replace_generic_exception(db_memory_with_data, fail_on_empty):
+ table = "test"
+ _id = {}
+ indata = {"_id": 1, "data": 1}
+ db_memory_with_data._find = MagicMock(side_effect=Exception())
+ with pytest.raises(DbException) as excinfo:
+ db_memory_with_data.replace(table, _id, indata, fail_on_empty=fail_on_empty)
+ assert str(excinfo.value) == empty_exception_message()
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "table, id, data",
+ [
+ ("test", "1", {"data": 1}),
+ ("test", "1", {"data": 2}),
+ ("test", "2", {"data": 1}),
+ ("test", "2", {"data": 2}),
+ ("test_table", "1", {"data": 1}),
+ ("test_table", "1", {"data": 2}),
+ ("test_table", "2", {"data": 1}),
+ ("test_table", "2", {"data": 2}),
+ ("test", "1", {"data_1": 1, "data_2": 2}),
+ ("test", "1", {"data_1": 2, "data_2": 1}),
+ ("test", "2", {"data_1": 1, "data_2": 2}),
+ ("test", "2", {"data_1": 2, "data_2": 1}),
+ ("test_table", "1", {"data_1": 1, "data_2": 2}),
+ ("test_table", "1", {"data_1": 2, "data_2": 1}),
+ ("test_table", "2", {"data_1": 1, "data_2": 2}),
+ ("test_table", "2", {"data_1": 2, "data_2": 1}),
+ ],
+)
+def test_create_with_empty_db_with_id(db_memory, table, id, data):
+ data_to_insert = data
+ data_to_insert["_id"] = id
+ returned_id = db_memory.create(table, data_to_insert)
+ assert returned_id == id
+ assert len(db_memory.db) == 1
+ assert table in db_memory.db
+ assert len(db_memory.db[table]) == 1
+ assert data_to_insert in db_memory.db[table]
+
+
+@pytest.mark.parametrize(
+ "table, id, data",
+ [
+ ("test", "4", {"data": 1}),
+ ("test", "5", {"data": 2}),
+ ("test", "4", {"data": 1}),
+ ("test", "5", {"data": 2}),
+ ("test_table", "4", {"data": 1}),
+ ("test_table", "5", {"data": 2}),
+ ("test_table", "4", {"data": 1}),
+ ("test_table", "5", {"data": 2}),
+ ("test", "4", {"data_1": 1, "data_2": 2}),
+ ("test", "5", {"data_1": 2, "data_2": 1}),
+ ("test", "4", {"data_1": 1, "data_2": 2}),
+ ("test", "5", {"data_1": 2, "data_2": 1}),
+ ("test_table", "4", {"data_1": 1, "data_2": 2}),
+ ("test_table", "5", {"data_1": 2, "data_2": 1}),
+ ("test_table", "4", {"data_1": 1, "data_2": 2}),
+ ("test_table", "5", {"data_1": 2, "data_2": 1}),
+ ],
+)
+def test_create_with_non_empty_db_with_id(db_memory_with_data, table, id, data):
+ data_to_insert = data
+ data_to_insert["_id"] = id
+ returned_id = db_memory_with_data.create(table, data_to_insert)
+ assert returned_id == id
+ assert len(db_memory_with_data.db) == (1 if table == "test" else 2)
+ assert table in db_memory_with_data.db
+ assert len(db_memory_with_data.db[table]) == (4 if table == "test" else 1)
+ assert data_to_insert in db_memory_with_data.db[table]
+
+
+@pytest.mark.parametrize(
+ "table, data",
+ [
+ ("test", {"data": 1}),
+ ("test", {"data": 2}),
+ ("test", {"data": 1}),
+ ("test", {"data": 2}),
+ ("test_table", {"data": 1}),
+ ("test_table", {"data": 2}),
+ ("test_table", {"data": 1}),
+ ("test_table", {"data": 2}),
+ ("test", {"data_1": 1, "data_2": 2}),
+ ("test", {"data_1": 2, "data_2": 1}),
+ ("test", {"data_1": 1, "data_2": 2}),
+ ("test", {"data_1": 2, "data_2": 1}),
+ ("test_table", {"data_1": 1, "data_2": 2}),
+ ("test_table", {"data_1": 2, "data_2": 1}),
+ ("test_table", {"data_1": 1, "data_2": 2}),
+ ("test_table", {"data_1": 2, "data_2": 1}),
+ ],
+)
+def test_create_with_empty_db_without_id(db_memory, table, data):
+ returned_id = db_memory.create(table, data)
+ assert len(db_memory.db) == 1
+ assert table in db_memory.db
+ assert len(db_memory.db[table]) == 1
+ data_inserted = data
+ data_inserted["_id"] = returned_id
+ assert data_inserted in db_memory.db[table]
+
+
+@pytest.mark.parametrize(
+ "table, data",
+ [
+ ("test", {"data": 1}),
+ ("test", {"data": 2}),
+ ("test", {"data": 1}),
+ ("test", {"data": 2}),
+ ("test_table", {"data": 1}),
+ ("test_table", {"data": 2}),
+ ("test_table", {"data": 1}),
+ ("test_table", {"data": 2}),
+ ("test", {"data_1": 1, "data_2": 2}),
+ ("test", {"data_1": 2, "data_2": 1}),
+ ("test", {"data_1": 1, "data_2": 2}),
+ ("test", {"data_1": 2, "data_2": 1}),
+ ("test_table", {"data_1": 1, "data_2": 2}),
+ ("test_table", {"data_1": 2, "data_2": 1}),
+ ("test_table", {"data_1": 1, "data_2": 2}),
+ ("test_table", {"data_1": 2, "data_2": 1}),
+ ],
+)
+def test_create_with_non_empty_db_without_id(db_memory_with_data, table, data):
+ returned_id = db_memory_with_data.create(table, data)
+ assert len(db_memory_with_data.db) == (1 if table == "test" else 2)
+ assert table in db_memory_with_data.db
+ assert len(db_memory_with_data.db[table]) == (4 if table == "test" else 1)
+ data_inserted = data
+ data_inserted["_id"] = returned_id
+ assert data_inserted in db_memory_with_data.db[table]
+
+
+def test_create_with_exception(db_memory):
+ table = "test"
+ data = {"_id": 1, "data": 1}
+ db_memory.db = MagicMock()
+ db_memory.db.__contains__.side_effect = Exception()
+ with pytest.raises(DbException) as excinfo:
+ db_memory.create(table, data)
+ assert str(excinfo.value) == empty_exception_message()
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "db_content, update_dict, expected, message",
+ [
+ (
+ {"a": {"none": None}},
+ {"a.b.num": "v"},
+ {"a": {"none": None, "b": {"num": "v"}}},
+ "create dict",
+ ),
+ (
+ {"a": {"none": None}},
+ {"a.none.num": "v"},
+ {"a": {"none": {"num": "v"}}},
+ "create dict over none",
+ ),
+ (
+ {"a": {"b": {"num": 4}}},
+ {"a.b.num": "v"},
+ {"a": {"b": {"num": "v"}}},
+ "replace_number",
+ ),
+ (
+ {"a": {"b": {"num": 4}}},
+ {"a.b.num.c.d": "v"},
+ None,
+ "create dict over number should fail",
+ ),
+ (
+ {"a": {"b": {"num": 4}}},
+ {"a.b": "v"},
+ {"a": {"b": "v"}},
+ "replace dict with a string",
+ ),
+ (
+ {"a": {"b": {"num": 4}}},
+ {"a.b": None},
+ {"a": {"b": None}},
+ "replace dict with None",
+ ),
+ (
+ {"a": [{"b": {"num": 4}}]},
+ {"a.b.num": "v"},
+ None,
+ "create dict over list should fail",
+ ),
+ (
+ {"a": [{"b": {"num": 4}}]},
+ {"a.0.b.num": "v"},
+ {"a": [{"b": {"num": "v"}}]},
+ "set list",
+ ),
+ (
+ {"a": [{"b": {"num": 4}}]},
+ {"a.3.b.num": "v"},
+ {"a": [{"b": {"num": 4}}, None, None, {"b": {"num": "v"}}]},
+ "expand list",
+ ),
+ ({"a": [[4]]}, {"a.0.0": "v"}, {"a": [["v"]]}, "set nested list"),
+ ({"a": [[4]]}, {"a.0.2": "v"}, {"a": [[4, None, "v"]]}, "expand nested list"),
+ (
+ {"a": [[4]]},
+ {"a.2.2": "v"},
+ {"a": [[4], None, {"2": "v"}]},
+ "expand list and add number key",
+ ),
+ ],
+)
+def test_set_one(db_memory, db_content, update_dict, expected, message):
+ db_memory._find = Mock(return_value=((0, db_content),))
+ if expected is None:
+ with pytest.raises(DbException) as excinfo:
+ db_memory.set_one("table", {}, update_dict)
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND, message
+ else:
+ db_memory.set_one("table", {}, update_dict)
+ assert db_content == expected, message
+
+
+class TestDbMemory(unittest.TestCase):
+ # TODO to delete. This is cover with pytest test_set_one.
+ def test_set_one(self):
+ test_set = (
+ # (database content, set-content, expected database content (None=fails), message)
+ (
+ {"a": {"none": None}},
+ {"a.b.num": "v"},
+ {"a": {"none": None, "b": {"num": "v"}}},
+ "create dict",
+ ),
+ (
+ {"a": {"none": None}},
+ {"a.none.num": "v"},
+ {"a": {"none": {"num": "v"}}},
+ "create dict over none",
+ ),
+ (
+ {"a": {"b": {"num": 4}}},
+ {"a.b.num": "v"},
+ {"a": {"b": {"num": "v"}}},
+ "replace_number",
+ ),
+ (
+ {"a": {"b": {"num": 4}}},
+ {"a.b.num.c.d": "v"},
+ None,
+ "create dict over number should fail",
+ ),
+ (
+ {"a": {"b": {"num": 4}}},
+ {"a.b": "v"},
+ {"a": {"b": "v"}},
+ "replace dict with a string",
+ ),
+ (
+ {"a": {"b": {"num": 4}}},
+ {"a.b": None},
+ {"a": {"b": None}},
+ "replace dict with None",
+ ),
+ (
+ {"a": [{"b": {"num": 4}}]},
+ {"a.b.num": "v"},
+ None,
+ "create dict over list should fail",
+ ),
+ (
+ {"a": [{"b": {"num": 4}}]},
+ {"a.0.b.num": "v"},
+ {"a": [{"b": {"num": "v"}}]},
+ "set list",
+ ),
+ (
+ {"a": [{"b": {"num": 4}}]},
+ {"a.3.b.num": "v"},
+ {"a": [{"b": {"num": 4}}, None, None, {"b": {"num": "v"}}]},
+ "expand list",
+ ),
+ ({"a": [[4]]}, {"a.0.0": "v"}, {"a": [["v"]]}, "set nested list"),
+ (
+ {"a": [[4]]},
+ {"a.0.2": "v"},
+ {"a": [[4, None, "v"]]},
+ "expand nested list",
+ ),
+ (
+ {"a": [[4]]},
+ {"a.2.2": "v"},
+ {"a": [[4], None, {"2": "v"}]},
+ "expand list and add number key",
+ ),
+ ({"a": None}, {"b.c": "v"}, {"a": None, "b": {"c": "v"}}, "expand at root"),
+ )
+ db_men = DbMemory()
+ db_men._find = Mock()
+ for db_content, update_dict, expected, message in test_set:
+ db_men._find.return_value = ((0, db_content),)
+ if expected is None:
+ self.assertRaises(DbException, db_men.set_one, "table", {}, update_dict)
+ else:
+ db_men.set_one("table", {}, update_dict)
+ self.assertEqual(db_content, expected, message)
+
+ def test_set_one_pull(self):
+ example = {"a": [1, "1", 1], "d": {}, "n": None}
+ test_set = (
+ # (database content, set-content, expected database content (None=fails), message)
+ (example, {"a": "1"}, {"a": [1, 1], "d": {}, "n": None}, "pull one item"),
+ (example, {"a": 1}, {"a": ["1"], "d": {}, "n": None}, "pull two items"),
+ (example, {"a": "v"}, example, "pull non existing item"),
+ (example, {"a.6": 1}, example, "pull non existing arrray"),
+ (example, {"d.b.c": 1}, example, "pull non existing arrray2"),
+ (example, {"b": 1}, example, "pull non existing arrray3"),
+ (example, {"d": 1}, None, "pull over dict"),
+ (example, {"n": 1}, None, "pull over None"),
+ )
+ db_men = DbMemory()
+ db_men._find = Mock()
+ for db_content, pull_dict, expected, message in test_set:
+ db_content = deepcopy(db_content)
+ db_men._find.return_value = ((0, db_content),)
+ if expected is None:
+ self.assertRaises(
+ DbException,
+ db_men.set_one,
+ "table",
+ {},
+ None,
+ fail_on_empty=False,
+ pull=pull_dict,
+ )
+ else:
+ db_men.set_one("table", {}, None, pull=pull_dict)
+ self.assertEqual(db_content, expected, message)
+
+ def test_set_one_push(self):
+ example = {"a": [1, "1", 1], "d": {}, "n": None}
+ test_set = (
+ # (database content, set-content, expected database content (None=fails), message)
+ (
+ example,
+ {"d.b.c": 1},
+ {"a": [1, "1", 1], "d": {"b": {"c": [1]}}, "n": None},
+ "push non existing arrray2",
+ ),
+ (
+ example,
+ {"b": 1},
+ {"a": [1, "1", 1], "d": {}, "b": [1], "n": None},
+ "push non existing arrray3",
+ ),
+ (
+ example,
+ {"a.6": 1},
+ {"a": [1, "1", 1, None, None, None, [1]], "d": {}, "n": None},
+ "push non existing arrray",
+ ),
+ (
+ example,
+ {"a": 2},
+ {"a": [1, "1", 1, 2], "d": {}, "n": None},
+ "push one item",
+ ),
+ (
+ example,
+ {"a": {1: 1}},
+ {"a": [1, "1", 1, {1: 1}], "d": {}, "n": None},
+ "push a dict",
+ ),
+ (example, {"d": 1}, None, "push over dict"),
+ (example, {"n": 1}, None, "push over None"),
+ )
+ db_men = DbMemory()
+ db_men._find = Mock()
+ for db_content, push_dict, expected, message in test_set:
+ db_content = deepcopy(db_content)
+ db_men._find.return_value = ((0, db_content),)
+ if expected is None:
+ self.assertRaises(
+ DbException,
+ db_men.set_one,
+ "table",
+ {},
+ None,
+ fail_on_empty=False,
+ push=push_dict,
+ )
+ else:
+ db_men.set_one("table", {}, None, push=push_dict)
+ self.assertEqual(db_content, expected, message)
+
+ def test_set_one_push_list(self):
+ example = {"a": [1, "1", 1], "d": {}, "n": None}
+ test_set = (
+ # (database content, set-content, expected database content (None=fails), message)
+ (
+ example,
+ {"d.b.c": [1]},
+ {"a": [1, "1", 1], "d": {"b": {"c": [1]}}, "n": None},
+ "push non existing arrray2",
+ ),
+ (
+ example,
+ {"b": [1]},
+ {"a": [1, "1", 1], "d": {}, "b": [1], "n": None},
+ "push non existing arrray3",
+ ),
+ (
+ example,
+ {"a.6": [1]},
+ {"a": [1, "1", 1, None, None, None, [1]], "d": {}, "n": None},
+ "push non existing arrray",
+ ),
+ (
+ example,
+ {"a": [2, 3]},
+ {"a": [1, "1", 1, 2, 3], "d": {}, "n": None},
+ "push two item",
+ ),
+ (
+ example,
+ {"a": [{1: 1}]},
+ {"a": [1, "1", 1, {1: 1}], "d": {}, "n": None},
+ "push a dict",
+ ),
+ (example, {"d": [1]}, None, "push over dict"),
+ (example, {"n": [1]}, None, "push over None"),
+ (example, {"a": 1}, None, "invalid push list non an array"),
+ )
+ db_men = DbMemory()
+ db_men._find = Mock()
+ for db_content, push_list, expected, message in test_set:
+ db_content = deepcopy(db_content)
+ db_men._find.return_value = ((0, db_content),)
+ if expected is None:
+ self.assertRaises(
+ DbException,
+ db_men.set_one,
+ "table",
+ {},
+ None,
+ fail_on_empty=False,
+ push_list=push_list,
+ )
+ else:
+ db_men.set_one("table", {}, None, push_list=push_list)
+ self.assertEqual(db_content, expected, message)
+
+ def test_unset_one(self):
+ example = {"a": [1, "1", 1], "d": {}, "n": None}
+ test_set = (
+ # (database content, set-content, expected database content (None=fails), message)
+ (example, {"d.b.c": 1}, example, "unset non existing"),
+ (example, {"b": 1}, example, "unset non existing"),
+ (example, {"a.6": 1}, example, "unset non existing arrray"),
+ (example, {"a": 2}, {"d": {}, "n": None}, "unset array"),
+ (example, {"d": 1}, {"a": [1, "1", 1], "n": None}, "unset dict"),
+ (example, {"n": 1}, {"a": [1, "1", 1], "d": {}}, "unset None"),
+ )
+ db_men = DbMemory()
+ db_men._find = Mock()
+ for db_content, unset_dict, expected, message in test_set:
+ db_content = deepcopy(db_content)
+ db_men._find.return_value = ((0, db_content),)
+ if expected is None:
+ self.assertRaises(
+ DbException,
+ db_men.set_one,
+ "table",
+ {},
+ None,
+ fail_on_empty=False,
+ unset=unset_dict,
+ )
+ else:
+ db_men.set_one("table", {}, None, unset=unset_dict)
+ self.assertEqual(db_content, expected, message)
--- /dev/null
+#######################################################################################
+# Copyright ETSI Contributors and Others.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#######################################################################################
+
+import logging
+from urllib.parse import quote
+
+from osm_common.dbbase import DbException, FakeLock
+from osm_common.dbmongo import DbMongo
+from pymongo import MongoClient
+import pytest
+
+
+def db_status_exception_message():
+ return "database exception Wrong database status"
+
+
+def db_version_exception_message():
+ return "database exception Invalid database version"
+
+
+def mock_get_one_status_not_enabled(a, b, c, fail_on_empty=False, fail_on_more=True):
+ return {"status": "ERROR", "version": "", "serial": ""}
+
+
+def mock_get_one_wrong_db_version(a, b, c, fail_on_empty=False, fail_on_more=True):
+ return {"status": "ENABLED", "version": "4.0", "serial": "MDY4OA=="}
+
+
+def db_generic_exception(exception):
+ return exception
+
+
+def db_generic_exception_message(message):
+ return f"database exception {message}"
+
+
+def test_constructor():
+ db = DbMongo(lock=True)
+ assert db.logger == logging.getLogger("db")
+ assert db.db is None
+ assert db.client is None
+ assert db.database_key is None
+ assert db.secret_obtained is False
+ assert db.lock.acquire() is True
+
+
+def test_constructor_with_logger():
+ logger_name = "db_mongo"
+ db = DbMongo(logger_name=logger_name, lock=False)
+ assert db.logger == logging.getLogger(logger_name)
+ assert db.db is None
+ assert db.client is None
+ assert db.database_key is None
+ assert db.secret_obtained is False
+ assert type(db.lock) == FakeLock
+
+
+@pytest.mark.parametrize(
+ "config, target_version, serial, lock",
+ [
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": "mongo:27017",
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "5.0",
+ "MDY=",
+ True,
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "masterpassword": "master",
+ "uri": "mongo:27017",
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "5.0",
+ "MDY=",
+ False,
+ ),
+ (
+ {
+ "logger_name": "logger",
+ "uri": "mongo:27017",
+ "name": "newdb",
+ "commonkey": "common",
+ },
+ "3.6",
+ "",
+ True,
+ ),
+ (
+ {
+ "uri": "mongo:27017",
+ "commonkey": "common",
+ "name": "newdb",
+ },
+ "5.0",
+ "MDIy",
+ False,
+ ),
+ (
+ {
+ "uri": "mongo:27017",
+ "masterpassword": "common",
+ "name": "newdb",
+ "loglevel": "CRITICAL",
+ },
+ "4.4",
+ "OTA=",
+ False,
+ ),
+ (
+ {
+ "uri": "mongo",
+ "masterpassword": "common",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "4.4",
+ "OTA=",
+ True,
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": quote("user4:password4@mongo"),
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "5.0",
+ "NTM=",
+ True,
+ ),
+ (
+ {
+ "logger_name": "logger",
+ "uri": quote("user3:password3@mongo:27017"),
+ "name": "newdb",
+ "commonkey": "common",
+ },
+ "4.0",
+ "NjEx",
+ False,
+ ),
+ (
+ {
+ "uri": quote("user2:password2@mongo:27017"),
+ "commonkey": "common",
+ "name": "newdb",
+ },
+ "5.0",
+ "cmV0MzI=",
+ False,
+ ),
+ (
+ {
+ "uri": quote("user1:password1@mongo:27017"),
+ "commonkey": "common",
+ "masterpassword": "master",
+ "name": "newdb",
+ "loglevel": "CRITICAL",
+ },
+ "4.0",
+ "MjMyNQ==",
+ False,
+ ),
+ (
+ {
+ "uri": quote("user1:password1@mongo"),
+ "masterpassword": "common",
+ "name": "newdb",
+ "loglevel": "CRITICAL",
+ },
+ "4.0",
+ "MjMyNQ==",
+ True,
+ ),
+ ],
+)
+def test_db_connection_with_valid_config(
+ config, target_version, serial, lock, monkeypatch
+):
+ def mock_get_one(a, b, c, fail_on_empty=False, fail_on_more=True):
+ return {"status": "ENABLED", "version": target_version, "serial": serial}
+
+ monkeypatch.setattr(DbMongo, "get_one", mock_get_one)
+ db = DbMongo(lock=lock)
+ db.db_connect(config, target_version)
+ assert (
+ db.logger == logging.getLogger(config.get("logger_name"))
+ if config.get("logger_name")
+ else logging.getLogger("db")
+ )
+ assert type(db.client) == MongoClient
+ assert db.database_key == "common"
+ assert db.logger.getEffectiveLevel() == 50 if config.get("loglevel") else 20
+
+
+@pytest.mark.parametrize(
+ "config, target_version, version_data, expected_exception_message",
+ [
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": "mongo:27017",
+ "replicaset": "rs2",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "4.0",
+ mock_get_one_status_not_enabled,
+ db_status_exception_message(),
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": "mongo:27017",
+ "replicaset": "rs4",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "5.0",
+ mock_get_one_wrong_db_version,
+ db_version_exception_message(),
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": quote("user2:pa@word2@mongo:27017"),
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "DEBUG",
+ },
+ "4.0",
+ mock_get_one_status_not_enabled,
+ db_status_exception_message(),
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": quote("username:pass1rd@mongo:27017"),
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "DEBUG",
+ },
+ "5.0",
+ mock_get_one_wrong_db_version,
+ db_version_exception_message(),
+ ),
+ ],
+)
+def test_db_connection_db_status_error(
+ config, target_version, version_data, expected_exception_message, monkeypatch
+):
+ monkeypatch.setattr(DbMongo, "get_one", version_data)
+ db = DbMongo(lock=False)
+ with pytest.raises(DbException) as exception_info:
+ db.db_connect(config, target_version)
+ assert str(exception_info.value).startswith(expected_exception_message)
+
+
+@pytest.mark.parametrize(
+ "config, target_version, lock, expected_exception",
+ [
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": "27017@/:",
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "4.0",
+ True,
+ db_generic_exception(DbException),
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": "user@pass",
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "4.0",
+ False,
+ db_generic_exception(DbException),
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": "user@pass:27017",
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "4.0",
+ True,
+ db_generic_exception(DbException),
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": "",
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "5.0",
+ False,
+ db_generic_exception(TypeError),
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": "user2::@mon:27017",
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "DEBUG",
+ },
+ "4.0",
+ True,
+ db_generic_exception(ValueError),
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "replicaset": 33,
+ "uri": "user2@@mongo:27017",
+ "name": "osmdb",
+ "loglevel": "DEBUG",
+ },
+ "5.0",
+ False,
+ db_generic_exception(TypeError),
+ ),
+ ],
+)
+def test_db_connection_with_invalid_uri(
+ config, target_version, lock, expected_exception, monkeypatch
+):
+ def mock_get_one(a, b, c, fail_on_empty=False, fail_on_more=True):
+ pass
+
+ monkeypatch.setattr(DbMongo, "get_one", mock_get_one)
+ db = DbMongo(lock=lock)
+ with pytest.raises(expected_exception) as exception_info:
+ db.db_connect(config, target_version)
+ assert type(exception_info.value) == expected_exception
+
+
+@pytest.mark.parametrize(
+ "config, target_version, expected_exception",
+ [
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "",
+ db_generic_exception(TypeError),
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "uri": "mongo:27017",
+ "replicaset": "rs0",
+ "loglevel": "CRITICAL",
+ },
+ "4.0",
+ db_generic_exception(KeyError),
+ ),
+ (
+ {
+ "replicaset": "rs0",
+ "loglevel": "CRITICAL",
+ },
+ None,
+ db_generic_exception(KeyError),
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": "",
+ "replicaset": "rs0",
+ "name": "osmdb",
+ "loglevel": "CRITICAL",
+ },
+ "5.0",
+ db_generic_exception(TypeError),
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "name": "osmdb",
+ },
+ "4.0",
+ db_generic_exception(TypeError),
+ ),
+ (
+ {
+ "logger_name": "logger",
+ "replicaset": "",
+ "uri": "user2@@mongo:27017",
+ },
+ "5.0",
+ db_generic_exception(KeyError),
+ ),
+ ],
+)
+def test_db_connection_with_missing_parameters(
+ config, target_version, expected_exception, monkeypatch
+):
+ def mock_get_one(a, b, c, fail_on_empty=False, fail_on_more=True):
+ return
+
+ monkeypatch.setattr(DbMongo, "get_one", mock_get_one)
+ db = DbMongo(lock=False)
+ with pytest.raises(expected_exception) as exception_info:
+ db.db_connect(config, target_version)
+ assert type(exception_info.value) == expected_exception
+
+
+@pytest.mark.parametrize(
+ "config, expected_exception_message",
+ [
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": "mongo:27017",
+ "replicaset": "rs0",
+ "name": "osmdb1",
+ "loglevel": "CRITICAL",
+ },
+ "MongoClient crashed",
+ ),
+ (
+ {
+ "logger_name": "mongo_logger",
+ "commonkey": "common",
+ "uri": "username:pas1ed@mongo:27017",
+ "replicaset": "rs1",
+ "name": "osmdb2",
+ "loglevel": "DEBUG",
+ },
+ "MongoClient crashed",
+ ),
+ ],
+)
+def test_db_connection_with_invalid_mongoclient(
+ config, expected_exception_message, monkeypatch
+):
+ def generate_exception(a, b, replicaSet=None):
+ raise DbException(expected_exception_message)
+
+ monkeypatch.setattr(MongoClient, "__init__", generate_exception)
+ db = DbMongo()
+ with pytest.raises(DbException) as exception_info:
+ db.db_connect(config)
+ assert str(exception_info.value) == db_generic_exception_message(
+ expected_exception_message
+ )
--- /dev/null
+# Copyright 2018 Whitestack, LLC
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
+##
+
+import http
+
+from osm_common.fsbase import FsBase, FsException
+import pytest
+
+
+def exception_message(message):
+ return "storage exception " + message
+
+
+@pytest.fixture
+def fs_base():
+ return FsBase()
+
+
+def test_constructor():
+ fs_base = FsBase()
+ assert fs_base is not None
+ assert isinstance(fs_base, FsBase)
+
+
+def test_get_params(fs_base):
+ params = fs_base.get_params()
+ assert isinstance(params, dict)
+ assert len(params) == 0
+
+
+def test_fs_connect(fs_base):
+ fs_base.fs_connect(None)
+
+
+def test_fs_disconnect(fs_base):
+ fs_base.fs_disconnect()
+
+
+def test_mkdir(fs_base):
+ with pytest.raises(FsException) as excinfo:
+ fs_base.mkdir(None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'mkdir' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+def test_file_exists(fs_base):
+ with pytest.raises(FsException) as excinfo:
+ fs_base.file_exists(None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'file_exists' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+def test_file_size(fs_base):
+ with pytest.raises(FsException) as excinfo:
+ fs_base.file_size(None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'file_size' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+def test_file_extract(fs_base):
+ with pytest.raises(FsException) as excinfo:
+ fs_base.file_extract(None, None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'file_extract' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+def test_file_open(fs_base):
+ with pytest.raises(FsException) as excinfo:
+ fs_base.file_open(None, None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'file_open' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+def test_file_delete(fs_base):
+ with pytest.raises(FsException) as excinfo:
+ fs_base.file_delete(None, None)
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'file_delete' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
--- /dev/null
+# Copyright 2018 Whitestack, LLC
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
+##
+
+import http
+import io
+import logging
+import os
+import shutil
+import tarfile
+import tempfile
+import uuid
+
+from osm_common.fsbase import FsException
+from osm_common.fslocal import FsLocal
+import pytest
+
+__author__ = "Eduardo Sousa <eduardosousa@av.it.pt>"
+
+
+def valid_path():
+ return tempfile.gettempdir() + "/"
+
+
+def invalid_path():
+ return "/#tweeter/"
+
+
+@pytest.fixture(scope="function", params=[True, False])
+def fs_local(request):
+ fs = FsLocal(lock=request.param)
+ fs.fs_connect({"path": valid_path()})
+ return fs
+
+
+def fs_connect_exception_message(path):
+ return "storage exception Invalid configuration param at '[storage]': path '{}' does not exist".format(
+ path
+ )
+
+
+def file_open_file_not_found_exception(storage):
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ return "storage exception File {} does not exist".format(f)
+
+
+def file_open_io_exception(storage):
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ return "storage exception File {} cannot be opened".format(f)
+
+
+def dir_ls_not_a_directory_exception(storage):
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ return "storage exception File {} does not exist".format(f)
+
+
+def dir_ls_io_exception(storage):
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ return "storage exception File {} cannot be opened".format(f)
+
+
+def file_delete_exception_message(storage):
+ return "storage exception File {} does not exist".format(storage)
+
+
+def test_constructor_without_logger():
+ fs = FsLocal()
+ assert fs.logger == logging.getLogger("fs")
+ assert fs.path is None
+
+
+def test_constructor_with_logger():
+ logger_name = "fs_local"
+ fs = FsLocal(logger_name=logger_name)
+ assert fs.logger == logging.getLogger(logger_name)
+ assert fs.path is None
+
+
+def test_get_params(fs_local):
+ params = fs_local.get_params()
+ assert len(params) == 2
+ assert "fs" in params
+ assert "path" in params
+ assert params["fs"] == "local"
+ assert params["path"] == valid_path()
+
+
+@pytest.mark.parametrize(
+ "config, exp_logger, exp_path",
+ [
+ ({"logger_name": "fs_local", "path": valid_path()}, "fs_local", valid_path()),
+ (
+ {"logger_name": "fs_local", "path": valid_path()[:-1]},
+ "fs_local",
+ valid_path(),
+ ),
+ ({"path": valid_path()}, "fs", valid_path()),
+ ({"path": valid_path()[:-1]}, "fs", valid_path()),
+ ],
+)
+def test_fs_connect_with_valid_config(config, exp_logger, exp_path):
+ fs = FsLocal()
+ fs.fs_connect(config)
+ assert fs.logger == logging.getLogger(exp_logger)
+ assert fs.path == exp_path
+
+
+@pytest.mark.parametrize(
+ "config, exp_exception_message",
+ [
+ (
+ {"logger_name": "fs_local", "path": invalid_path()},
+ fs_connect_exception_message(invalid_path()),
+ ),
+ (
+ {"logger_name": "fs_local", "path": invalid_path()[:-1]},
+ fs_connect_exception_message(invalid_path()[:-1]),
+ ),
+ ({"path": invalid_path()}, fs_connect_exception_message(invalid_path())),
+ (
+ {"path": invalid_path()[:-1]},
+ fs_connect_exception_message(invalid_path()[:-1]),
+ ),
+ ],
+)
+def test_fs_connect_with_invalid_path(config, exp_exception_message):
+ fs = FsLocal()
+ with pytest.raises(FsException) as excinfo:
+ fs.fs_connect(config)
+ assert str(excinfo.value) == exp_exception_message
+
+
+def test_fs_disconnect(fs_local):
+ fs_local.fs_disconnect()
+
+
+def test_mkdir_with_valid_path(fs_local):
+ folder_name = str(uuid.uuid4())
+ folder_path = valid_path() + folder_name
+ fs_local.mkdir(folder_name)
+ assert os.path.exists(folder_path)
+ # test idempotency
+ fs_local.mkdir(folder_name)
+ assert os.path.exists(folder_path)
+ os.rmdir(folder_path)
+
+
+def test_mkdir_with_exception(fs_local):
+ folder_name = str(uuid.uuid4())
+ with pytest.raises(FsException) as excinfo:
+ fs_local.mkdir(folder_name + "/" + folder_name)
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+@pytest.mark.parametrize(
+ "storage, mode, expected",
+ [
+ (str(uuid.uuid4()), "file", False),
+ ([str(uuid.uuid4())], "file", False),
+ (str(uuid.uuid4()), "dir", False),
+ ([str(uuid.uuid4())], "dir", False),
+ ],
+)
+def test_file_exists_returns_false(fs_local, storage, mode, expected):
+ assert fs_local.file_exists(storage, mode) == expected
+
+
+@pytest.mark.parametrize(
+ "storage, mode, expected",
+ [
+ (str(uuid.uuid4()), "file", True),
+ ([str(uuid.uuid4())], "file", True),
+ (str(uuid.uuid4()), "dir", True),
+ ([str(uuid.uuid4())], "dir", True),
+ ],
+)
+def test_file_exists_returns_true(fs_local, storage, mode, expected):
+ path = (
+ valid_path() + storage
+ if isinstance(storage, str)
+ else valid_path() + storage[0]
+ )
+ if mode == "file":
+ os.mknod(path)
+ elif mode == "dir":
+ os.mkdir(path)
+ assert fs_local.file_exists(storage, mode) == expected
+ if mode == "file":
+ os.remove(path)
+ elif mode == "dir":
+ os.rmdir(path)
+
+
+@pytest.mark.parametrize(
+ "storage, mode",
+ [
+ (str(uuid.uuid4()), "file"),
+ ([str(uuid.uuid4())], "file"),
+ (str(uuid.uuid4()), "dir"),
+ ([str(uuid.uuid4())], "dir"),
+ ],
+)
+def test_file_size(fs_local, storage, mode):
+ path = (
+ valid_path() + storage
+ if isinstance(storage, str)
+ else valid_path() + storage[0]
+ )
+ if mode == "file":
+ os.mknod(path)
+ elif mode == "dir":
+ os.mkdir(path)
+ size = os.path.getsize(path)
+ assert fs_local.file_size(storage) == size
+ if mode == "file":
+ os.remove(path)
+ elif mode == "dir":
+ os.rmdir(path)
+
+
+@pytest.mark.parametrize(
+ "files, path",
+ [
+ (["foo", "bar", "foobar"], str(uuid.uuid4())),
+ (["foo", "bar", "foobar"], [str(uuid.uuid4())]),
+ ],
+)
+def test_file_extract(fs_local, files, path):
+ for f in files:
+ os.mknod(valid_path() + f)
+ tar_path = valid_path() + str(uuid.uuid4()) + ".tar"
+ with tarfile.open(tar_path, "w") as tar:
+ for f in files:
+ tar.add(valid_path() + f, arcname=f)
+ with tarfile.open(tar_path, "r") as tar:
+ fs_local.file_extract(tar, path)
+ extracted_path = valid_path() + (path if isinstance(path, str) else "/".join(path))
+ ls_dir = os.listdir(extracted_path)
+ assert len(ls_dir) == len(files)
+ for f in files:
+ assert f in ls_dir
+ os.remove(tar_path)
+ for f in files:
+ os.remove(valid_path() + f)
+ shutil.rmtree(extracted_path)
+
+
+@pytest.mark.parametrize(
+ "storage, mode",
+ [
+ (str(uuid.uuid4()), "r"),
+ (str(uuid.uuid4()), "w"),
+ (str(uuid.uuid4()), "a"),
+ (str(uuid.uuid4()), "rb"),
+ (str(uuid.uuid4()), "wb"),
+ (str(uuid.uuid4()), "ab"),
+ ([str(uuid.uuid4())], "r"),
+ ([str(uuid.uuid4())], "w"),
+ ([str(uuid.uuid4())], "a"),
+ ([str(uuid.uuid4())], "rb"),
+ ([str(uuid.uuid4())], "wb"),
+ ([str(uuid.uuid4())], "ab"),
+ ],
+)
+def test_file_open(fs_local, storage, mode):
+ path = (
+ valid_path() + storage
+ if isinstance(storage, str)
+ else valid_path() + storage[0]
+ )
+ os.mknod(path)
+ file_obj = fs_local.file_open(storage, mode)
+ assert isinstance(file_obj, io.IOBase)
+ assert file_obj.closed is False
+ os.remove(path)
+
+
+@pytest.mark.parametrize(
+ "storage, mode",
+ [
+ (str(uuid.uuid4()), "r"),
+ (str(uuid.uuid4()), "rb"),
+ ([str(uuid.uuid4())], "r"),
+ ([str(uuid.uuid4())], "rb"),
+ ],
+)
+def test_file_open_file_not_found_exception(fs_local, storage, mode):
+ with pytest.raises(FsException) as excinfo:
+ fs_local.file_open(storage, mode)
+ assert str(excinfo.value) == file_open_file_not_found_exception(storage)
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize(
+ "storage, mode",
+ [
+ (str(uuid.uuid4()), "r"),
+ (str(uuid.uuid4()), "w"),
+ (str(uuid.uuid4()), "a"),
+ (str(uuid.uuid4()), "rb"),
+ (str(uuid.uuid4()), "wb"),
+ (str(uuid.uuid4()), "ab"),
+ ([str(uuid.uuid4())], "r"),
+ ([str(uuid.uuid4())], "w"),
+ ([str(uuid.uuid4())], "a"),
+ ([str(uuid.uuid4())], "rb"),
+ ([str(uuid.uuid4())], "wb"),
+ ([str(uuid.uuid4())], "ab"),
+ ],
+)
+def test_file_open_io_error(fs_local, storage, mode):
+ path = (
+ valid_path() + storage
+ if isinstance(storage, str)
+ else valid_path() + storage[0]
+ )
+ os.mknod(path)
+ os.chmod(path, 0)
+ with pytest.raises(FsException) as excinfo:
+ fs_local.file_open(storage, mode)
+ assert str(excinfo.value) == file_open_io_exception(storage)
+ assert excinfo.value.http_code == http.HTTPStatus.BAD_REQUEST
+ os.remove(path)
+
+
+@pytest.mark.parametrize(
+ "storage, with_files",
+ [
+ (str(uuid.uuid4()), True),
+ (str(uuid.uuid4()), False),
+ ([str(uuid.uuid4())], True),
+ ([str(uuid.uuid4())], False),
+ ],
+)
+def test_dir_ls(fs_local, storage, with_files):
+ path = (
+ valid_path() + storage
+ if isinstance(storage, str)
+ else valid_path() + storage[0]
+ )
+ os.mkdir(path)
+ if with_files is True:
+ file_name = str(uuid.uuid4())
+ file_path = path + "/" + file_name
+ os.mknod(file_path)
+ result = fs_local.dir_ls(storage)
+
+ if with_files is True:
+ assert len(result) == 1
+ assert result[0] == file_name
+ else:
+ assert len(result) == 0
+ shutil.rmtree(path)
+
+
+@pytest.mark.parametrize("storage", [(str(uuid.uuid4())), ([str(uuid.uuid4())])])
+def test_dir_ls_with_not_a_directory_error(fs_local, storage):
+ path = (
+ valid_path() + storage
+ if isinstance(storage, str)
+ else valid_path() + storage[0]
+ )
+ os.mknod(path)
+ with pytest.raises(FsException) as excinfo:
+ fs_local.dir_ls(storage)
+ assert str(excinfo.value) == dir_ls_not_a_directory_exception(storage)
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+ os.remove(path)
+
+
+@pytest.mark.parametrize("storage", [(str(uuid.uuid4())), ([str(uuid.uuid4())])])
+def test_dir_ls_with_io_error(fs_local, storage):
+ path = (
+ valid_path() + storage
+ if isinstance(storage, str)
+ else valid_path() + storage[0]
+ )
+ os.mkdir(path)
+ os.chmod(path, 0)
+ with pytest.raises(FsException) as excinfo:
+ fs_local.dir_ls(storage)
+ assert str(excinfo.value) == dir_ls_io_exception(storage)
+ assert excinfo.value.http_code == http.HTTPStatus.BAD_REQUEST
+ os.rmdir(path)
+
+
+@pytest.mark.parametrize(
+ "storage, with_files, ignore_non_exist",
+ [
+ (str(uuid.uuid4()), True, True),
+ (str(uuid.uuid4()), False, True),
+ (str(uuid.uuid4()), True, False),
+ (str(uuid.uuid4()), False, False),
+ ([str(uuid.uuid4())], True, True),
+ ([str(uuid.uuid4())], False, True),
+ ([str(uuid.uuid4())], True, False),
+ ([str(uuid.uuid4())], False, False),
+ ],
+)
+def test_file_delete_with_dir(fs_local, storage, with_files, ignore_non_exist):
+ path = (
+ valid_path() + storage
+ if isinstance(storage, str)
+ else valid_path() + storage[0]
+ )
+ os.mkdir(path)
+ if with_files is True:
+ file_path = path + "/" + str(uuid.uuid4())
+ os.mknod(file_path)
+ fs_local.file_delete(storage, ignore_non_exist)
+ assert os.path.exists(path) is False
+
+
+@pytest.mark.parametrize("storage", [(str(uuid.uuid4())), ([str(uuid.uuid4())])])
+def test_file_delete_expect_exception(fs_local, storage):
+ with pytest.raises(FsException) as excinfo:
+ fs_local.file_delete(storage)
+ assert str(excinfo.value) == file_delete_exception_message(storage)
+ assert excinfo.value.http_code == http.HTTPStatus.NOT_FOUND
+
+
+@pytest.mark.parametrize("storage", [(str(uuid.uuid4())), ([str(uuid.uuid4())])])
+def test_file_delete_no_exception(fs_local, storage):
+ path = (
+ valid_path() + storage
+ if isinstance(storage, str)
+ else valid_path() + storage[0]
+ )
+ fs_local.file_delete(storage, ignore_non_exist=True)
+ assert os.path.exists(path) is False
--- /dev/null
+# Copyright 2019 Canonical
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: eduardo.sousa@canonical.com
+##
+
+from io import BytesIO
+import logging
+import os
+from pathlib import Path
+import subprocess
+import tarfile
+import tempfile
+from unittest.mock import Mock
+
+from gridfs import GridFSBucket
+from osm_common.fsbase import FsException
+from osm_common.fsmongo import FsMongo
+from pymongo import MongoClient
+import pytest
+
+__author__ = "Eduardo Sousa <eduardo.sousa@canonical.com>"
+
+
+def valid_path():
+ return tempfile.gettempdir() + "/"
+
+
+def invalid_path():
+ return "/#tweeter/"
+
+
+@pytest.fixture(scope="function", params=[True, False])
+def fs_mongo(request, monkeypatch):
+ def mock_mongoclient_constructor(a, b):
+ pass
+
+ def mock_mongoclient_getitem(a, b):
+ pass
+
+ def mock_gridfs_constructor(a, b):
+ pass
+
+ monkeypatch.setattr(MongoClient, "__init__", mock_mongoclient_constructor)
+ monkeypatch.setattr(MongoClient, "__getitem__", mock_mongoclient_getitem)
+ monkeypatch.setattr(GridFSBucket, "__init__", mock_gridfs_constructor)
+ fs = FsMongo(lock=request.param)
+ fs.fs_connect({"path": valid_path(), "uri": "mongo:27017", "collection": "files"})
+ return fs
+
+
+def generic_fs_exception_message(message):
+ return "storage exception {}".format(message)
+
+
+def fs_connect_exception_message(path):
+ return "storage exception Invalid configuration param at '[storage]': path '{}' does not exist".format(
+ path
+ )
+
+
+def file_open_file_not_found_exception(storage):
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ return "storage exception File {} does not exist".format(f)
+
+
+def file_open_io_exception(storage):
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ return "storage exception File {} cannot be opened".format(f)
+
+
+def dir_ls_not_a_directory_exception(storage):
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ return "storage exception File {} does not exist".format(f)
+
+
+def dir_ls_io_exception(storage):
+ f = storage if isinstance(storage, str) else "/".join(storage)
+ return "storage exception File {} cannot be opened".format(f)
+
+
+def file_delete_exception_message(storage):
+ return "storage exception File {} does not exist".format(storage)
+
+
+def test_constructor_without_logger():
+ fs = FsMongo()
+ assert fs.logger == logging.getLogger("fs")
+ assert fs.path is None
+ assert fs.client is None
+ assert fs.fs is None
+
+
+def test_constructor_with_logger():
+ logger_name = "fs_mongo"
+ fs = FsMongo(logger_name=logger_name)
+ assert fs.logger == logging.getLogger(logger_name)
+ assert fs.path is None
+ assert fs.client is None
+ assert fs.fs is None
+
+
+def test_get_params(fs_mongo, monkeypatch):
+ def mock_gridfs_find(self, search_query, **kwargs):
+ return []
+
+ monkeypatch.setattr(GridFSBucket, "find", mock_gridfs_find)
+ params = fs_mongo.get_params()
+ assert len(params) == 2
+ assert "fs" in params
+ assert "path" in params
+ assert params["fs"] == "mongo"
+ assert params["path"] == valid_path()
+
+
+@pytest.mark.parametrize(
+ "config, exp_logger, exp_path",
+ [
+ (
+ {
+ "logger_name": "fs_mongo",
+ "path": valid_path(),
+ "uri": "mongo:27017",
+ "collection": "files",
+ },
+ "fs_mongo",
+ valid_path(),
+ ),
+ (
+ {
+ "logger_name": "fs_mongo",
+ "path": valid_path()[:-1],
+ "uri": "mongo:27017",
+ "collection": "files",
+ },
+ "fs_mongo",
+ valid_path(),
+ ),
+ (
+ {"path": valid_path(), "uri": "mongo:27017", "collection": "files"},
+ "fs",
+ valid_path(),
+ ),
+ (
+ {"path": valid_path()[:-1], "uri": "mongo:27017", "collection": "files"},
+ "fs",
+ valid_path(),
+ ),
+ ],
+)
+def test_fs_connect_with_valid_config(config, exp_logger, exp_path):
+ fs = FsMongo()
+ fs.fs_connect(config)
+ assert fs.logger == logging.getLogger(exp_logger)
+ assert fs.path == exp_path
+ assert type(fs.client) == MongoClient
+ assert type(fs.fs) == GridFSBucket
+
+
+@pytest.mark.parametrize(
+ "config, exp_exception_message",
+ [
+ (
+ {
+ "logger_name": "fs_mongo",
+ "path": invalid_path(),
+ "uri": "mongo:27017",
+ "collection": "files",
+ },
+ fs_connect_exception_message(invalid_path()),
+ ),
+ (
+ {
+ "logger_name": "fs_mongo",
+ "path": invalid_path()[:-1],
+ "uri": "mongo:27017",
+ "collection": "files",
+ },
+ fs_connect_exception_message(invalid_path()[:-1]),
+ ),
+ (
+ {"path": invalid_path(), "uri": "mongo:27017", "collection": "files"},
+ fs_connect_exception_message(invalid_path()),
+ ),
+ (
+ {"path": invalid_path()[:-1], "uri": "mongo:27017", "collection": "files"},
+ fs_connect_exception_message(invalid_path()[:-1]),
+ ),
+ (
+ {"path": "/", "uri": "mongo:27017", "collection": "files"},
+ generic_fs_exception_message(
+ "Invalid configuration param at '[storage]': path '/' is not writable"
+ ),
+ ),
+ ],
+)
+def test_fs_connect_with_invalid_path(config, exp_exception_message):
+ fs = FsMongo()
+ with pytest.raises(FsException) as excinfo:
+ fs.fs_connect(config)
+ assert str(excinfo.value) == exp_exception_message
+
+
+@pytest.mark.parametrize(
+ "config, exp_exception_message",
+ [
+ (
+ {"logger_name": "fs_mongo", "uri": "mongo:27017", "collection": "files"},
+ 'Missing parameter "path"',
+ ),
+ (
+ {"logger_name": "fs_mongo", "path": valid_path(), "collection": "files"},
+ 'Missing parameters: "uri"',
+ ),
+ (
+ {"logger_name": "fs_mongo", "path": valid_path(), "uri": "mongo:27017"},
+ 'Missing parameter "collection"',
+ ),
+ ],
+)
+def test_fs_connect_with_missing_parameters(config, exp_exception_message):
+ fs = FsMongo()
+ with pytest.raises(FsException) as excinfo:
+ fs.fs_connect(config)
+ assert str(excinfo.value) == generic_fs_exception_message(exp_exception_message)
+
+
+@pytest.mark.parametrize(
+ "config, exp_exception_message",
+ [
+ (
+ {
+ "logger_name": "fs_mongo",
+ "path": valid_path(),
+ "uri": "mongo:27017",
+ "collection": "files",
+ },
+ "MongoClient crashed",
+ ),
+ ],
+)
+def test_fs_connect_with_invalid_mongoclient(
+ config, exp_exception_message, monkeypatch
+):
+ def generate_exception(a, b=None):
+ raise Exception(exp_exception_message)
+
+ monkeypatch.setattr(MongoClient, "__init__", generate_exception)
+
+ fs = FsMongo()
+ with pytest.raises(FsException) as excinfo:
+ fs.fs_connect(config)
+ assert str(excinfo.value) == generic_fs_exception_message(exp_exception_message)
+
+
+@pytest.mark.parametrize(
+ "config, exp_exception_message",
+ [
+ (
+ {
+ "logger_name": "fs_mongo",
+ "path": valid_path(),
+ "uri": "mongo:27017",
+ "collection": "files",
+ },
+ "Collection unavailable",
+ ),
+ ],
+)
+def test_fs_connect_with_invalid_mongo_collection(
+ config, exp_exception_message, monkeypatch
+):
+ def mock_mongoclient_constructor(a, b=None):
+ pass
+
+ def generate_exception(a, b):
+ raise Exception(exp_exception_message)
+
+ monkeypatch.setattr(MongoClient, "__init__", mock_mongoclient_constructor)
+ monkeypatch.setattr(MongoClient, "__getitem__", generate_exception)
+
+ fs = FsMongo()
+ with pytest.raises(FsException) as excinfo:
+ fs.fs_connect(config)
+ assert str(excinfo.value) == generic_fs_exception_message(exp_exception_message)
+
+
+@pytest.mark.parametrize(
+ "config, exp_exception_message",
+ [
+ (
+ {
+ "logger_name": "fs_mongo",
+ "path": valid_path(),
+ "uri": "mongo:27017",
+ "collection": "files",
+ },
+ "GridFsBucket crashed",
+ ),
+ ],
+)
+def test_fs_connect_with_invalid_gridfsbucket(
+ config, exp_exception_message, monkeypatch
+):
+ def mock_mongoclient_constructor(a, b=None):
+ pass
+
+ def mock_mongoclient_getitem(a, b):
+ pass
+
+ def generate_exception(a, b):
+ raise Exception(exp_exception_message)
+
+ monkeypatch.setattr(MongoClient, "__init__", mock_mongoclient_constructor)
+ monkeypatch.setattr(MongoClient, "__getitem__", mock_mongoclient_getitem)
+ monkeypatch.setattr(GridFSBucket, "__init__", generate_exception)
+
+ fs = FsMongo()
+ with pytest.raises(FsException) as excinfo:
+ fs.fs_connect(config)
+ assert str(excinfo.value) == generic_fs_exception_message(exp_exception_message)
+
+
+def test_fs_disconnect(fs_mongo):
+ fs_mongo.fs_disconnect()
+
+
+# Example.tar.gz
+# example_tar/
+# ├── directory
+# │ └── file
+# └── symlinks
+# ├── directory_link -> ../directory/
+# └── file_link -> ../directory/file
+class FakeCursor:
+ def __init__(self, id, filename, metadata):
+ self._id = id
+ self.filename = filename
+ self.metadata = metadata
+
+
+class FakeFS:
+ directory_metadata = {"type": "dir", "permissions": 509}
+ file_metadata = {"type": "file", "permissions": 436}
+ symlink_metadata = {"type": "sym", "permissions": 511}
+
+ tar_info = {
+ 1: {
+ "cursor": FakeCursor(1, "example_tar", directory_metadata),
+ "metadata": directory_metadata,
+ "stream_content": b"",
+ "stream_content_bad": b"Something",
+ "path": "./tmp/example_tar",
+ },
+ 2: {
+ "cursor": FakeCursor(2, "example_tar/directory", directory_metadata),
+ "metadata": directory_metadata,
+ "stream_content": b"",
+ "stream_content_bad": b"Something",
+ "path": "./tmp/example_tar/directory",
+ },
+ 3: {
+ "cursor": FakeCursor(3, "example_tar/symlinks", directory_metadata),
+ "metadata": directory_metadata,
+ "stream_content": b"",
+ "stream_content_bad": b"Something",
+ "path": "./tmp/example_tar/symlinks",
+ },
+ 4: {
+ "cursor": FakeCursor(4, "example_tar/directory/file", file_metadata),
+ "metadata": file_metadata,
+ "stream_content": b"Example test",
+ "stream_content_bad": b"Example test2",
+ "path": "./tmp/example_tar/directory/file",
+ },
+ 5: {
+ "cursor": FakeCursor(5, "example_tar/symlinks/file_link", symlink_metadata),
+ "metadata": symlink_metadata,
+ "stream_content": b"../directory/file",
+ "stream_content_bad": b"",
+ "path": "./tmp/example_tar/symlinks/file_link",
+ },
+ 6: {
+ "cursor": FakeCursor(
+ 6, "example_tar/symlinks/directory_link", symlink_metadata
+ ),
+ "metadata": symlink_metadata,
+ "stream_content": b"../directory/",
+ "stream_content_bad": b"",
+ "path": "./tmp/example_tar/symlinks/directory_link",
+ },
+ }
+
+ def upload_from_stream(self, f, stream, metadata=None):
+ found = False
+ for i, v in self.tar_info.items():
+ if f == v["path"]:
+ assert metadata["type"] == v["metadata"]["type"]
+ assert stream.read() == BytesIO(v["stream_content"]).read()
+ stream.seek(0)
+ assert stream.read() != BytesIO(v["stream_content_bad"]).read()
+ found = True
+ continue
+ assert found
+
+ def find(self, type, no_cursor_timeout=True, sort=None):
+ list = []
+ for i, v in self.tar_info.items():
+ if type["metadata.type"] == "dir":
+ if v["metadata"] == self.directory_metadata:
+ list.append(v["cursor"])
+ else:
+ if v["metadata"] != self.directory_metadata:
+ list.append(v["cursor"])
+ return list
+
+ def download_to_stream(self, id, file_stream):
+ file_stream.write(BytesIO(self.tar_info[id]["stream_content"]).read())
+
+
+def test_file_extract():
+ tar_path = "tmp/Example.tar.gz"
+ folder_path = "tmp/example_tar"
+
+ # Generate package
+ subprocess.call(["rm", "-rf", "./tmp"])
+ subprocess.call(["mkdir", "-p", "{}/directory".format(folder_path)])
+ subprocess.call(["mkdir", "-p", "{}/symlinks".format(folder_path)])
+ p = Path("{}/directory/file".format(folder_path))
+ p.write_text("Example test")
+ os.symlink("../directory/file", "{}/symlinks/file_link".format(folder_path))
+ os.symlink("../directory/", "{}/symlinks/directory_link".format(folder_path))
+ if os.path.exists(tar_path):
+ os.remove(tar_path)
+ subprocess.call(["tar", "-czvf", tar_path, folder_path])
+
+ try:
+ tar = tarfile.open(tar_path, "r")
+ fs = FsMongo()
+ fs.fs = FakeFS()
+ fs.file_extract(compressed_object=tar, path=".")
+ finally:
+ os.remove(tar_path)
+ subprocess.call(["rm", "-rf", "./tmp"])
+
+
+def test_upload_local_fs():
+ path = "./tmp/"
+
+ subprocess.call(["rm", "-rf", path])
+ try:
+ fs = FsMongo()
+ fs.path = path
+ fs.fs = FakeFS()
+ fs.sync()
+ assert os.path.isdir("{}example_tar".format(path))
+ assert os.path.isdir("{}example_tar/directory".format(path))
+ assert os.path.isdir("{}example_tar/symlinks".format(path))
+ assert os.path.isfile("{}example_tar/directory/file".format(path))
+ assert os.path.islink("{}example_tar/symlinks/file_link".format(path))
+ assert os.path.islink("{}example_tar/symlinks/directory_link".format(path))
+ finally:
+ subprocess.call(["rm", "-rf", path])
+
+
+def test_upload_mongo_fs():
+ path = "./tmp/"
+
+ subprocess.call(["rm", "-rf", path])
+ try:
+ fs = FsMongo()
+ fs.path = path
+ fs.fs = Mock()
+ fs.fs.find.return_value = {}
+
+ file_content = "Test file content"
+
+ # Create local dir and upload content to fakefs
+ os.mkdir(path)
+ os.mkdir("{}example_local".format(path))
+ os.mkdir("{}example_local/directory".format(path))
+ with open(
+ "{}example_local/directory/test_file".format(path), "w+"
+ ) as test_file:
+ test_file.write(file_content)
+ fs.reverse_sync("example_local")
+
+ assert fs.fs.upload_from_stream.call_count == 2
+
+ # first call to upload_from_stream, dir_name
+ dir_name = "example_local/directory"
+ call_args_0 = fs.fs.upload_from_stream.call_args_list[0]
+ assert call_args_0[0][0] == dir_name
+ assert call_args_0[1].get("metadata").get("type") == "dir"
+
+ # second call to upload_from_stream, dir_name
+ file_name = "example_local/directory/test_file"
+ call_args_1 = fs.fs.upload_from_stream.call_args_list[1]
+ assert call_args_1[0][0] == file_name
+ assert call_args_1[1].get("metadata").get("type") == "file"
+
+ finally:
+ subprocess.call(["rm", "-rf", path])
+ pass
--- /dev/null
+# Copyright 2018 Whitestack, LLC
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
+##
+import asyncio
+import http
+
+from osm_common.msgbase import MsgBase, MsgException
+import pytest
+
+
+def exception_message(message):
+ return "messaging exception " + message
+
+
+@pytest.fixture
+def msg_base():
+ return MsgBase()
+
+
+def test_constructor():
+ msgbase = MsgBase()
+ assert msgbase is not None
+ assert isinstance(msgbase, MsgBase)
+
+
+def test_connect(msg_base):
+ msg_base.connect(None)
+
+
+def test_disconnect(msg_base):
+ msg_base.disconnect()
+
+
+def test_write(msg_base):
+ with pytest.raises(MsgException) as excinfo:
+ msg_base.write("test", "test", "test")
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'write' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+def test_read(msg_base):
+ with pytest.raises(MsgException) as excinfo:
+ msg_base.read("test")
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'read' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+def test_aiowrite(msg_base):
+ with pytest.raises(MsgException) as excinfo:
+ asyncio.run(msg_base.aiowrite("test", "test", "test"))
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'aiowrite' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+def test_aioread(msg_base):
+ with pytest.raises(MsgException) as excinfo:
+ asyncio.run(msg_base.aioread("test"))
+ assert str(excinfo.value).startswith(
+ exception_message("Method 'aioread' not implemented")
+ )
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
--- /dev/null
+# Copyright 2018 Whitestack, LLC
+# Copyright 2018 Telefonica S.A.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# For those usages not covered by the Apache License, Version 2.0 please
+# contact: esousa@whitestack.com or alfonso.tiernosepulveda@telefonica.com
+##
+import asyncio
+import http
+import logging
+import os
+import shutil
+import tempfile
+import threading
+import time
+from unittest.mock import MagicMock
+import uuid
+
+from osm_common.msgbase import MsgException
+from osm_common.msglocal import MsgLocal
+import pytest
+import yaml
+
+__author__ = "Eduardo Sousa <eduardosousa@av.it.pt>"
+
+
+def valid_path():
+ return tempfile.gettempdir() + "/"
+
+
+def invalid_path():
+ return "/#tweeter/"
+
+
+@pytest.fixture(scope="function", params=[True, False])
+def msg_local(request):
+ msg = MsgLocal(lock=request.param)
+ yield msg
+
+ msg.disconnect()
+ if msg.path and msg.path != invalid_path() and msg.path != valid_path():
+ shutil.rmtree(msg.path)
+
+
+@pytest.fixture(scope="function", params=[True, False])
+def msg_local_config(request):
+ msg = MsgLocal(lock=request.param)
+ msg.connect({"path": valid_path() + str(uuid.uuid4())})
+ yield msg
+
+ msg.disconnect()
+ if msg.path != invalid_path():
+ shutil.rmtree(msg.path)
+
+
+@pytest.fixture(scope="function", params=[True, False])
+def msg_local_with_data(request):
+ msg = MsgLocal(lock=request.param)
+ msg.connect({"path": valid_path() + str(uuid.uuid4())})
+
+ msg.write("topic1", "key1", "msg1")
+ msg.write("topic1", "key2", "msg1")
+ msg.write("topic2", "key1", "msg1")
+ msg.write("topic2", "key2", "msg1")
+ msg.write("topic1", "key1", "msg2")
+ msg.write("topic1", "key2", "msg2")
+ msg.write("topic2", "key1", "msg2")
+ msg.write("topic2", "key2", "msg2")
+ yield msg
+
+ msg.disconnect()
+ if msg.path != invalid_path():
+ shutil.rmtree(msg.path)
+
+
+def empty_exception_message():
+ return "messaging exception "
+
+
+def test_constructor():
+ msg = MsgLocal()
+ assert msg.logger == logging.getLogger("msg")
+ assert msg.path is None
+ assert len(msg.files_read) == 0
+ assert len(msg.files_write) == 0
+ assert len(msg.buffer) == 0
+
+
+def test_constructor_with_logger():
+ logger_name = "msg_local"
+ msg = MsgLocal(logger_name=logger_name)
+ assert msg.logger == logging.getLogger(logger_name)
+ assert msg.path is None
+ assert len(msg.files_read) == 0
+ assert len(msg.files_write) == 0
+ assert len(msg.buffer) == 0
+
+
+@pytest.mark.parametrize(
+ "config, logger_name, path",
+ [
+ ({"logger_name": "msg_local", "path": valid_path()}, "msg_local", valid_path()),
+ (
+ {"logger_name": "msg_local", "path": valid_path()[:-1]},
+ "msg_local",
+ valid_path(),
+ ),
+ (
+ {"logger_name": "msg_local", "path": valid_path() + "test_it/"},
+ "msg_local",
+ valid_path() + "test_it/",
+ ),
+ (
+ {"logger_name": "msg_local", "path": valid_path() + "test_it"},
+ "msg_local",
+ valid_path() + "test_it/",
+ ),
+ ({"path": valid_path()}, "msg", valid_path()),
+ ({"path": valid_path()[:-1]}, "msg", valid_path()),
+ ({"path": valid_path() + "test_it/"}, "msg", valid_path() + "test_it/"),
+ ({"path": valid_path() + "test_it"}, "msg", valid_path() + "test_it/"),
+ ],
+)
+def test_connect(msg_local, config, logger_name, path):
+ msg_local.connect(config)
+ assert msg_local.logger == logging.getLogger(logger_name)
+ assert msg_local.path == path
+ assert len(msg_local.files_read) == 0
+ assert len(msg_local.files_write) == 0
+ assert len(msg_local.buffer) == 0
+
+
+@pytest.mark.parametrize(
+ "config",
+ [
+ ({"logger_name": "msg_local", "path": invalid_path()}),
+ ({"path": invalid_path()}),
+ ],
+)
+def test_connect_with_exception(msg_local, config):
+ with pytest.raises(MsgException) as excinfo:
+ msg_local.connect(config)
+ assert str(excinfo.value).startswith(empty_exception_message())
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+def test_disconnect(msg_local_config):
+ files_read = msg_local_config.files_read.copy()
+ files_write = msg_local_config.files_write.copy()
+ msg_local_config.disconnect()
+ for f in files_read.values():
+ assert f.closed
+ for f in files_write.values():
+ assert f.closed
+
+
+def test_disconnect_with_read(msg_local_config):
+ msg_local_config.read("topic1", blocks=False)
+ msg_local_config.read("topic2", blocks=False)
+ files_read = msg_local_config.files_read.copy()
+ files_write = msg_local_config.files_write.copy()
+ msg_local_config.disconnect()
+ for f in files_read.values():
+ assert f.closed
+ for f in files_write.values():
+ assert f.closed
+
+
+def test_disconnect_with_write(msg_local_with_data):
+ files_read = msg_local_with_data.files_read.copy()
+ files_write = msg_local_with_data.files_write.copy()
+ msg_local_with_data.disconnect()
+
+ for f in files_read.values():
+ assert f.closed
+
+ for f in files_write.values():
+ assert f.closed
+
+
+def test_disconnect_with_read_and_write(msg_local_with_data):
+ msg_local_with_data.read("topic1", blocks=False)
+ msg_local_with_data.read("topic2", blocks=False)
+ files_read = msg_local_with_data.files_read.copy()
+ files_write = msg_local_with_data.files_write.copy()
+
+ msg_local_with_data.disconnect()
+ for f in files_read.values():
+ assert f.closed
+ for f in files_write.values():
+ assert f.closed
+
+
+@pytest.mark.parametrize(
+ "topic, key, msg",
+ [
+ ("test_topic", "test_key", "test_msg"),
+ ("test", "test_key", "test_msg"),
+ ("test_topic", "test", "test_msg"),
+ ("test_topic", "test_key", "test"),
+ ("test_topic", "test_list", ["a", "b", "c"]),
+ ("test_topic", "test_tuple", ("c", "b", "a")),
+ ("test_topic", "test_dict", {"a": 1, "b": 2, "c": 3}),
+ ("test_topic", "test_number", 123),
+ ("test_topic", "test_float", 1.23),
+ ("test_topic", "test_boolean", True),
+ ("test_topic", "test_none", None),
+ ],
+)
+def test_write(msg_local_config, topic, key, msg):
+ file_path = msg_local_config.path + topic
+ msg_local_config.write(topic, key, msg)
+ assert os.path.exists(file_path)
+
+ with open(file_path, "r") as stream:
+ assert yaml.safe_load(stream) == {
+ key: msg if not isinstance(msg, tuple) else list(msg)
+ }
+
+
+@pytest.mark.parametrize(
+ "topic, key, msg, times",
+ [
+ ("test_topic", "test_key", "test_msg", 2),
+ ("test", "test_key", "test_msg", 3),
+ ("test_topic", "test", "test_msg", 4),
+ ("test_topic", "test_key", "test", 2),
+ ("test_topic", "test_list", ["a", "b", "c"], 3),
+ ("test_topic", "test_tuple", ("c", "b", "a"), 4),
+ ("test_topic", "test_dict", {"a": 1, "b": 2, "c": 3}, 2),
+ ("test_topic", "test_number", 123, 3),
+ ("test_topic", "test_float", 1.23, 4),
+ ("test_topic", "test_boolean", True, 2),
+ ("test_topic", "test_none", None, 3),
+ ],
+)
+def test_write_with_multiple_calls(msg_local_config, topic, key, msg, times):
+ file_path = msg_local_config.path + topic
+
+ for _ in range(times):
+ msg_local_config.write(topic, key, msg)
+ assert os.path.exists(file_path)
+
+ with open(file_path, "r") as stream:
+ for _ in range(times):
+ data = stream.readline()
+ assert yaml.safe_load(data) == {
+ key: msg if not isinstance(msg, tuple) else list(msg)
+ }
+
+
+def test_write_exception(msg_local_config):
+ msg_local_config.files_write = MagicMock()
+ msg_local_config.files_write.__contains__.side_effect = Exception()
+
+ with pytest.raises(MsgException) as excinfo:
+ msg_local_config.write("test", "test", "test")
+ assert str(excinfo.value).startswith(empty_exception_message())
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+@pytest.mark.parametrize(
+ "topics, datas",
+ [
+ (["topic"], [{"key": "value"}]),
+ (["topic1"], [{"key": "value"}]),
+ (["topic2"], [{"key": "value"}]),
+ (["topic", "topic1"], [{"key": "value"}]),
+ (["topic", "topic2"], [{"key": "value"}]),
+ (["topic1", "topic2"], [{"key": "value"}]),
+ (["topic", "topic1", "topic2"], [{"key": "value"}]),
+ (["topic"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic1"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic1"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ ],
+)
+def test_read(msg_local_with_data, topics, datas):
+ def write_to_topic(topics, datas):
+ # Allow msglocal to block while waiting
+ time.sleep(2)
+ for topic in topics:
+ for data in datas:
+ with open(msg_local_with_data.path + topic, "a+") as fp:
+ yaml.safe_dump(data, fp, default_flow_style=True, width=20000)
+ fp.flush()
+
+ # If file is not opened first, the messages written won't be seen
+ for topic in topics:
+ if topic not in msg_local_with_data.files_read:
+ msg_local_with_data.read(topic, blocks=False)
+
+ t = threading.Thread(target=write_to_topic, args=(topics, datas))
+ t.start()
+
+ for topic in topics:
+ for data in datas:
+ recv_topic, recv_key, recv_msg = msg_local_with_data.read(topic)
+ key = list(data.keys())[0]
+ val = data[key]
+ assert recv_topic == topic
+ assert recv_key == key
+ assert recv_msg == val
+ t.join()
+
+
+@pytest.mark.parametrize(
+ "topics, datas",
+ [
+ (["topic"], [{"key": "value"}]),
+ (["topic1"], [{"key": "value"}]),
+ (["topic2"], [{"key": "value"}]),
+ (["topic", "topic1"], [{"key": "value"}]),
+ (["topic", "topic2"], [{"key": "value"}]),
+ (["topic1", "topic2"], [{"key": "value"}]),
+ (["topic", "topic1", "topic2"], [{"key": "value"}]),
+ (["topic"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic1"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic1"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ ],
+)
+def test_read_non_block(msg_local_with_data, topics, datas):
+ def write_to_topic(topics, datas):
+ for topic in topics:
+ for data in datas:
+ with open(msg_local_with_data.path + topic, "a+") as fp:
+ yaml.safe_dump(data, fp, default_flow_style=True, width=20000)
+ fp.flush()
+
+ # If file is not opened first, the messages written won't be seen
+ for topic in topics:
+ if topic not in msg_local_with_data.files_read:
+ msg_local_with_data.read(topic, blocks=False)
+
+ t = threading.Thread(target=write_to_topic, args=(topics, datas))
+ t.start()
+ t.join()
+
+ for topic in topics:
+ for data in datas:
+ recv_topic, recv_key, recv_msg = msg_local_with_data.read(
+ topic, blocks=False
+ )
+ key = list(data.keys())[0]
+ val = data[key]
+ assert recv_topic == topic
+ assert recv_key == key
+ assert recv_msg == val
+
+
+@pytest.mark.parametrize(
+ "topics, datas",
+ [
+ (["topic"], [{"key": "value"}]),
+ (["topic1"], [{"key": "value"}]),
+ (["topic2"], [{"key": "value"}]),
+ (["topic", "topic1"], [{"key": "value"}]),
+ (["topic", "topic2"], [{"key": "value"}]),
+ (["topic1", "topic2"], [{"key": "value"}]),
+ (["topic", "topic1", "topic2"], [{"key": "value"}]),
+ (["topic"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic1"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic1"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ ],
+)
+def test_read_non_block_none(msg_local_with_data, topics, datas):
+ def write_to_topic(topics, datas):
+ time.sleep(2)
+ for topic in topics:
+ for data in datas:
+ with open(msg_local_with_data.path + topic, "a+") as fp:
+ yaml.safe_dump(data, fp, default_flow_style=True, width=20000)
+ fp.flush()
+
+ # If file is not opened first, the messages written won't be seen
+ for topic in topics:
+ if topic not in msg_local_with_data.files_read:
+ msg_local_with_data.read(topic, blocks=False)
+ t = threading.Thread(target=write_to_topic, args=(topics, datas))
+ t.start()
+
+ for topic in topics:
+ recv_data = msg_local_with_data.read(topic, blocks=False)
+ assert recv_data is None
+ t.join()
+
+
+@pytest.mark.parametrize("blocks", [(True), (False)])
+def test_read_exception(msg_local_with_data, blocks):
+ msg_local_with_data.files_read = MagicMock()
+ msg_local_with_data.files_read.__contains__.side_effect = Exception()
+
+ with pytest.raises(MsgException) as excinfo:
+ msg_local_with_data.read("topic1", blocks=blocks)
+ assert str(excinfo.value).startswith(empty_exception_message())
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+@pytest.mark.parametrize(
+ "topics, datas",
+ [
+ (["topic"], [{"key": "value"}]),
+ (["topic1"], [{"key": "value"}]),
+ (["topic2"], [{"key": "value"}]),
+ (["topic", "topic1"], [{"key": "value"}]),
+ (["topic", "topic2"], [{"key": "value"}]),
+ (["topic1", "topic2"], [{"key": "value"}]),
+ (["topic", "topic1", "topic2"], [{"key": "value"}]),
+ (["topic"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic1"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic1"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ (["topic", "topic1", "topic2"], [{"key": "value"}, {"key1": "value1"}]),
+ ],
+)
+def test_aioread(msg_local_with_data, topics, datas):
+ def write_to_topic(topics, datas):
+ time.sleep(2)
+ for topic in topics:
+ for data in datas:
+ with open(msg_local_with_data.path + topic, "a+") as fp:
+ yaml.safe_dump(data, fp, default_flow_style=True, width=20000)
+ fp.flush()
+
+ # If file is not opened first, the messages written won't be seen
+ for topic in topics:
+ if topic not in msg_local_with_data.files_read:
+ msg_local_with_data.read(topic, blocks=False)
+
+ t = threading.Thread(target=write_to_topic, args=(topics, datas))
+ t.start()
+ for topic in topics:
+ for data in datas:
+ recv = asyncio.run(msg_local_with_data.aioread(topic))
+ recv_topic, recv_key, recv_msg = recv
+ key = list(data.keys())[0]
+ val = data[key]
+ assert recv_topic == topic
+ assert recv_key == key
+ assert recv_msg == val
+ t.join()
+
+
+def test_aioread_exception(msg_local_with_data):
+ msg_local_with_data.files_read = MagicMock()
+ msg_local_with_data.files_read.__contains__.side_effect = Exception()
+
+ with pytest.raises(MsgException) as excinfo:
+ asyncio.run(msg_local_with_data.aioread("topic1"))
+ assert str(excinfo.value).startswith(empty_exception_message())
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+def test_aioread_general_exception(msg_local_with_data):
+ msg_local_with_data.read = MagicMock()
+ msg_local_with_data.read.side_effect = Exception()
+
+ with pytest.raises(MsgException) as excinfo:
+ asyncio.run(msg_local_with_data.aioread("topic1"))
+ assert str(excinfo.value).startswith(empty_exception_message())
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
+
+
+@pytest.mark.parametrize(
+ "topic, key, msg",
+ [
+ ("test_topic", "test_key", "test_msg"),
+ ("test", "test_key", "test_msg"),
+ ("test_topic", "test", "test_msg"),
+ ("test_topic", "test_key", "test"),
+ ("test_topic", "test_list", ["a", "b", "c"]),
+ ("test_topic", "test_tuple", ("c", "b", "a")),
+ ("test_topic", "test_dict", {"a": 1, "b": 2, "c": 3}),
+ ("test_topic", "test_number", 123),
+ ("test_topic", "test_float", 1.23),
+ ("test_topic", "test_boolean", True),
+ ("test_topic", "test_none", None),
+ ],
+)
+def test_aiowrite(msg_local_config, topic, key, msg):
+ file_path = msg_local_config.path + topic
+ asyncio.run(msg_local_config.aiowrite(topic, key, msg))
+ assert os.path.exists(file_path)
+
+ with open(file_path, "r") as stream:
+ assert yaml.safe_load(stream) == {
+ key: msg if not isinstance(msg, tuple) else list(msg)
+ }
+
+
+@pytest.mark.parametrize(
+ "topic, key, msg, times",
+ [
+ ("test_topic", "test_key", "test_msg", 2),
+ ("test", "test_key", "test_msg", 3),
+ ("test_topic", "test", "test_msg", 4),
+ ("test_topic", "test_key", "test", 2),
+ ("test_topic", "test_list", ["a", "b", "c"], 3),
+ ("test_topic", "test_tuple", ("c", "b", "a"), 4),
+ ("test_topic", "test_dict", {"a": 1, "b": 2, "c": 3}, 2),
+ ("test_topic", "test_number", 123, 3),
+ ("test_topic", "test_float", 1.23, 4),
+ ("test_topic", "test_boolean", True, 2),
+ ("test_topic", "test_none", None, 3),
+ ],
+)
+def test_aiowrite_with_multiple_calls(msg_local_config, topic, key, msg, times):
+ file_path = msg_local_config.path + topic
+ for _ in range(times):
+ asyncio.run(msg_local_config.aiowrite(topic, key, msg))
+ assert os.path.exists(file_path)
+
+ with open(file_path, "r") as stream:
+ for _ in range(times):
+ data = stream.readline()
+ assert yaml.safe_load(data) == {
+ key: msg if not isinstance(msg, tuple) else list(msg)
+ }
+
+
+def test_aiowrite_exception(msg_local_config):
+ msg_local_config.files_write = MagicMock()
+ msg_local_config.files_write.__contains__.side_effect = Exception()
+
+ with pytest.raises(MsgException) as excinfo:
+ asyncio.run(msg_local_config.aiowrite("test", "test", "test"))
+ assert str(excinfo.value).startswith(empty_exception_message())
+ assert excinfo.value.http_code == http.HTTPStatus.INTERNAL_SERVER_ERROR
[testenv]
usedevelop = True
-basepython = python3.10
+basepython = python3.13
setenv = VIRTUAL_ENV={envdir}
PYTHONDONTWRITEBYTECODE = 1
deps = -r{toxinidir}/requirements.txt
deps = black==23.12.1
skip_install = true
commands =
- black --check --diff osm_common/
- black --check --diff setup.py
+ black --exclude "_version\.py$" --check --diff src/ tests/
#######################################################################################
[testenv:cover]
commands =
sh -c 'rm -f nosetests.xml'
coverage erase
- nose2 -C --coverage osm_common -s osm_common/tests
+ nose2 -C --coverage src -s tests
coverage report --omit='*tests*'
coverage html -d ./cover --omit='*tests*'
coverage xml -o coverage.xml --omit=*tests*
flake8==6.0.0
flake8-import-order
commands =
- flake8 osm_common/ setup.py
+ flake8 src/ tests/
#######################################################################################
[testenv:pylint]
-r{toxinidir}/requirements-test.txt
pylint
commands =
- pylint -E osm_common
+ pylint -E src
#######################################################################################
[testenv:safety]
sed -i -e '1 e head -16 tox.ini' $out ;\
done"
-#######################################################################################
-[testenv:dist]
-deps = {[testenv]deps}
- -r{toxinidir}/requirements-dist.txt
-
-# In the commands, we copy the requirements.txt to be presented as a source file (.py)
-# so it gets included in the .deb package for others to consume
-commands =
- sh -c 'cp requirements.txt osm_common/requirements.txt'
- python3 setup.py --command-packages=stdeb.command sdist_dsc
- sh -c 'cd deb_dist/osm-common*/ && dpkg-buildpackage -rfakeroot -uc -us'
- sh -c 'rm osm_common/requirements.txt'
-allowlist_externals = sh
-
#######################################################################################
[testenv:release_notes]
deps = reno
show-source = True
builtins = _
import-order-style = google
+