123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- import secrets
- import logging
- import dbm
- from collections import defaultdict
- from mimetypes import guess_type
- from asyncio.queues import Queue
- from fastapi import FastAPI, WebSocket, HTTPException, Depends, status, Response
- from starlette.websockets import WebSocketDisconnect
- from fastapi.security import HTTPBasic, HTTPBasicCredentials
- from fastapi.encoders import jsonable_encoder
- from fastapi.staticfiles import StaticFiles
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel
- class BaseStore:
- def __init__(self, n: int):
- self.values = {i: 0 for i in range(n)}
- def get(self, key) -> int:
- return self.values[key]
- def incr(self, key) -> int:
- newval = self.get(key) + 1
- return self.set(key, newval)
- def set(self, key, value: int) -> int:
- self.values[key] = value
- return value
- class PersistStoreMixin:
- def __init__(self, *args, **kwargs):
- self.db_path = kwargs.pop("db_path")
- super().__init__(*args, **kwargs)
- with dbm.open(self.db_path, "c") as db:
- for key in db.keys():
- self.values[int(key)] = int(db[key])
- def set(self, key, value: int) -> int:
- ret = super().set(key, value)
- with dbm.open(self.db_path, "w") as db:
- db[str(key)] = str(value)
- return ret
- class SignalStoreMixin:
- """make any BaseStore manager-aware"""
- def __init__(self, *args, **kwargs):
- manager = kwargs.pop("manager")
- super().__init__(*args, **kwargs)
- self.manager = manager
- def set(self, key, value: int) -> int:
- ret = super().set(key, value)
- self.manager.notify(key, value)
- return ret
- class Store(SignalStoreMixin, PersistStoreMixin, BaseStore):
- pass
- class Manager:
- """Handle notifications logic (for websocket)."""
- def __init__(self):
- self.registry = defaultdict(list)
- def subscribe(self, key, q: Queue):
- self.registry[key].append(q)
- def unsubscribe(self, key, q: Queue):
- try:
- self.registry[key].remove(q)
- except ValueError: # not found
- pass
- def notify(self, key, val):
- for queue in self.registry[key]:
- queue.put_nowait(val)
- app = FastAPI()
- app.add_middleware(
- CORSMiddleware,
- allow_origins=['*'],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- app.mount("/static", StaticFiles(directory="static"), name="static")
- manager = Manager()
- counter_store = Store(
- n=1, db_path="/var/lib/pizzicore/pizzicore.dbm", manager=manager
- ) # XXX: pesca da file di conf
- security = HTTPBasic()
- class Value(BaseModel):
- counter: int
- value: int
- def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
- # XXX: read user/pass from config
- correct_username = secrets.compare_digest(credentials.username, "avanti")
- correct_password = secrets.compare_digest(credentials.password, "prossimo")
- if not (correct_username and correct_password):
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password",
- headers={"WWW-Authenticate": "Basic"},
- )
- return "admin"
- @app.get("/v1/counter/{cid}")
- async def get_value(cid: int):
- try:
- val = counter_store.get(cid)
- except KeyError:
- raise HTTPException(status_code=404, detail="Item not found")
- else:
- return Value(counter=cid, value=val)
- @app.post("/v1/counter/{cid}/increment")
- async def increment(cid: int, role: str = Depends(get_current_role)):
- if role != "admin":
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
- val = counter_store.incr(cid)
- return Value(counter=cid, value=val)
- @app.websocket("/v1/ws/counter/{cid}")
- async def websocket_counter(websocket: WebSocket, cid: int):
- await websocket.accept()
- q: Queue = Queue()
- manager.subscribe(cid, q)
- while True:
- try:
- val = counter_store.get(cid)
- await websocket.send_json(jsonable_encoder(Value(counter=cid, value=val)))
- except WebSocketDisconnect:
- logging.debug("client disconnected")
- manager.unsubscribe(cid, q)
- return
- except Exception:
- logging.exception("unexpected error")
- manager.unsubscribe(cid, q)
- return
- await q.get()
- async def get_page(fname):
- with open(fname) as f:
- content = f.read()
- content_type, _ = guess_type(fname)
- return Response(content, media_type=content_type)
- @app.get("/")
- async def root_page():
- return await get_page("pages/index.html")
- @app.get("/prenota")
- async def prenota_page():
- return await get_page("pages/prenotati.html")
|