3
0
Fork 0
numeretti/pizzicore/pizzicore.py
2021-09-24 00:53:44 +02:00

150 linhas
4 KiB
Python

import secrets
import logging
import dbm
from collections import defaultdict
from asyncio.queues import Queue
from fastapi import (
FastAPI,
WebSocket,
HTTPException,
Depends,
status,
)
from starlette.websockets import WebSocketDisconnect
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pydantic import BaseModel
class BaseStore:
def __init__(self, n: int):
self.values = {i: 0 for i in range(n)}
def get(self, key) -> int:
return self.values[key]
def incr(self, key) -> int:
newval = self.get(key) + 1
return self.set(key, newval)
def set(self, key, value: int) -> int:
self.values[key] = value
return value
class PersistStoreMixin:
def __init__(self, *args, **kwargs):
self.db_path = kwargs.pop("db_path")
super().__init__(*args, **kwargs)
with dbm.open(self.db_path, "c") as db:
for key in db.keys():
self.values[int(key)] = int(db[key])
def set(self, key, value: int) -> int:
ret = super().set(key, value)
with dbm.open(self.db_path, "w") as db:
db[str(key)] = str(value)
return ret
class SignalStoreMixin:
"""make any BaseStore manager-aware"""
def __init__(self, *args, **kwargs):
manager = kwargs.pop("manager")
super().__init__(*args, **kwargs)
self.manager = manager
def set(self, key, value: int) -> int:
ret = super().set(key, value)
self.manager.notify(key, value)
return ret
class Store(SignalStoreMixin, PersistStoreMixin, BaseStore):
pass
class Manager:
"""Handle notifications logic (for websocket)."""
def __init__(self):
self.registry = defaultdict(list)
def subscribe(self, key, q: Queue):
self.registry[key].append(q)
def unsubscribe(self, key, q: Queue):
try:
self.registry[key].remove(q)
except ValueError: # not found
pass
def notify(self, key, val):
for queue in self.registry[key]:
queue.put_nowait(val)
app = FastAPI()
manager = Manager()
counter_store = Store(
n=1, db_path="/var/lib/pizzicore/pizzicore.dbm", manager=manager
) # XXX: pesca da file di conf
security = HTTPBasic()
class Value(BaseModel):
counter: int
value: int
def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
# XXX: read user/pass from config
correct_username = secrets.compare_digest(credentials.username, "avanti")
correct_password = secrets.compare_digest(credentials.password, "prossimo")
if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return "admin"
@app.get("/v1/counter/{cid}")
async def get_value(cid: int):
try:
val = counter_store.get(cid)
except KeyError:
raise HTTPException(status_code=404, detail="Item not found")
else:
return Value(counter=cid, value=val)
@app.post("/v1/counter/{cid}/increment")
async def increment(cid: int, role: str = Depends(get_current_role)):
if role != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
val = counter_store.incr(cid)
return Value(counter=cid, value=val)
@app.websocket("/v1/ws/counter/{cid}")
async def websocket_counter(websocket: WebSocket, cid: int):
await websocket.accept()
q: Queue = Queue()
manager.subscribe(cid, q)
while True:
try:
val = counter_store.get(cid)
await websocket.send_json(jsonable_encoder(Value(counter=cid, value=val)))
except WebSocketDisconnect:
logging.debug("client disconnected")
manager.unsubscribe(cid, q)
return
except Exception:
logging.exception("unexpected error")
manager.unsubscribe(cid, q)
return
await q.get()