Browse Source

Initial support for multiple dbs

refs #19
boyska 2 years ago
parent
commit
51a7c96ea0
1 changed files with 80 additions and 14 deletions
  1. 80 14
      larigira/db.py

+ 80 - 14
larigira/db.py

@@ -1,24 +1,84 @@
+from logging import getLogger
 from tinydb import TinyDB
+from tinydb.storages import JSONStorage
+from tinydb.middlewares import Middleware
+from pathlib import Path
 
+class ReadOnlyMiddleware(Middleware):
+    """
+    Make sure no write ever occurs
+    """
+
+    def __init__(self, storage_cls=TinyDB.DEFAULT_STORAGE):
+        super().__init__(storage_cls)
+
+
+    def write(self, data):
+        raise ReadOnlyException('You cannot write to a readonly db')
+
+
+class ReadOnlyException(ValueError):
+    pass
 
 class EventModel(object):
-    def __init__(self, uri):
+    def __init__(self, uri, additional_db_dir=None):
         self.uri = uri
-        self.db = None
+        self.additional_db_dir = Path(additional_db_dir) if additional_db_dir else None
+        self._dbs = {}
+        self.log = getLogger(self.__class__.__name__)
         self.reload()
 
     def reload(self):
-        if self.db is not None:
-            self.db.close()
-        self.db = TinyDB(self.uri, indent=2)
-        self._actions = self.db.table("actions")
-        self._alarms = self.db.table("alarms")
+        for db in self._dbs.values():
+            db.close()
+        self._dbs['main'] = TinyDB(self.uri, indent=2)
+        if self.additional_db_dir is not None:
+            if self.additional_db_dir.is_dir():
+                for db_file in self.additional_db_dir.glob('*.db.json'):
+                    name = db_file.name[:-8]
+                    if name == 'main':
+                        self.log.warning("%s db file name is not valid (any other name.db.json would have been ok!", str(db_file.name))
+                        continue
+                    if not name.isalpha():
+                        self.log.warning("%s db file name is not valid: it must be alphabetic only", str(db_file.name))
+                        continue
+                    self._dbs[name] = TinyDB(
+                            str(db_file),
+                            storage=ReadOnlyMiddleware(JSONStorage),
+                            default_table='actions'
+                    )
+
+        self.log.debug('Loaded %d databases: %s', len(self._dbs), ','.join(self._dbs.keys()))
+
+        self._actions = self._dbs['main'].table("actions")
+        self._alarms = self._dbs['main'].table("alarms")
+
+    def canonicalize(self, eid_or_aid):
+        try:
+            int(eid_or_aid)
+        except ValueError:
+            return eid_or_aid
+        return 'main:%d' % eid_or_aid
+
+    def parse_id(self, eid_or_aid):
+        try:
+            int(eid_or_aid)
+        except ValueError:
+            pass
+        else:
+            return ('main', eid_or_aid)
+
+        dbname, num = eid_or_aid.split(':')
+        return (dbname, int(num))
+        
 
     def get_action_by_id(self, action_id):
-        return self._actions.get(eid=action_id)
+        db, action_id = self.parse_id(action_id)
+        return self._dbs[db].table('actions').get(eid=action_id)
 
     def get_alarm_by_id(self, alarm_id):
-        return self._alarms.get(eid=alarm_id)
+        db, alarm_id = self.parse_id(alarm_id)
+        return self._dbs[db].table('alarms').get(eid=alarm_id)
 
     def get_actions_by_alarm(self, alarm):
         for action_id in alarm.get("actions", []):
@@ -27,11 +87,17 @@ class EventModel(object):
                 continue
             yield action
 
-    def get_all_alarms(self):
-        return self._alarms.all()
-
-    def get_all_actions(self):
-        return self._actions.all()
+    def get_all_alarms(self) -> list:
+        out = []
+        for db in self._dbs:
+            out.extend(self._dbs[db].table('alarms').all())
+        return out
+
+    def get_all_actions(self) -> list:
+        out = []
+        for db in self._dbs:
+            out.extend(self._dbs[db].table('actions').all())
+        return out
 
     def get_all_alarms_expanded(self):
         for alarm in self.get_all_alarms():