seriow/webserver.py

123 lines
3.3 KiB
Python
Raw Normal View History

2023-07-01 18:12:38 +02:00
#!/usr/bin/env python3
2023-07-01 14:23:33 +02:00
import secrets
from uuid import UUID, uuid4
import time
from typing import Optional, Callable, Any
from pydantic import BaseModel, BaseSettings, Field
from fastapi import FastAPI, Depends, HTTPException, status, Response
2023-07-01 14:23:33 +02:00
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.middleware.cors import CORSMiddleware
class VariableModel(BaseModel):
key: str
value: int
class AllVariablesModel(BaseModel):
variables: dict[str, int]
class MessageModel(BaseModel):
message: str
level: int = 0
id: UUID = Field(default_factory=uuid4)
timestamp: float = Field(default_factory=time.time)
class Settings(BaseSettings):
app_name: str = "Squeow"
serial_password: str = "hackme"
variables: dict[str, int] = {}
messages_length: int = 10
messages: list[MessageModel] = []
class Config:
env_file = "pizzicore.env"
def push_message(self, message):
self.messages.append(message)
if len(self.messages) > self.messages_length:
self.messages.pop(0)
app = FastAPI()
settings = Settings()
security = HTTPBasic()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
correct_username = secrets.compare_digest(credentials.username, "serial")
correct_password = secrets.compare_digest(
credentials.password, settings.serial_password
)
if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return "serial"
@app.post("/variables")
async def update_all_variables(
variables: AllVariablesModel, role: str = Depends(get_current_role)
) -> None:
if role != "serial":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
settings.variables.update(variables.variables)
return
@app.post("/messages")
async def push_message(
message: MessageModel, role: str = Depends(get_current_role)
) -> None:
if role != "serial":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
settings.push_message(message)
return
@app.get("/variables")
async def get_all_variables() -> AllVariablesModel:
return settings.variables
def first_matching(lst: list, condition: Callable[[Any], bool]) -> int:
"""return the index of the first item that matches condition"""
for i, elem in enumerate(lst):
if condition(elem):
return i
return None
@app.get("/messages")
async def get_all_messages(from_id: Optional[UUID] = None) -> list[MessageModel]:
messages = settings.messages
if from_id is not None:
match = first_matching(messages, lambda x: x.id == from_id)
# if match is not found, we assume that the referred id is very old, so all messages are relevant
if match is not None:
messages = messages[match:]
return messages
@app.get("/metrics")
async def export_prometheus() -> str:
text = ''.join(
f"{key}={value}\n" for key, value in settings.variables.items()
)
return Response(text, media_type="text/plain")