From eb0f6c0310ea0ff2d2f3e6b11f946b28a7032687 Mon Sep 17 00:00:00 2001 From: boyska Date: Mon, 17 Jan 2022 00:46:24 +0100 Subject: [PATCH] change operations only work on "main" DB --- larigira/db.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/larigira/db.py b/larigira/db.py index 7aaa40c..a36468b 100644 --- a/larigira/db.py +++ b/larigira/db.py @@ -22,6 +22,16 @@ class ReadOnlyMiddleware(Middleware): class ReadOnlyException(ValueError): pass +def only_main(f): + '''assumes first argument is id, and must be "main"''' + def wrapper(self, *args): + _id = args[0] + db, db_id = EventModel.parse_id(_id) + if db != 'main': + raise ReadOnlyException('You called a write operation on a readonly db') + return f(self, db_id, *args[1:]) + return wrapper + class EventModel(object): def __init__(self, uri, additional_db_dir=None): @@ -63,31 +73,36 @@ class EventModel(object): self._actions = self._dbs['main'].table("actions") self._alarms = self._dbs['main'].table("alarms") - def canonicalize(self, eid_or_aid: Union[str, int]) -> str: + @staticmethod + def canonicalize(eid_or_aid: Union[str, int]) -> str: try: int(eid_or_aid) except ValueError: return eid_or_aid return 'main:%d' % eid_or_aid - def parse_id(self, eid_or_aid: Union[str, int]) -> Tuple[str, int]: + @staticmethod + def parse_id(eid_or_aid: Union[str, int]) -> Tuple[str, int]: try: int(eid_or_aid) except ValueError: pass else: - return ('main', eid_or_aid) + return ('main', int(eid_or_aid)) dbname, num = eid_or_aid.split(':') return (dbname, int(num)) def get_action_by_id(self, action_id: Union[str, int]): - db, db_action_id = self.parse_id(action_id) - return self._dbs[db].table('actions').get(eid=db_action_id) + canonical = self.canonicalize(action_id) + db, db_action_id = self.__class__.parse_id(canonical) + out = self._dbs[db].table('actions').get(eid=db_action_id) + out.doc_id = canonical + return out def get_alarm_by_id(self, alarm_id): - db, alarm_id = self.parse_id(alarm_id) + db, alarm_id = self.__class__.parse_id(alarm_id) return self._dbs[db].table('alarms').get(eid=alarm_id) def get_actions_by_alarm(self, alarm): @@ -125,14 +140,18 @@ class EventModel(object): def add_alarm(self, alarm): return self.add_event(alarm, []) + @only_main def update_alarm(self, alarmid, new_fields={}): return self._alarms.update(new_fields, eids=[alarmid]) + @only_main def update_action(self, actionid, new_fields={}): return self._actions.update(new_fields, eids=[actionid]) + @only_main def delete_alarm(self, alarmid): return self._alarms.remove(eids=[alarmid]) + @only_main def delete_action(self, actionid): return self._actions.remove(eids=[actionid])