websocket works
This commit is contained in:
parent
3fa81f0ae3
commit
7c9bd6a700
2 changed files with 59 additions and 24 deletions
|
@ -2,45 +2,75 @@ import secrets
|
|||
from collections import defaultdict
|
||||
|
||||
from asyncio.queues import Queue
|
||||
from fastapi import FastAPI, WebSocket, HTTPException, Depends, status
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
WebSocket,
|
||||
HTTPException,
|
||||
Depends,
|
||||
status,
|
||||
WebSocketDisconnect,
|
||||
)
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Store:
|
||||
class BaseStore:
|
||||
def __init__(self, n: int):
|
||||
self.values = {i: 0 for i in range(n)}
|
||||
|
||||
def get(self, key):
|
||||
def get(self, key) -> int:
|
||||
return self.values[key]
|
||||
|
||||
def incr(self, key):
|
||||
return self.set(key, self.get(key) + 1)
|
||||
def incr(self, key) -> int:
|
||||
newval = self.get(key) + 1
|
||||
return self.set(key, newval)
|
||||
|
||||
def set(self, key, value):
|
||||
def set(self, key, value: int) -> int:
|
||||
self.values[key] = value
|
||||
return value
|
||||
|
||||
|
||||
class SignalStore(Store):
|
||||
class SignalStore:
|
||||
"""make any BaseStore manager-aware"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
manager = kwargs.pop("manager")
|
||||
super().__init__(*args, **kwargs)
|
||||
self.manager = manager
|
||||
|
||||
def set(self, key, value: int) -> int:
|
||||
ret = super().set(key, value)
|
||||
self.manager.notify(key, value)
|
||||
return ret
|
||||
|
||||
|
||||
class Store(SignalStore, BaseStore):
|
||||
pass
|
||||
|
||||
|
||||
class Manager:
|
||||
"""Handle notifications logic (for websocket)."""
|
||||
|
||||
def __init__(self):
|
||||
self.registry = defaultdict(list)
|
||||
|
||||
def subscribe(self, key, q):
|
||||
def subscribe(self, key, q: Queue):
|
||||
self.registry[key].append(q)
|
||||
|
||||
def _notify(self, key):
|
||||
for queue in self.registry[key]:
|
||||
queue.put_nowait(self.get(key))
|
||||
def unsubscribe(self, key, q: Queue):
|
||||
try:
|
||||
self.registry[key].remove(q)
|
||||
except ValueError: # not found
|
||||
pass
|
||||
|
||||
def set(self, key, value):
|
||||
super().set(key, value)
|
||||
self._notify(key)
|
||||
def notify(self, key, val):
|
||||
for queue in self.registry[key]:
|
||||
queue.put_nowait(val)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
counter_store = Store(n=1) # XXX: pesca da file di conf
|
||||
manager = Manager()
|
||||
counter_store = Store(n=1, manager=manager) # XXX: pesca da file di conf
|
||||
security = HTTPBasic()
|
||||
|
||||
|
||||
|
@ -49,7 +79,7 @@ class Value(BaseModel):
|
|||
value: int
|
||||
|
||||
|
||||
def get_current_username(credentials: HTTPBasicCredentials = Depends(security)):
|
||||
def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
|
||||
# XXX: read user/pass from config
|
||||
correct_username = secrets.compare_digest(credentials.username, "avanti")
|
||||
correct_password = secrets.compare_digest(credentials.password, "prossimo")
|
||||
|
@ -68,11 +98,12 @@ async def get_value(cid: int):
|
|||
val = counter_store.get(cid)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
return Value(counter=cid, value=val)
|
||||
else:
|
||||
return Value(counter=cid, value=val)
|
||||
|
||||
|
||||
@app.post("/counter/{cid}/increment")
|
||||
async def increment(cid: int, role: str = Depends(get_current_username)):
|
||||
async def increment(cid: int, role: str = Depends(get_current_role)):
|
||||
if role != "admin":
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
val = counter_store.incr(cid)
|
||||
|
@ -82,8 +113,13 @@ async def increment(cid: int, role: str = Depends(get_current_username)):
|
|||
@app.websocket("/ws/counter/{cid}")
|
||||
async def websocket_counter(websocket: WebSocket, cid: int):
|
||||
await websocket.accept()
|
||||
# XXX: subscribe to counter
|
||||
q: Queue = Queue()
|
||||
manager.subscribe(cid, q)
|
||||
|
||||
while True:
|
||||
# XXX: get notifications
|
||||
val = 1
|
||||
await websocket.send_text(str(val))
|
||||
try:
|
||||
await websocket.send_text(str(counter_store.get(cid)))
|
||||
except:
|
||||
manager.unsubscribe(cid, q)
|
||||
return
|
||||
await q.get()
|
||||
|
|
|
@ -1,2 +1 @@
|
|||
fastapi==0.62.0
|
||||
uvicorn==0.13.1
|
||||
fastapi[all]==0.62.0
|
||||
|
|
Loading…
Reference in a new issue