pizzicore.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import secrets
  2. import logging
  3. import dbm
  4. from collections import defaultdict
  5. from mimetypes import guess_type
  6. from asyncio.queues import Queue
  7. from fastapi import FastAPI, WebSocket, HTTPException, Depends, status, Response
  8. from starlette.websockets import WebSocketDisconnect
  9. from fastapi.security import HTTPBasic, HTTPBasicCredentials
  10. from fastapi.encoders import jsonable_encoder
  11. from fastapi.staticfiles import StaticFiles
  12. from pydantic import BaseModel
  13. class BaseStore:
  14. def __init__(self, n: int):
  15. self.values = {i: 0 for i in range(n)}
  16. def get(self, key) -> int:
  17. return self.values[key]
  18. def incr(self, key) -> int:
  19. newval = self.get(key) + 1
  20. return self.set(key, newval)
  21. def set(self, key, value: int) -> int:
  22. self.values[key] = value
  23. return value
  24. class PersistStoreMixin:
  25. def __init__(self, *args, **kwargs):
  26. self.db_path = kwargs.pop("db_path")
  27. super().__init__(*args, **kwargs)
  28. with dbm.open(self.db_path, "c") as db:
  29. for key in db.keys():
  30. self.values[int(key)] = int(db[key])
  31. def set(self, key, value: int) -> int:
  32. ret = super().set(key, value)
  33. with dbm.open(self.db_path, "w") as db:
  34. db[str(key)] = str(value)
  35. return ret
  36. class SignalStoreMixin:
  37. """make any BaseStore manager-aware"""
  38. def __init__(self, *args, **kwargs):
  39. manager = kwargs.pop("manager")
  40. super().__init__(*args, **kwargs)
  41. self.manager = manager
  42. def set(self, key, value: int) -> int:
  43. ret = super().set(key, value)
  44. self.manager.notify(key, value)
  45. return ret
  46. class Store(SignalStoreMixin, PersistStoreMixin, BaseStore):
  47. pass
  48. class Manager:
  49. """Handle notifications logic (for websocket)."""
  50. def __init__(self):
  51. self.registry = defaultdict(list)
  52. def subscribe(self, key, q: Queue):
  53. self.registry[key].append(q)
  54. def unsubscribe(self, key, q: Queue):
  55. try:
  56. self.registry[key].remove(q)
  57. except ValueError: # not found
  58. pass
  59. def notify(self, key, val):
  60. for queue in self.registry[key]:
  61. queue.put_nowait(val)
  62. app = FastAPI()
  63. app.mount("/static", StaticFiles(directory="static"), name="static")
  64. manager = Manager()
  65. counter_store = Store(
  66. n=1, db_path="/var/lib/pizzicore/pizzicore.dbm", manager=manager
  67. ) # XXX: pesca da file di conf
  68. security = HTTPBasic()
  69. class Value(BaseModel):
  70. counter: int
  71. value: int
  72. def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
  73. # XXX: read user/pass from config
  74. correct_username = secrets.compare_digest(credentials.username, "avanti")
  75. correct_password = secrets.compare_digest(credentials.password, "prossimo")
  76. if not (correct_username and correct_password):
  77. raise HTTPException(
  78. status_code=status.HTTP_401_UNAUTHORIZED,
  79. detail="Incorrect username or password",
  80. headers={"WWW-Authenticate": "Basic"},
  81. )
  82. return "admin"
  83. @app.get("/v1/counter/{cid}")
  84. async def get_value(cid: int):
  85. try:
  86. val = counter_store.get(cid)
  87. except KeyError:
  88. raise HTTPException(status_code=404, detail="Item not found")
  89. else:
  90. return Value(counter=cid, value=val)
  91. @app.post("/v1/counter/{cid}/increment")
  92. async def increment(cid: int, role: str = Depends(get_current_role)):
  93. if role != "admin":
  94. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
  95. val = counter_store.incr(cid)
  96. return Value(counter=cid, value=val)
  97. @app.websocket("/v1/ws/counter/{cid}")
  98. async def websocket_counter(websocket: WebSocket, cid: int):
  99. await websocket.accept()
  100. q: Queue = Queue()
  101. manager.subscribe(cid, q)
  102. while True:
  103. try:
  104. val = counter_store.get(cid)
  105. await websocket.send_json(jsonable_encoder(Value(counter=cid, value=val)))
  106. except WebSocketDisconnect:
  107. logging.debug("client disconnected")
  108. manager.unsubscribe(cid, q)
  109. return
  110. except Exception:
  111. logging.exception("unexpected error")
  112. manager.unsubscribe(cid, q)
  113. return
  114. await q.get()
  115. async def get_page(fname):
  116. with open(fname) as f:
  117. content = f.read()
  118. content_type, _ = guess_type(fname)
  119. return Response(content, media_type=content_type)
  120. @app.get("/")
  121. async def root_page():
  122. return await get_page("pages/index.html")