pizzicore.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import secrets
  2. import logging
  3. import dbm
  4. from collections import defaultdict
  5. from asyncio.queues import Queue
  6. from fastapi import (
  7. FastAPI,
  8. WebSocket,
  9. HTTPException,
  10. Depends,
  11. status,
  12. )
  13. from starlette.websockets import WebSocketDisconnect
  14. from fastapi.security import HTTPBasic, HTTPBasicCredentials
  15. from pydantic import BaseModel
  16. class BaseStore:
  17. def __init__(self, n: int):
  18. self.values = {i: 0 for i in range(n)}
  19. def get(self, key) -> int:
  20. return self.values[key]
  21. def incr(self, key) -> int:
  22. newval = self.get(key) + 1
  23. return self.set(key, newval)
  24. def set(self, key, value: int) -> int:
  25. self.values[key] = value
  26. return value
  27. class PersistStoreMixin:
  28. def __init__(self, *args, **kwargs):
  29. self.db_path = kwargs.pop("db_path")
  30. super().__init__(*args, **kwargs)
  31. with dbm.open(self.db_path, "c") as db:
  32. for key in db.keys():
  33. self.values[int(key)] = int(db[key])
  34. def set(self, key, value: int) -> int:
  35. ret = super().set(key, value)
  36. with dbm.open(self.db_path, "w") as db:
  37. db[str(key)] = str(value)
  38. return ret
  39. class SignalStoreMixin:
  40. """make any BaseStore manager-aware"""
  41. def __init__(self, *args, **kwargs):
  42. manager = kwargs.pop("manager")
  43. super().__init__(*args, **kwargs)
  44. self.manager = manager
  45. def set(self, key, value: int) -> int:
  46. ret = super().set(key, value)
  47. self.manager.notify(key, value)
  48. return ret
  49. class Store(SignalStoreMixin, PersistStoreMixin, BaseStore):
  50. pass
  51. class Manager:
  52. """Handle notifications logic (for websocket)."""
  53. def __init__(self):
  54. self.registry = defaultdict(list)
  55. def subscribe(self, key, q: Queue):
  56. self.registry[key].append(q)
  57. def unsubscribe(self, key, q: Queue):
  58. try:
  59. self.registry[key].remove(q)
  60. except ValueError: # not found
  61. pass
  62. def notify(self, key, val):
  63. for queue in self.registry[key]:
  64. queue.put_nowait(val)
  65. app = FastAPI()
  66. manager = Manager()
  67. counter_store = Store(
  68. n=1, db_path="/var/lib/pizzicore/pizzicore.dbm", manager=manager
  69. ) # XXX: pesca da file di conf
  70. security = HTTPBasic()
  71. class Value(BaseModel):
  72. counter: int
  73. value: int
  74. def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
  75. # XXX: read user/pass from config
  76. correct_username = secrets.compare_digest(credentials.username, "avanti")
  77. correct_password = secrets.compare_digest(credentials.password, "prossimo")
  78. if not (correct_username and correct_password):
  79. raise HTTPException(
  80. status_code=status.HTTP_401_UNAUTHORIZED,
  81. detail="Incorrect username or password",
  82. headers={"WWW-Authenticate": "Basic"},
  83. )
  84. return "admin"
  85. @app.get("/v1/counter/{cid}")
  86. async def get_value(cid: int):
  87. try:
  88. val = counter_store.get(cid)
  89. except KeyError:
  90. raise HTTPException(status_code=404, detail="Item not found")
  91. else:
  92. return Value(counter=cid, value=val)
  93. @app.post("/v1/counter/{cid}/increment")
  94. async def increment(cid: int, role: str = Depends(get_current_role)):
  95. if role != "admin":
  96. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
  97. val = counter_store.incr(cid)
  98. return Value(counter=cid, value=val)
  99. @app.websocket("/v1/ws/counter/{cid}")
  100. async def websocket_counter(websocket: WebSocket, cid: int):
  101. await websocket.accept()
  102. q: Queue = Queue()
  103. manager.subscribe(cid, q)
  104. while True:
  105. try:
  106. val = counter_store.get(cid)
  107. await websocket.send_json(jsonable_encoder(Value(counter=cid, value=val)))
  108. except WebSocketDisconnect:
  109. logging.debug("client disconnected")
  110. manager.unsubscribe(cid, q)
  111. return
  112. except Exception:
  113. logging.exception("unexpected error")
  114. manager.unsubscribe(cid, q)
  115. return
  116. await q.get()