import secrets import logging import dbm from collections import defaultdict from mimetypes import guess_type from asyncio.queues import Queue from fastapi import FastAPI, WebSocket, HTTPException, Depends, status, Response from starlette.websockets import WebSocketDisconnect from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.encoders import jsonable_encoder from fastapi.staticfiles import StaticFiles from pydantic import BaseModel class BaseStore: def __init__(self, n: int): self.values = {i: 0 for i in range(n)} def get(self, key) -> int: return self.values[key] def incr(self, key) -> int: newval = self.get(key) + 1 return self.set(key, newval) def set(self, key, value: int) -> int: self.values[key] = value return value class PersistStoreMixin: def __init__(self, *args, **kwargs): self.db_path = kwargs.pop("db_path") super().__init__(*args, **kwargs) with dbm.open(self.db_path, "c") as db: for key in db.keys(): self.values[int(key)] = int(db[key]) def set(self, key, value: int) -> int: ret = super().set(key, value) with dbm.open(self.db_path, "w") as db: db[str(key)] = str(value) return ret class SignalStoreMixin: """make any BaseStore manager-aware""" def __init__(self, *args, **kwargs): manager = kwargs.pop("manager") super().__init__(*args, **kwargs) self.manager = manager def set(self, key, value: int) -> int: ret = super().set(key, value) self.manager.notify(key, value) return ret class Store(SignalStoreMixin, PersistStoreMixin, BaseStore): pass class Manager: """Handle notifications logic (for websocket).""" def __init__(self): self.registry = defaultdict(list) def subscribe(self, key, q: Queue): self.registry[key].append(q) def unsubscribe(self, key, q: Queue): try: self.registry[key].remove(q) except ValueError: # not found pass def notify(self, key, val): for queue in self.registry[key]: queue.put_nowait(val) app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") manager = Manager() counter_store = Store( n=1, db_path="/var/lib/pizzicore/pizzicore.dbm", manager=manager ) # XXX: pesca da file di conf security = HTTPBasic() class Value(BaseModel): counter: int value: int def get_current_role(credentials: HTTPBasicCredentials = Depends(security)): # XXX: read user/pass from config correct_username = secrets.compare_digest(credentials.username, "avanti") correct_password = secrets.compare_digest(credentials.password, "prossimo") if not (correct_username and correct_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}, ) return "admin" @app.get("/v1/counter/{cid}") async def get_value(cid: int): try: val = counter_store.get(cid) except KeyError: raise HTTPException(status_code=404, detail="Item not found") else: return Value(counter=cid, value=val) @app.post("/v1/counter/{cid}/increment") async def increment(cid: int, role: str = Depends(get_current_role)): if role != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) val = counter_store.incr(cid) return Value(counter=cid, value=val) @app.websocket("/v1/ws/counter/{cid}") async def websocket_counter(websocket: WebSocket, cid: int): await websocket.accept() q: Queue = Queue() manager.subscribe(cid, q) while True: try: val = counter_store.get(cid) await websocket.send_json(jsonable_encoder(Value(counter=cid, value=val))) except WebSocketDisconnect: logging.debug("client disconnected") manager.unsubscribe(cid, q) return except Exception: logging.exception("unexpected error") manager.unsubscribe(cid, q) return await q.get() async def get_page(fname): with open(fname) as f: content = f.read() content_type, _ = guess_type(fname) return Response(content, media_type=content_type) @app.get("/") async def root_page(): return await get_page("pages/index.html")