220 lines
6 KiB
Python
220 lines
6 KiB
Python
import secrets
|
|
import logging
|
|
import dbm
|
|
from collections import defaultdict
|
|
from mimetypes import guess_type
|
|
from pathlib import Path
|
|
|
|
from asyncio.queues import Queue
|
|
from fastapi import FastAPI, WebSocket, HTTPException, Depends, status, Response
|
|
from starlette.websockets import WebSocketDisconnect
|
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|
from fastapi.encoders import jsonable_encoder
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, BaseSettings
|
|
|
|
class Settings(BaseSettings):
|
|
app_name: str = "Numeretti"
|
|
storage_dir: Path = Path("/var/lib/pizzicore")
|
|
queues_number: int = 1
|
|
admin_password: str = "changeme!"
|
|
|
|
class Config:
|
|
env_file = "pizzicore.env"
|
|
|
|
class BaseStore:
|
|
def __init__(self, n: int):
|
|
self.values = {i: 0 for i in range(n)}
|
|
|
|
def get(self, key) -> int:
|
|
return self.values[key]
|
|
|
|
def incr(self, key) -> int:
|
|
newval = self.get(key) + 1
|
|
return self.set(key, newval)
|
|
|
|
def decr(self, key) -> int:
|
|
newval = self.get(key) - 1
|
|
return self.set(key, newval)
|
|
|
|
def set(self, key, value: int) -> int:
|
|
self.values[key] = value
|
|
return value
|
|
|
|
|
|
class PersistStoreMixin:
|
|
def __init__(self, *args, **kwargs):
|
|
self.db_path = kwargs.pop("db_path")
|
|
super().__init__(*args, **kwargs)
|
|
with dbm.open(str(self.db_path), "c") as db:
|
|
for key in db.keys():
|
|
self.values[int(key)] = int(db[key])
|
|
|
|
def set(self, key, value: int) -> int:
|
|
ret = super().set(key, value)
|
|
with dbm.open(str(self.db_path), "w") as db:
|
|
db[str(key)] = str(value)
|
|
return ret
|
|
|
|
|
|
class SignalStoreMixin:
|
|
"""make any BaseStore manager-aware"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
manager = kwargs.pop("manager")
|
|
super().__init__(*args, **kwargs)
|
|
self.manager = manager
|
|
|
|
def set(self, key, value: int) -> int:
|
|
ret = super().set(key, value)
|
|
self.manager.notify(key, value)
|
|
return ret
|
|
|
|
|
|
class Store(SignalStoreMixin, PersistStoreMixin, BaseStore):
|
|
pass
|
|
|
|
|
|
class Manager:
|
|
"""Handle notifications logic (for websocket)."""
|
|
|
|
def __init__(self):
|
|
self.registry = defaultdict(list)
|
|
|
|
def subscribe(self, key, q: Queue):
|
|
self.registry[key].append(q)
|
|
|
|
def unsubscribe(self, key, q: Queue):
|
|
try:
|
|
self.registry[key].remove(q)
|
|
except ValueError: # not found
|
|
pass
|
|
|
|
def notify(self, key, val):
|
|
for queue in self.registry[key]:
|
|
queue.put_nowait(val)
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=['*'],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
manager = Manager()
|
|
settings = Settings()
|
|
counter_store = Store(
|
|
n=settings.queues_number, db_path=settings.storage_dir / "pizzicore.dbm", manager=manager
|
|
)
|
|
security = HTTPBasic()
|
|
|
|
|
|
class CountersDescription(BaseModel):
|
|
counters: int
|
|
|
|
class UserDescription(BaseModel):
|
|
role: str
|
|
|
|
|
|
class Value(BaseModel):
|
|
counter: int
|
|
value: int
|
|
|
|
|
|
def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
|
|
correct_username = secrets.compare_digest(credentials.username, "admin")
|
|
correct_password = secrets.compare_digest(credentials.password, settings.admin_password)
|
|
if not (correct_username and correct_password):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Basic"},
|
|
)
|
|
return "admin"
|
|
|
|
@app.get("/v1/whoami/")
|
|
async def whoami(role: str = Depends(get_current_role)):
|
|
return UserDescription(role=role)
|
|
|
|
@app.get("/v1/counter/")
|
|
async def get_counter_number():
|
|
return CountersDescription(counters=len(counter_store.values))
|
|
|
|
|
|
@app.get("/v1/counter/{cid}")
|
|
async def get_value(cid: int):
|
|
try:
|
|
val = counter_store.get(cid)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail="Counter not found")
|
|
else:
|
|
return Value(counter=cid, value=val)
|
|
|
|
|
|
@app.post("/v1/counter/{cid}/increment")
|
|
async def increment(cid: int, role: str = Depends(get_current_role)):
|
|
if role != "admin":
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
|
try:
|
|
val = counter_store.incr(cid)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail="Counter not found")
|
|
return Value(counter=cid, value=val)
|
|
|
|
@app.post("/v1/counter/{cid}/decrement")
|
|
async def increment(cid: int, role: str = Depends(get_current_role)):
|
|
if role != "admin":
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
|
try:
|
|
val = counter_store.decr(cid)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail="Counter not found")
|
|
return Value(counter=cid, value=val)
|
|
|
|
|
|
@app.websocket("/v1/ws/counter/{cid}")
|
|
async def websocket_counter(websocket: WebSocket, cid: int):
|
|
await websocket.accept()
|
|
q: Queue = Queue()
|
|
manager.subscribe(cid, q)
|
|
|
|
while True:
|
|
try:
|
|
val = counter_store.get(cid)
|
|
await websocket.send_json(jsonable_encoder(Value(counter=cid, value=val)))
|
|
except WebSocketDisconnect:
|
|
logging.debug("client disconnected")
|
|
manager.unsubscribe(cid, q)
|
|
return
|
|
except Exception:
|
|
logging.exception("unexpected error")
|
|
manager.unsubscribe(cid, q)
|
|
return
|
|
await q.get()
|
|
|
|
|
|
async def get_page(fname):
|
|
with open(fname) as f:
|
|
content = f.read()
|
|
content_type, _ = guess_type(fname)
|
|
return Response(content, media_type=content_type)
|
|
|
|
|
|
@app.get("/")
|
|
async def root_page():
|
|
return await get_page("pages/index.html")
|
|
|
|
@app.get("/prenota")
|
|
async def prenota_page():
|
|
return await get_page("pages/prenotati.html")
|
|
|
|
@app.get("/new")
|
|
async def prenota_page():
|
|
return await get_page("pages/spa.html")
|