numeretti/pizzicore/pizzicore.py

217 lines
5.9 KiB
Python
Raw Normal View History

2021-09-16 01:47:11 +02:00
import secrets
2021-09-16 02:43:03 +02:00
import logging
2021-09-16 02:36:00 +02:00
import dbm
2021-09-16 01:47:11 +02:00
from collections import defaultdict
2021-09-24 00:52:33 +02:00
from mimetypes import guess_type
2022-08-17 19:17:01 +02:00
from pathlib import Path
2021-09-16 01:47:11 +02:00
from asyncio.queues import Queue
2021-09-24 00:52:33 +02:00
from fastapi import FastAPI, WebSocket, HTTPException, Depends, status, Response
2021-09-16 02:43:03 +02:00
from starlette.websockets import WebSocketDisconnect
2021-09-16 01:47:11 +02:00
from fastapi.security import HTTPBasic, HTTPBasicCredentials
2021-09-24 00:52:33 +02:00
from fastapi.encoders import jsonable_encoder
from fastapi.staticfiles import StaticFiles
2021-10-09 20:39:33 +02:00
from fastapi.middleware.cors import CORSMiddleware
2022-08-17 19:17:01 +02:00
from pydantic import BaseModel, BaseSettings
2021-09-16 01:47:11 +02:00
2022-08-17 19:17:01 +02:00
class Settings(BaseSettings):
app_name: str = "Numeretti"
storage_dir: Path = Path("/var/lib/pizzicore")
queues_number: int = 1
admin_password: str = "changeme!"
2022-08-17 19:17:01 +02:00
class Config:
env_file = "pizzicore.env"
2021-09-16 01:47:11 +02:00
2021-09-16 02:21:50 +02:00
class BaseStore:
2021-09-16 01:47:11 +02:00
def __init__(self, n: int):
self.values = {i: 0 for i in range(n)}
2021-09-16 02:21:50 +02:00
def get(self, key) -> int:
2021-09-16 01:47:11 +02:00
return self.values[key]
2021-09-16 02:21:50 +02:00
def incr(self, key) -> int:
newval = self.get(key) + 1
return self.set(key, newval)
2021-09-16 01:47:11 +02:00
2022-08-17 19:18:16 +02:00
def decr(self, key) -> int:
newval = self.get(key) - 1
return self.set(key, newval)
2021-09-16 02:21:50 +02:00
def set(self, key, value: int) -> int:
2021-09-16 01:47:11 +02:00
self.values[key] = value
return value
2021-09-16 02:36:00 +02:00
class PersistStoreMixin:
def __init__(self, *args, **kwargs):
self.db_path = kwargs.pop("db_path")
super().__init__(*args, **kwargs)
2022-08-17 19:17:01 +02:00
with dbm.open(str(self.db_path), "c") as db:
2021-09-16 02:36:00 +02:00
for key in db.keys():
self.values[int(key)] = int(db[key])
def set(self, key, value: int) -> int:
ret = super().set(key, value)
2022-08-17 19:17:01 +02:00
with dbm.open(str(self.db_path), "w") as db:
2021-09-16 02:36:00 +02:00
db[str(key)] = str(value)
return ret
class SignalStoreMixin:
2021-09-16 02:21:50 +02:00
"""make any BaseStore manager-aware"""
2021-09-16 01:47:11 +02:00
def __init__(self, *args, **kwargs):
2021-09-16 02:21:50 +02:00
manager = kwargs.pop("manager")
2021-09-16 01:47:11 +02:00
super().__init__(*args, **kwargs)
2021-09-16 02:21:50 +02:00
self.manager = manager
def set(self, key, value: int) -> int:
ret = super().set(key, value)
self.manager.notify(key, value)
return ret
2021-09-16 02:36:00 +02:00
class Store(SignalStoreMixin, PersistStoreMixin, BaseStore):
2021-09-16 02:21:50 +02:00
pass
class Manager:
"""Handle notifications logic (for websocket)."""
def __init__(self):
2021-09-16 01:47:11 +02:00
self.registry = defaultdict(list)
2021-09-16 02:21:50 +02:00
def subscribe(self, key, q: Queue):
2021-09-16 01:47:11 +02:00
self.registry[key].append(q)
2021-09-16 02:21:50 +02:00
def unsubscribe(self, key, q: Queue):
try:
self.registry[key].remove(q)
except ValueError: # not found
pass
2021-09-16 01:47:11 +02:00
2021-09-16 02:21:50 +02:00
def notify(self, key, val):
for queue in self.registry[key]:
queue.put_nowait(val)
2021-09-16 01:47:11 +02:00
app = FastAPI()
2021-10-09 20:39:33 +02:00
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
2021-09-24 00:52:33 +02:00
app.mount("/static", StaticFiles(directory="static"), name="static")
2021-09-16 02:21:50 +02:00
manager = Manager()
2022-08-17 19:17:01 +02:00
settings = Settings()
2021-09-16 02:36:00 +02:00
counter_store = Store(
2022-08-17 19:17:01 +02:00
n=settings.queues_number, db_path=settings.storage_dir / "pizzicore.dbm", manager=manager
)
2021-09-16 01:47:11 +02:00
security = HTTPBasic()
2022-08-17 19:17:13 +02:00
class CountersDescription(BaseModel):
counters: int
class UserDescription(BaseModel):
role: str
2022-08-17 19:17:13 +02:00
2021-09-16 01:47:11 +02:00
class Value(BaseModel):
counter: int
value: int
2021-09-16 02:21:50 +02:00
def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
correct_username = secrets.compare_digest(credentials.username, "admin")
correct_password = secrets.compare_digest(credentials.password, settings.admin_password)
2021-09-16 01:47:11 +02:00
if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return "admin"
@app.get("/v1/whoami/")
async def whoami(role: str = Depends(get_current_role)):
return UserDescription(role=role)
2022-08-17 19:17:13 +02:00
@app.get("/v1/counter/")
async def get_counter_number():
return CountersDescription(counters=len(counter_store.values))
2021-09-16 01:47:11 +02:00
2021-09-24 00:52:55 +02:00
@app.get("/v1/counter/{cid}")
2021-09-16 01:47:11 +02:00
async def get_value(cid: int):
try:
val = counter_store.get(cid)
except KeyError:
2022-08-17 19:17:13 +02:00
raise HTTPException(status_code=404, detail="Counter not found")
2021-09-16 02:21:50 +02:00
else:
return Value(counter=cid, value=val)
2021-09-16 01:47:11 +02:00
2021-09-24 00:52:55 +02:00
@app.post("/v1/counter/{cid}/increment")
2021-09-16 02:21:50 +02:00
async def increment(cid: int, role: str = Depends(get_current_role)):
2021-09-16 01:47:11 +02:00
if role != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
2022-08-17 19:17:13 +02:00
try:
val = counter_store.incr(cid)
except KeyError:
raise HTTPException(status_code=404, detail="Counter not found")
2021-09-16 01:47:11 +02:00
return Value(counter=cid, value=val)
2022-08-17 19:18:16 +02:00
@app.post("/v1/counter/{cid}/decrement")
async def increment(cid: int, role: str = Depends(get_current_role)):
if role != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
try:
val = counter_store.decr(cid)
except KeyError:
raise HTTPException(status_code=404, detail="Counter not found")
return Value(counter=cid, value=val)
2021-09-16 01:47:11 +02:00
2021-09-24 00:52:55 +02:00
@app.websocket("/v1/ws/counter/{cid}")
2021-09-16 01:47:11 +02:00
async def websocket_counter(websocket: WebSocket, cid: int):
await websocket.accept()
2021-09-16 02:21:50 +02:00
q: Queue = Queue()
manager.subscribe(cid, q)
2021-09-16 01:47:11 +02:00
while True:
2021-09-16 02:21:50 +02:00
try:
2021-09-16 02:43:03 +02:00
val = counter_store.get(cid)
2021-09-24 00:53:30 +02:00
await websocket.send_json(jsonable_encoder(Value(counter=cid, value=val)))
2021-09-16 02:43:03 +02:00
except WebSocketDisconnect:
logging.debug("client disconnected")
manager.unsubscribe(cid, q)
return
except Exception:
logging.exception("unexpected error")
2021-09-16 02:21:50 +02:00
manager.unsubscribe(cid, q)
return
await q.get()
2021-09-24 00:52:33 +02:00
async def get_page(fname):
with open(fname) as f:
content = f.read()
content_type, _ = guess_type(fname)
return Response(content, media_type=content_type)
@app.get("/")
async def root_page():
return await get_page("pages/index.html")
2021-10-09 20:37:51 +02:00
@app.get("/prenota")
2021-10-09 21:15:08 +02:00
async def prenota_page():
2021-10-09 20:37:51 +02:00
return await get_page("pages/prenotati.html")