numeretti/pizzicore/pizzicore.py
2022-08-19 19:21:17 +02:00

216 lines
5.9 KiB
Python

import secrets
import logging
import dbm
from collections import defaultdict
from mimetypes import guess_type
from pathlib import Path
from asyncio.queues import Queue
from fastapi import FastAPI, WebSocket, HTTPException, Depends, status, Response
from starlette.websockets import WebSocketDisconnect
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.encoders import jsonable_encoder
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, BaseSettings
class Settings(BaseSettings):
app_name: str = "Numeretti"
storage_dir: Path = Path("/var/lib/pizzicore")
queues_number: int = 1
admin_password: str = "changeme!"
class Config:
env_file = "pizzicore.env"
class BaseStore:
def __init__(self, n: int):
self.values = {i: 0 for i in range(n)}
def get(self, key) -> int:
return self.values[key]
def incr(self, key) -> int:
newval = self.get(key) + 1
return self.set(key, newval)
def decr(self, key) -> int:
newval = self.get(key) - 1
return self.set(key, newval)
def set(self, key, value: int) -> int:
self.values[key] = value
return value
class PersistStoreMixin:
def __init__(self, *args, **kwargs):
self.db_path = kwargs.pop("db_path")
super().__init__(*args, **kwargs)
with dbm.open(str(self.db_path), "c") as db:
for key in db.keys():
self.values[int(key)] = int(db[key])
def set(self, key, value: int) -> int:
ret = super().set(key, value)
with dbm.open(str(self.db_path), "w") as db:
db[str(key)] = str(value)
return ret
class SignalStoreMixin:
"""make any BaseStore manager-aware"""
def __init__(self, *args, **kwargs):
manager = kwargs.pop("manager")
super().__init__(*args, **kwargs)
self.manager = manager
def set(self, key, value: int) -> int:
ret = super().set(key, value)
self.manager.notify(key, value)
return ret
class Store(SignalStoreMixin, PersistStoreMixin, BaseStore):
pass
class Manager:
"""Handle notifications logic (for websocket)."""
def __init__(self):
self.registry = defaultdict(list)
def subscribe(self, key, q: Queue):
self.registry[key].append(q)
def unsubscribe(self, key, q: Queue):
try:
self.registry[key].remove(q)
except ValueError: # not found
pass
def notify(self, key, val):
for queue in self.registry[key]:
queue.put_nowait(val)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
manager = Manager()
settings = Settings()
counter_store = Store(
n=settings.queues_number, db_path=settings.storage_dir / "pizzicore.dbm", manager=manager
)
security = HTTPBasic()
class CountersDescription(BaseModel):
counters: int
class UserDescription(BaseModel):
role: str
class Value(BaseModel):
counter: int
value: int
def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
correct_username = secrets.compare_digest(credentials.username, "admin")
correct_password = secrets.compare_digest(credentials.password, settings.admin_password)
if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return "admin"
@app.get("/v1/whoami/")
async def whoami(role: str = Depends(get_current_role)):
return UserDescription(role=role)
@app.get("/v1/counter/")
async def get_counter_number():
return CountersDescription(counters=len(counter_store.values))
@app.get("/v1/counter/{cid}")
async def get_value(cid: int):
try:
val = counter_store.get(cid)
except KeyError:
raise HTTPException(status_code=404, detail="Counter not found")
else:
return Value(counter=cid, value=val)
@app.post("/v1/counter/{cid}/increment")
async def increment(cid: int, role: str = Depends(get_current_role)):
if role != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
try:
val = counter_store.incr(cid)
except KeyError:
raise HTTPException(status_code=404, detail="Counter not found")
return Value(counter=cid, value=val)
@app.post("/v1/counter/{cid}/decrement")
async def increment(cid: int, role: str = Depends(get_current_role)):
if role != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
try:
val = counter_store.decr(cid)
except KeyError:
raise HTTPException(status_code=404, detail="Counter not found")
return Value(counter=cid, value=val)
@app.websocket("/v1/ws/counter/{cid}")
async def websocket_counter(websocket: WebSocket, cid: int):
await websocket.accept()
q: Queue = Queue()
manager.subscribe(cid, q)
while True:
try:
val = counter_store.get(cid)
await websocket.send_json(jsonable_encoder(Value(counter=cid, value=val)))
except WebSocketDisconnect:
logging.debug("client disconnected")
manager.unsubscribe(cid, q)
return
except Exception:
logging.exception("unexpected error")
manager.unsubscribe(cid, q)
return
await q.get()
async def get_page(fname):
with open(fname) as f:
content = f.read()
content_type, _ = guess_type(fname)
return Response(content, media_type=content_type)
@app.get("/")
async def root_page():
return await get_page("pages/index.html")
@app.get("/prenota")
async def prenota_page():
return await get_page("pages/prenotati.html")