123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- #!/usr/bin/env python3
- import secrets
- from uuid import UUID, uuid4
- import time
- from typing import Optional, Callable, Any
- import datetime
- from pydantic import BaseModel, BaseSettings, Field
- from fastapi import FastAPI, Depends, HTTPException, status, Response, Request
- from fastapi.responses import HTMLResponse
- from fastapi.security import HTTPBasic, HTTPBasicCredentials
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.templating import Jinja2Templates
- class VariableModel(BaseModel):
- key: str
- value: int
- class AllVariablesModel(BaseModel):
- variables: dict[str, int]
- class MessageModel(BaseModel):
- message: str
- level: int = 0
- id: UUID = Field(default_factory=uuid4)
- timestamp: float = Field(default_factory=time.time)
- class Settings(BaseSettings):
- app_name: str = "Squeow"
- serial_password: str = "hackme"
- variables: dict[str, int] = {}
- messages_length: int = 10
- messages: list[MessageModel] = []
- last_message: datetime.datetime = datetime.datetime.now()
- class Config:
- env_file = "pizzicore.env"
- def push_message(self, message):
- self.messages.append(message)
- if len(self.messages) > self.messages_length:
- self.messages.pop(0)
- def update_last_message(self):
- self.last_message = datetime.datetime.now()
- app = FastAPI()
- settings = Settings()
- security = HTTPBasic()
- templates = Jinja2Templates(directory="templates")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
- correct_username = secrets.compare_digest(credentials.username, "serial")
- correct_password = secrets.compare_digest(
- credentials.password, settings.serial_password
- )
- if not (correct_username and correct_password):
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password",
- headers={"WWW-Authenticate": "Basic"},
- )
- return "serial"
- @app.post("/variables")
- async def update_all_variables(
- variables: AllVariablesModel, role: str = Depends(get_current_role)
- ) -> None:
- if role != "serial":
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
- settings.update_last_message()
- settings.variables.update(variables.variables)
- return
- @app.post("/messages")
- async def push_message(
- message: MessageModel, role: str = Depends(get_current_role)
- ) -> None:
- if role != "serial":
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
- settings.update_last_message()
- settings.push_message(message)
- return
- @app.get("/variables")
- async def get_all_variables() -> AllVariablesModel:
- return AllVariablesModel(
- variables=settings.variables
- )
- @app.get("/variables/{key}")
- async def get_variable(key: str) -> VariableModel:
- try:
- value = settings.variables[key]
- except KeyError:
- raise HTTPException(status_code=404, detail="Variable not found")
- return Response(str(value), media_type="text/plain")
- def first_matching(lst: list, condition: Callable[[Any], bool]) -> int:
- """return the index of the first item that matches condition"""
- for i, elem in enumerate(lst):
- if condition(elem):
- return i
- return None
- @app.get("/messages")
- async def get_all_messages(from_id: Optional[UUID] = None) -> list[MessageModel]:
- messages = settings.messages
- if from_id is not None:
- match = first_matching(messages, lambda x: x.id == from_id)
- # if match is not found, we assume that the referred id is very old, so all messages are relevant
- if match is not None:
- messages = messages[match:]
- return messages
- @app.get("/metrics")
- async def export_prometheus() -> str:
- variables: list[tuple[str, int]] = [
- (f"squeow_var_{key}", value) for key, value in settings.variables.items()
- ]
- variables.append(("squeow_variables_count", len(settings.variables)))
- time_since_last_seen = (
- datetime.datetime.now() - settings.last_message
- ).total_seconds()
- variables.append(("squeow_time_since_last_seen", int(time_since_last_seen)))
- text = "".join(f"{k} {v}\n" for k, v in variables)
- return Response(text, media_type="text/plain")
- @app.get("/", response_class=HTMLResponse)
- async def html_index(request: Request):
- autorefresh = request.query_params.get('refresh') == '1'
- return templates.TemplateResponse("index.html",
- dict(request=request,
- autorefresh=autorefresh,
- variables=settings.variables))
|