pizzicore.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import secrets
  2. import logging
  3. import dbm
  4. from collections import defaultdict
  5. from mimetypes import guess_type
  6. from pathlib import Path
  7. from asyncio.queues import Queue
  8. from fastapi import FastAPI, WebSocket, HTTPException, Depends, status, Response
  9. from starlette.websockets import WebSocketDisconnect
  10. from fastapi.security import HTTPBasic, HTTPBasicCredentials
  11. from fastapi.encoders import jsonable_encoder
  12. from fastapi.staticfiles import StaticFiles
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from pydantic import BaseModel, BaseSettings
  15. class Settings(BaseSettings):
  16. app_name: str = "Numeretti"
  17. storage_dir: Path = Path("/var/lib/pizzicore")
  18. queues_number: int = 1
  19. admin_password: str = "changeme!"
  20. class Config:
  21. env_file = "pizzicore.env"
  22. class BaseStore:
  23. def __init__(self, n: int):
  24. self.values = {i: 0 for i in range(n)}
  25. def get(self, key) -> int:
  26. return self.values[key]
  27. def incr(self, key) -> int:
  28. newval = self.get(key) + 1
  29. return self.set(key, newval)
  30. def decr(self, key) -> int:
  31. newval = self.get(key) - 1
  32. return self.set(key, newval)
  33. def set(self, key, value: int) -> int:
  34. self.values[key] = value
  35. return value
  36. class PersistStoreMixin:
  37. def __init__(self, *args, **kwargs):
  38. self.db_path = kwargs.pop("db_path")
  39. super().__init__(*args, **kwargs)
  40. with dbm.open(str(self.db_path), "c") as db:
  41. for key in db.keys():
  42. self.values[int(key)] = int(db[key])
  43. def set(self, key, value: int) -> int:
  44. ret = super().set(key, value)
  45. with dbm.open(str(self.db_path), "w") as db:
  46. db[str(key)] = str(value)
  47. return ret
  48. class SignalStoreMixin:
  49. """make any BaseStore manager-aware"""
  50. def __init__(self, *args, **kwargs):
  51. manager = kwargs.pop("manager")
  52. super().__init__(*args, **kwargs)
  53. self.manager = manager
  54. def set(self, key, value: int) -> int:
  55. ret = super().set(key, value)
  56. self.manager.notify(key, value)
  57. return ret
  58. class Store(SignalStoreMixin, PersistStoreMixin, BaseStore):
  59. pass
  60. class Manager:
  61. """Handle notifications logic (for websocket)."""
  62. def __init__(self):
  63. self.registry = defaultdict(list)
  64. def subscribe(self, key, q: Queue):
  65. self.registry[key].append(q)
  66. def unsubscribe(self, key, q: Queue):
  67. try:
  68. self.registry[key].remove(q)
  69. except ValueError: # not found
  70. pass
  71. def notify(self, key, val):
  72. for queue in self.registry[key]:
  73. queue.put_nowait(val)
  74. app = FastAPI()
  75. app.add_middleware(
  76. CORSMiddleware,
  77. allow_origins=['*'],
  78. allow_credentials=True,
  79. allow_methods=["*"],
  80. allow_headers=["*"],
  81. )
  82. app.mount("/static", StaticFiles(directory="static"), name="static")
  83. manager = Manager()
  84. settings = Settings()
  85. counter_store = Store(
  86. n=settings.queues_number, db_path=settings.storage_dir / "pizzicore.dbm", manager=manager
  87. )
  88. security = HTTPBasic()
  89. class CountersDescription(BaseModel):
  90. counters: int
  91. class UserDescription(BaseModel):
  92. role: str
  93. class Value(BaseModel):
  94. counter: int
  95. value: int
  96. def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
  97. correct_username = secrets.compare_digest(credentials.username, "admin")
  98. correct_password = secrets.compare_digest(credentials.password, settings.admin_password)
  99. if not (correct_username and correct_password):
  100. raise HTTPException(
  101. status_code=status.HTTP_401_UNAUTHORIZED,
  102. detail="Incorrect username or password",
  103. headers={"WWW-Authenticate": "Basic"},
  104. )
  105. return "admin"
  106. @app.get("/v1/whoami/")
  107. async def whoami(role: str = Depends(get_current_role)):
  108. return UserDescription(role=role)
  109. @app.get("/v1/counter/")
  110. async def get_counter_number():
  111. return CountersDescription(counters=len(counter_store.values))
  112. @app.get("/v1/counter/{cid}")
  113. async def get_value(cid: int):
  114. try:
  115. val = counter_store.get(cid)
  116. except KeyError:
  117. raise HTTPException(status_code=404, detail="Counter not found")
  118. else:
  119. return Value(counter=cid, value=val)
  120. @app.post("/v1/counter/{cid}/increment")
  121. async def increment(cid: int, role: str = Depends(get_current_role)):
  122. if role != "admin":
  123. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
  124. try:
  125. val = counter_store.incr(cid)
  126. except KeyError:
  127. raise HTTPException(status_code=404, detail="Counter not found")
  128. return Value(counter=cid, value=val)
  129. @app.post("/v1/counter/{cid}/decrement")
  130. async def increment(cid: int, role: str = Depends(get_current_role)):
  131. if role != "admin":
  132. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
  133. try:
  134. val = counter_store.decr(cid)
  135. except KeyError:
  136. raise HTTPException(status_code=404, detail="Counter not found")
  137. return Value(counter=cid, value=val)
  138. @app.websocket("/v1/ws/counter/{cid}")
  139. async def websocket_counter(websocket: WebSocket, cid: int):
  140. await websocket.accept()
  141. q: Queue = Queue()
  142. manager.subscribe(cid, q)
  143. while True:
  144. try:
  145. val = counter_store.get(cid)
  146. await websocket.send_json(jsonable_encoder(Value(counter=cid, value=val)))
  147. except WebSocketDisconnect:
  148. logging.debug("client disconnected")
  149. manager.unsubscribe(cid, q)
  150. return
  151. except Exception:
  152. logging.exception("unexpected error")
  153. manager.unsubscribe(cid, q)
  154. return
  155. await q.get()
  156. async def get_page(fname):
  157. with open(fname) as f:
  158. content = f.read()
  159. content_type, _ = guess_type(fname)
  160. return Response(content, media_type=content_type)
  161. @app.get("/")
  162. async def root_page():
  163. return await get_page("pages/index.html")
  164. @app.get("/prenota")
  165. async def prenota_page():
  166. return await get_page("pages/prenotati.html")