import secrets import logging import dbm from collections import defaultdict from mimetypes import guess_type from pathlib import Path from asyncio.queues import Queue from fastapi import FastAPI, WebSocket, HTTPException, Depends, status, Response from starlette.websockets import WebSocketDisconnect from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.encoders import jsonable_encoder from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, BaseSettings class Settings(BaseSettings): app_name: str = "Numeretti" storage_dir: Path = Path("/var/lib/pizzicore") queues_number: int = 1 class Config: env_file = "pizzicore.env" class BaseStore: def __init__(self, n: int): self.values = {i: 0 for i in range(n)} def get(self, key) -> int: return self.values[key] def incr(self, key) -> int: newval = self.get(key) + 1 return self.set(key, newval) def set(self, key, value: int) -> int: self.values[key] = value return value class PersistStoreMixin: def __init__(self, *args, **kwargs): self.db_path = kwargs.pop("db_path") super().__init__(*args, **kwargs) with dbm.open(str(self.db_path), "c") as db: for key in db.keys(): self.values[int(key)] = int(db[key]) def set(self, key, value: int) -> int: ret = super().set(key, value) with dbm.open(str(self.db_path), "w") as db: db[str(key)] = str(value) return ret class SignalStoreMixin: """make any BaseStore manager-aware""" def __init__(self, *args, **kwargs): manager = kwargs.pop("manager") super().__init__(*args, **kwargs) self.manager = manager def set(self, key, value: int) -> int: ret = super().set(key, value) self.manager.notify(key, value) return ret class Store(SignalStoreMixin, PersistStoreMixin, BaseStore): pass class Manager: """Handle notifications logic (for websocket).""" def __init__(self): self.registry = defaultdict(list) def subscribe(self, key, q: Queue): self.registry[key].append(q) def unsubscribe(self, key, q: Queue): try: self.registry[key].remove(q) except ValueError: # not found pass def notify(self, key, val): for queue in self.registry[key]: queue.put_nowait(val) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.mount("/static", StaticFiles(directory="static"), name="static") manager = Manager() settings = Settings() counter_store = Store( n=settings.queues_number, db_path=settings.storage_dir / "pizzicore.dbm", manager=manager ) security = HTTPBasic() class CountersDescription(BaseModel): counters: int class Value(BaseModel): counter: int value: int def get_current_role(credentials: HTTPBasicCredentials = Depends(security)): # XXX: read user/pass from config correct_username = secrets.compare_digest(credentials.username, "avanti") correct_password = secrets.compare_digest(credentials.password, "prossimo") if not (correct_username and correct_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}, ) return "admin" @app.get("/v1/counter/") async def get_counter_number(): return CountersDescription(counters=len(counter_store.values)) @app.get("/v1/counter/{cid}") async def get_value(cid: int): try: val = counter_store.get(cid) except KeyError: raise HTTPException(status_code=404, detail="Counter not found") else: return Value(counter=cid, value=val) @app.post("/v1/counter/{cid}/increment") async def increment(cid: int, role: str = Depends(get_current_role)): if role != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) try: val = counter_store.incr(cid) except KeyError: raise HTTPException(status_code=404, detail="Counter not found") return Value(counter=cid, value=val) @app.websocket("/v1/ws/counter/{cid}") async def websocket_counter(websocket: WebSocket, cid: int): await websocket.accept() q: Queue = Queue() manager.subscribe(cid, q) while True: try: val = counter_store.get(cid) await websocket.send_json(jsonable_encoder(Value(counter=cid, value=val))) except WebSocketDisconnect: logging.debug("client disconnected") manager.unsubscribe(cid, q) return except Exception: logging.exception("unexpected error") manager.unsubscribe(cid, q) return await q.get() async def get_page(fname): with open(fname) as f: content = f.read() content_type, _ = guess_type(fname) return Response(content, media_type=content_type) @app.get("/") async def root_page(): return await get_page("pages/index.html") @app.get("/prenota") async def prenota_page(): return await get_page("pages/prenotati.html")