- def create_tables(self):
- try:
- db.connect()
- db.create_tables([VimCredentials, Alarm])
- db.close()
- except Exception as e:
- log.exception("Error creating tables: ")
-
- def get_credentials(self, vim_uuid):
- return VimCredentials.get_or_none(VimCredentials.uuid == vim_uuid)
-
- def save_credentials(self, vim_credentials):
- """Saves vim credentials. If a record with same uuid exists, overwrite it."""
- exists = VimCredentials.get_or_none(VimCredentials.uuid == vim_credentials.uuid)
- if exists:
- vim_credentials.id = exists.id
- vim_credentials.save()
-
- def get_credentials_for_alarm_id(self, alarm_id, vim_type):
- alarm = Alarm.select() \
- .where(Alarm.alarm_id == alarm_id) \
- .join(VimCredentials) \
- .where(VimCredentials.type == vim_type).get()
- return alarm.credentials
-
- def save_alarm(self, alarm_id, vim_uuid):
- """Saves alarm. If a record with same id and vim_uuid exists, overwrite it."""
- alarm = Alarm()
- alarm.alarm_id = alarm_id
- creds = VimCredentials.get(VimCredentials.uuid == vim_uuid)
- alarm.credentials = creds
- exists = Alarm.select(Alarm.alarm_id == alarm.alarm_id) \
- .join(VimCredentials) \
- .where(VimCredentials.uuid == vim_uuid)
- if len(exists):
- alarm.id = exists[0].id
- alarm.save()
+ def __init__(self, config: Config):
+ db.initialize(connect(config.get('sql', 'database_uri')))
+
+ def create_tables(self) -> None:
+ db.connect()
+ with db.atomic():
+ router = Router(db, os.path.dirname(migrations.__file__))
+ router.run()
+ db.close()
+
+
+class VimCredentialsRepository:
+ @staticmethod
+ def upsert(**query) -> VimCredentials:
+ vim_credentials = VimCredentials.get_or_none(**query)
+ if vim_credentials:
+ query.update({'id': vim_credentials.id})
+ vim_id = VimCredentials.insert(**query).on_conflict_replace().execute()
+ return VimCredentials.get(id=vim_id)
+
+ @staticmethod
+ def get(*expressions) -> VimCredentials:
+ return VimCredentials.select().where(*expressions).get()
+
+
+class AlarmRepository:
+ @staticmethod
+ def create(**query) -> Alarm:
+ return Alarm.create(**query)
+
+ @staticmethod
+ def get(*expressions) -> Alarm:
+ return Alarm.select().where(*expressions).get()
+
+ @staticmethod
+ def list(*expressions) -> Iterable[Alarm]:
+ if expressions == ():
+ return Alarm.select()
+ else:
+ return Alarm.select().where(*expressions)