diff --git a/larigira/db.py b/larigira/db.py index 6111856..0c1da6e 100644 --- a/larigira/db.py +++ b/larigira/db.py @@ -1,24 +1,84 @@ +from logging import getLogger from tinydb import TinyDB +from tinydb.storages import JSONStorage +from tinydb.middlewares import Middleware +from pathlib import Path +class ReadOnlyMiddleware(Middleware): + """ + Make sure no write ever occurs + """ + + def __init__(self, storage_cls=TinyDB.DEFAULT_STORAGE): + super().__init__(storage_cls) + + + def write(self, data): + raise ReadOnlyException('You cannot write to a readonly db') + + +class ReadOnlyException(ValueError): + pass class EventModel(object): - def __init__(self, uri): + def __init__(self, uri, additional_db_dir=None): self.uri = uri - self.db = None + self.additional_db_dir = Path(additional_db_dir) if additional_db_dir else None + self._dbs = {} + self.log = getLogger(self.__class__.__name__) self.reload() def reload(self): - if self.db is not None: - self.db.close() - self.db = TinyDB(self.uri, indent=2) - self._actions = self.db.table("actions") - self._alarms = self.db.table("alarms") + for db in self._dbs.values(): + db.close() + self._dbs['main'] = TinyDB(self.uri, indent=2) + if self.additional_db_dir is not None: + if self.additional_db_dir.is_dir(): + for db_file in self.additional_db_dir.glob('*.db.json'): + name = db_file.name[:-8] + if name == 'main': + self.log.warning("%s db file name is not valid (any other name.db.json would have been ok!", str(db_file.name)) + continue + if not name.isalpha(): + self.log.warning("%s db file name is not valid: it must be alphabetic only", str(db_file.name)) + continue + self._dbs[name] = TinyDB( + str(db_file), + storage=ReadOnlyMiddleware(JSONStorage), + default_table='actions' + ) + + self.log.debug('Loaded %d databases: %s', len(self._dbs), ','.join(self._dbs.keys())) + + self._actions = self._dbs['main'].table("actions") + self._alarms = self._dbs['main'].table("alarms") + + def canonicalize(self, eid_or_aid): + try: + int(eid_or_aid) + except ValueError: + return eid_or_aid + return 'main:%d' % eid_or_aid + + def parse_id(self, eid_or_aid): + try: + int(eid_or_aid) + except ValueError: + pass + else: + return ('main', eid_or_aid) + + dbname, num = eid_or_aid.split(':') + return (dbname, int(num)) + def get_action_by_id(self, action_id): - return self._actions.get(eid=action_id) + db, action_id = self.parse_id(action_id) + return self._dbs[db].table('actions').get(eid=action_id) def get_alarm_by_id(self, alarm_id): - return self._alarms.get(eid=alarm_id) + db, alarm_id = self.parse_id(alarm_id) + return self._dbs[db].table('alarms').get(eid=alarm_id) def get_actions_by_alarm(self, alarm): for action_id in alarm.get("actions", []): @@ -27,11 +87,17 @@ class EventModel(object): continue yield action - def get_all_alarms(self): - return self._alarms.all() + def get_all_alarms(self) -> list: + out = [] + for db in self._dbs: + out.extend(self._dbs[db].table('alarms').all()) + return out - def get_all_actions(self): - return self._actions.all() + def get_all_actions(self) -> list: + out = [] + for db in self._dbs: + out.extend(self._dbs[db].table('actions').all()) + return out def get_all_alarms_expanded(self): for alarm in self.get_all_alarms():