pizzicore.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import secrets
  2. from collections import defaultdict
  3. from asyncio.queues import Queue
  4. from fastapi import FastAPI, WebSocket, HTTPException, Depends, status
  5. from fastapi.security import HTTPBasic, HTTPBasicCredentials
  6. from pydantic import BaseModel
  7. class Store:
  8. def __init__(self, n: int):
  9. self.values = {i: 0 for i in range(n)}
  10. def get(self, key):
  11. return self.values[key]
  12. def incr(self, key):
  13. return self.set(key, self.get(key) + 1)
  14. def set(self, key, value):
  15. self.values[key] = value
  16. return value
  17. class SignalStore(Store):
  18. def __init__(self, *args, **kwargs):
  19. super().__init__(*args, **kwargs)
  20. self.registry = defaultdict(list)
  21. def subscribe(self, key, q):
  22. self.registry[key].append(q)
  23. def _notify(self, key):
  24. for queue in self.registry[key]:
  25. queue.put_nowait(self.get(key))
  26. def set(self, key, value):
  27. super().set(key, value)
  28. self._notify(key)
  29. app = FastAPI()
  30. counter_store = Store(n=1) # XXX: pesca da file di conf
  31. security = HTTPBasic()
  32. class Value(BaseModel):
  33. counter: int
  34. value: int
  35. def get_current_username(credentials: HTTPBasicCredentials = Depends(security)):
  36. # XXX: read user/pass from config
  37. correct_username = secrets.compare_digest(credentials.username, "avanti")
  38. correct_password = secrets.compare_digest(credentials.password, "prossimo")
  39. if not (correct_username and correct_password):
  40. raise HTTPException(
  41. status_code=status.HTTP_401_UNAUTHORIZED,
  42. detail="Incorrect username or password",
  43. headers={"WWW-Authenticate": "Basic"},
  44. )
  45. return "admin"
  46. @app.get("/counter/{cid}")
  47. async def get_value(cid: int):
  48. try:
  49. val = counter_store.get(cid)
  50. except KeyError:
  51. raise HTTPException(status_code=404, detail="Item not found")
  52. return Value(counter=cid, value=val)
  53. @app.post("/counter/{cid}/increment")
  54. async def increment(cid: int, role: str = Depends(get_current_username)):
  55. if role != "admin":
  56. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
  57. val = counter_store.incr(cid)
  58. return Value(counter=cid, value=val)
  59. @app.websocket("/ws/counter/{cid}")
  60. async def websocket_counter(websocket: WebSocket, cid: int):
  61. await websocket.accept()
  62. # XXX: subscribe to counter
  63. while True:
  64. # XXX: get notifications
  65. val = 1
  66. await websocket.send_text(str(val))