webserver.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. #!/usr/bin/env python3
  2. import secrets
  3. from uuid import UUID, uuid4
  4. import time
  5. from typing import Optional, Callable, Any
  6. import datetime
  7. from pydantic import BaseModel, BaseSettings, Field
  8. from fastapi import FastAPI, Depends, HTTPException, status, Response, Request
  9. from fastapi.responses import HTMLResponse
  10. from fastapi.security import HTTPBasic, HTTPBasicCredentials
  11. from fastapi.middleware.cors import CORSMiddleware
  12. from fastapi.templating import Jinja2Templates
  13. class VariableModel(BaseModel):
  14. key: str
  15. value: int
  16. class AllVariablesModel(BaseModel):
  17. variables: dict[str, int]
  18. class MessageModel(BaseModel):
  19. message: str
  20. level: int = 0
  21. id: UUID = Field(default_factory=uuid4)
  22. timestamp: float = Field(default_factory=time.time)
  23. class Settings(BaseSettings):
  24. app_name: str = "Squeow"
  25. serial_password: str = "hackme"
  26. variables: dict[str, int] = {}
  27. messages_length: int = 10
  28. messages: list[MessageModel] = []
  29. last_message: datetime.datetime = datetime.datetime.now()
  30. class Config:
  31. env_file = "pizzicore.env"
  32. def push_message(self, message):
  33. self.messages.append(message)
  34. if len(self.messages) > self.messages_length:
  35. self.messages.pop(0)
  36. def update_last_message(self):
  37. self.last_message = datetime.datetime.now()
  38. app = FastAPI()
  39. settings = Settings()
  40. security = HTTPBasic()
  41. templates = Jinja2Templates(directory="templates")
  42. app.add_middleware(
  43. CORSMiddleware,
  44. allow_origins=["*"],
  45. allow_credentials=True,
  46. allow_methods=["*"],
  47. allow_headers=["*"],
  48. )
  49. def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
  50. correct_username = secrets.compare_digest(credentials.username, "serial")
  51. correct_password = secrets.compare_digest(
  52. credentials.password, settings.serial_password
  53. )
  54. if not (correct_username and correct_password):
  55. raise HTTPException(
  56. status_code=status.HTTP_401_UNAUTHORIZED,
  57. detail="Incorrect username or password",
  58. headers={"WWW-Authenticate": "Basic"},
  59. )
  60. return "serial"
  61. @app.post("/variables")
  62. async def update_all_variables(
  63. variables: AllVariablesModel, role: str = Depends(get_current_role)
  64. ) -> None:
  65. if role != "serial":
  66. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
  67. settings.update_last_message()
  68. settings.variables.update(variables.variables)
  69. return
  70. @app.post("/messages")
  71. async def push_message(
  72. message: MessageModel, role: str = Depends(get_current_role)
  73. ) -> None:
  74. if role != "serial":
  75. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
  76. settings.update_last_message()
  77. settings.push_message(message)
  78. return
  79. @app.get("/variables")
  80. async def get_all_variables() -> AllVariablesModel:
  81. return AllVariablesModel(
  82. variables=settings.variables
  83. )
  84. @app.get("/variables/{key}")
  85. async def get_variable(key: str) -> VariableModel:
  86. try:
  87. value = settings.variables[key]
  88. except KeyError:
  89. raise HTTPException(status_code=404, detail="Variable not found")
  90. return Response(str(value), media_type="text/plain")
  91. def first_matching(lst: list, condition: Callable[[Any], bool]) -> int:
  92. """return the index of the first item that matches condition"""
  93. for i, elem in enumerate(lst):
  94. if condition(elem):
  95. return i
  96. return None
  97. @app.get("/messages")
  98. async def get_all_messages(from_id: Optional[UUID] = None) -> list[MessageModel]:
  99. messages = settings.messages
  100. if from_id is not None:
  101. match = first_matching(messages, lambda x: x.id == from_id)
  102. # if match is not found, we assume that the referred id is very old, so all messages are relevant
  103. if match is not None:
  104. messages = messages[match:]
  105. return messages
  106. @app.get("/metrics")
  107. async def export_prometheus() -> str:
  108. variables: list[tuple[str, int]] = [
  109. (f"squeow_var_{key}", value) for key, value in settings.variables.items()
  110. ]
  111. variables.append(("squeow_variables_count", len(settings.variables)))
  112. time_since_last_seen = (
  113. datetime.datetime.now() - settings.last_message
  114. ).total_seconds()
  115. variables.append(("squeow_time_since_last_seen", int(time_since_last_seen)))
  116. text = "".join(f"{k} {v}\n" for k, v in variables)
  117. return Response(text, media_type="text/plain")
  118. @app.get("/", response_class=HTMLResponse)
  119. async def html_index(request: Request):
  120. autorefresh = request.query_params.get('refresh') == '1'
  121. return templates.TemplateResponse("index.html",
  122. dict(request=request,
  123. autorefresh=autorefresh,
  124. variables=settings.variables))