Refactors code and adds unit tests
[osm/MON.git] / osm_mon / core / database.py
index d1c2e6b..2f51b1e 100644 (file)
 # contact: bdiaz@whitestack.com or glavado@whitestack.com
 ##
 
-import json
 import logging
 import os
-import uuid
+from typing import Iterable
 
 from peewee import CharField, TextField, FloatField, Model, AutoField, Proxy
 from peewee_migrate import Router
@@ -81,77 +80,33 @@ class DatabaseManager:
             router.run()
         db.close()
 
-    def get_credentials(self, vim_uuid: str = None) -> VimCredentials:
-        db.connect()
-        try:
-            with db.atomic():
-                vim_credentials = VimCredentials.get_or_none(VimCredentials.uuid == vim_uuid)
-                return vim_credentials
-        finally:
-            db.close()
-
-    def save_credentials(self, vim_credentials) -> VimCredentials:
-        """Saves vim credentials. If a record with same uuid exists, overwrite it."""
-        db.connect()
-        try:
-            with db.atomic():
-                exists = VimCredentials.get_or_none(VimCredentials.uuid == vim_credentials.uuid)
-                if exists:
-                    vim_credentials.id = exists.id
-                vim_credentials.save()
-                return vim_credentials
-        finally:
-            db.close()
-
-    def get_alarm(self, alarm_id) -> Alarm:
-        db.connect()
-        try:
-            with db.atomic():
-                alarm = (Alarm.select()
-                         .where(Alarm.alarm_id == alarm_id)
-                         .get())
-                return alarm
-        finally:
-            db.close()
-
-    def save_alarm(self, name, threshold, operation, severity, statistic, metric_name, vdur_name,
-                   vnf_member_index, nsr_id) -> Alarm:
-        """Saves alarm."""
-        # TODO: Add uuid optional param and check if exists to handle updates (see self.save_credentials)
-        db.connect()
-        try:
-            with db.atomic():
-                alarm = Alarm()
-                alarm.uuid = str(uuid.uuid4())
-                alarm.name = name
-                alarm.threshold = threshold
-                alarm.operation = operation
-                alarm.severity = severity
-                alarm.statistic = statistic
-                alarm.monitoring_param = metric_name
-                alarm.vdur_name = vdur_name
-                alarm.vnf_member_index = vnf_member_index
-                alarm.nsr_id = nsr_id
-                alarm.save()
-                return alarm
-        finally:
-            db.close()
-
-    def delete_alarm(self, alarm_uuid) -> None:
-        db.connect()
-        with db.atomic():
-            alarm = (Alarm.select()
-                     .where(Alarm.uuid == alarm_uuid)
-                     .get())
-            alarm.delete_instance()
-        db.close()
 
-    def get_vim_type(self, vim_account_id) -> str:
-        """Get the vim type that is required by the message."""
-        credentials = self.get_credentials(vim_account_id)
-        config = json.loads(credentials.config)
-        if 'vim_type' in config:
-            vim_type = config['vim_type']
-            return str(vim_type.lower())
+class VimCredentialsRepository:
+    @staticmethod
+    def upsert(**query) -> VimCredentials:
+        vim_credentials = VimCredentials.get_or_none(**query)
+        if vim_credentials:
+            query.update({'id': vim_credentials.id})
+        vim_id = VimCredentials.insert(**query).on_conflict_replace().execute()
+        return VimCredentials.get(id=vim_id)
+
+    @staticmethod
+    def get(*expressions) -> VimCredentials:
+        return VimCredentials.select().where(*expressions).get()
+
+
+class AlarmRepository:
+    @staticmethod
+    def create(**query) -> Alarm:
+        return Alarm.create(**query)
+
+    @staticmethod
+    def get(*expressions) -> Alarm:
+        return Alarm.select().where(*expressions).get()
+
+    @staticmethod
+    def list(*expressions) -> Iterable[Alarm]:
+        if expressions == ():
+            return Alarm.select()
         else:
-            return str(credentials.type)
+            return Alarm.select().where(*expressions)