|
@@ -0,0 +1,89 @@
|
|
|
+import secrets
|
|
|
+from collections import defaultdict
|
|
|
+
|
|
|
+from asyncio.queues import Queue
|
|
|
+from fastapi import FastAPI, WebSocket, HTTPException, Depends, status
|
|
|
+from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|
|
+from pydantic import BaseModel
|
|
|
+
|
|
|
+
|
|
|
+class Store:
|
|
|
+ def __init__(self, n: int):
|
|
|
+ self.values = {i: 0 for i in range(n)}
|
|
|
+
|
|
|
+ def get(self, key):
|
|
|
+ return self.values[key]
|
|
|
+
|
|
|
+ def incr(self, key):
|
|
|
+ return self.set(key, self.get(key) + 1)
|
|
|
+
|
|
|
+ def set(self, key, value):
|
|
|
+ self.values[key] = value
|
|
|
+ return value
|
|
|
+
|
|
|
+
|
|
|
+class SignalStore(Store):
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+ self.registry = defaultdict(list)
|
|
|
+
|
|
|
+ def subscribe(self, key, q):
|
|
|
+ self.registry[key].append(q)
|
|
|
+
|
|
|
+ def _notify(self, key):
|
|
|
+ for queue in self.registry[key]:
|
|
|
+ queue.put_nowait(self.get(key))
|
|
|
+
|
|
|
+ def set(self, key, value):
|
|
|
+ super().set(key, value)
|
|
|
+ self._notify(key)
|
|
|
+
|
|
|
+
|
|
|
+app = FastAPI()
|
|
|
+counter_store = Store(n=1) # XXX: pesca da file di conf
|
|
|
+security = HTTPBasic()
|
|
|
+
|
|
|
+
|
|
|
+class Value(BaseModel):
|
|
|
+ counter: int
|
|
|
+ value: int
|
|
|
+
|
|
|
+
|
|
|
+def get_current_username(credentials: HTTPBasicCredentials = Depends(security)):
|
|
|
+ # XXX: read user/pass from config
|
|
|
+ correct_username = secrets.compare_digest(credentials.username, "avanti")
|
|
|
+ correct_password = secrets.compare_digest(credentials.password, "prossimo")
|
|
|
+ if not (correct_username and correct_password):
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
+ detail="Incorrect username or password",
|
|
|
+ headers={"WWW-Authenticate": "Basic"},
|
|
|
+ )
|
|
|
+ return "admin"
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/counter/{cid}")
|
|
|
+async def get_value(cid: int):
|
|
|
+ try:
|
|
|
+ val = counter_store.get(cid)
|
|
|
+ except KeyError:
|
|
|
+ raise HTTPException(status_code=404, detail="Item not found")
|
|
|
+ return Value(counter=cid, value=val)
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/counter/{cid}/increment")
|
|
|
+async def increment(cid: int, role: str = Depends(get_current_username)):
|
|
|
+ if role != "admin":
|
|
|
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
|
|
+ val = counter_store.incr(cid)
|
|
|
+ return Value(counter=cid, value=val)
|
|
|
+
|
|
|
+
|
|
|
+@app.websocket("/ws/counter/{cid}")
|
|
|
+async def websocket_counter(websocket: WebSocket, cid: int):
|
|
|
+ await websocket.accept()
|
|
|
+ # XXX: subscribe to counter
|
|
|
+ while True:
|
|
|
+ # XXX: get notifications
|
|
|
+ val = 1
|
|
|
+ await websocket.send_text(str(val))
|