webserver.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #!/usr/bin/env python3
  2. import secrets
  3. from uuid import UUID, uuid4
  4. import time
  5. from typing import Optional, Callable, Any
  6. import datetime
  7. from pydantic import BaseModel, BaseSettings, Field
  8. from fastapi import FastAPI, Depends, HTTPException, status, Response
  9. from fastapi.security import HTTPBasic, HTTPBasicCredentials
  10. from fastapi.middleware.cors import CORSMiddleware
  11. class VariableModel(BaseModel):
  12. key: str
  13. value: int
  14. class AllVariablesModel(BaseModel):
  15. variables: dict[str, int]
  16. class MessageModel(BaseModel):
  17. message: str
  18. level: int = 0
  19. id: UUID = Field(default_factory=uuid4)
  20. timestamp: float = Field(default_factory=time.time)
  21. class Settings(BaseSettings):
  22. app_name: str = "Squeow"
  23. serial_password: str = "hackme"
  24. variables: dict[str, int] = {}
  25. messages_length: int = 10
  26. messages: list[MessageModel] = []
  27. last_message: datetime.datetime = datetime.datetime.now()
  28. class Config:
  29. env_file = "pizzicore.env"
  30. def push_message(self, message):
  31. self.messages.append(message)
  32. if len(self.messages) > self.messages_length:
  33. self.messages.pop(0)
  34. def update_last_message(self):
  35. self.last_message = datetime.datetime.now()
  36. app = FastAPI()
  37. settings = Settings()
  38. security = HTTPBasic()
  39. app.add_middleware(
  40. CORSMiddleware,
  41. allow_origins=["*"],
  42. allow_credentials=True,
  43. allow_methods=["*"],
  44. allow_headers=["*"],
  45. )
  46. def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
  47. correct_username = secrets.compare_digest(credentials.username, "serial")
  48. correct_password = secrets.compare_digest(
  49. credentials.password, settings.serial_password
  50. )
  51. if not (correct_username and correct_password):
  52. raise HTTPException(
  53. status_code=status.HTTP_401_UNAUTHORIZED,
  54. detail="Incorrect username or password",
  55. headers={"WWW-Authenticate": "Basic"},
  56. )
  57. return "serial"
  58. @app.post("/variables")
  59. async def update_all_variables(
  60. variables: AllVariablesModel, role: str = Depends(get_current_role)
  61. ) -> None:
  62. if role != "serial":
  63. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
  64. settings.update_last_message()
  65. settings.variables.update(variables.variables)
  66. return
  67. @app.post("/messages")
  68. async def push_message(
  69. message: MessageModel, role: str = Depends(get_current_role)
  70. ) -> None:
  71. if role != "serial":
  72. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
  73. settings.update_last_message()
  74. settings.push_message(message)
  75. return
  76. @app.get("/variables")
  77. async def get_all_variables() -> AllVariablesModel:
  78. return settings.variables
  79. @app.get("/variables/{key}")
  80. async def get_variable(key: str) -> VariableModel:
  81. try:
  82. value = settings.variables[key]
  83. except KeyError:
  84. raise HTTPException(status_code=404, detail="Variable not found")
  85. return Response(str(value), media_type="text/plain")
  86. def first_matching(lst: list, condition: Callable[[Any], bool]) -> int:
  87. """return the index of the first item that matches condition"""
  88. for i, elem in enumerate(lst):
  89. if condition(elem):
  90. return i
  91. return None
  92. @app.get("/messages")
  93. async def get_all_messages(from_id: Optional[UUID] = None) -> list[MessageModel]:
  94. messages = settings.messages
  95. if from_id is not None:
  96. match = first_matching(messages, lambda x: x.id == from_id)
  97. # if match is not found, we assume that the referred id is very old, so all messages are relevant
  98. if match is not None:
  99. messages = messages[match:]
  100. return messages
  101. @app.get("/metrics")
  102. async def export_prometheus() -> str:
  103. variables: list[tuple[str, int]] = [
  104. (f"squeow_var_{key}", value) for key, value in settings.variables.items()
  105. ]
  106. variables.append(("squeow_variables_count", len(settings.variables)))
  107. time_since_last_seen = (
  108. datetime.datetime.now() - settings.last_message
  109. ).total_seconds()
  110. variables.append(("squeow_time_since_last_seen", int(time_since_last_seen)))
  111. text = "".join(f"{k} {v}\n" for k, v in variables)
  112. return Response(text, media_type="text/plain")