webserver.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import secrets
  2. from uuid import UUID, uuid4
  3. import time
  4. from typing import Optional, Callable, Any
  5. from pydantic import BaseModel, BaseSettings, Field
  6. from fastapi import FastAPI, Depends, HTTPException, status
  7. from fastapi.security import HTTPBasic, HTTPBasicCredentials
  8. from fastapi.middleware.cors import CORSMiddleware
  9. class VariableModel(BaseModel):
  10. key: str
  11. value: int
  12. class AllVariablesModel(BaseModel):
  13. variables: dict[str, int]
  14. class MessageModel(BaseModel):
  15. message: str
  16. level: int = 0
  17. id: UUID = Field(default_factory=uuid4)
  18. timestamp: float = Field(default_factory=time.time)
  19. class Settings(BaseSettings):
  20. app_name: str = "Squeow"
  21. serial_password: str = "hackme"
  22. variables: dict[str, int] = {}
  23. messages_length: int = 10
  24. messages: list[MessageModel] = []
  25. class Config:
  26. env_file = "pizzicore.env"
  27. def push_message(self, message):
  28. self.messages.append(message)
  29. if len(self.messages) > self.messages_length:
  30. self.messages.pop(0)
  31. app = FastAPI()
  32. settings = Settings()
  33. security = HTTPBasic()
  34. app.add_middleware(
  35. CORSMiddleware,
  36. allow_origins=["*"],
  37. allow_credentials=True,
  38. allow_methods=["*"],
  39. allow_headers=["*"],
  40. )
  41. def get_current_role(credentials: HTTPBasicCredentials = Depends(security)):
  42. correct_username = secrets.compare_digest(credentials.username, "serial")
  43. correct_password = secrets.compare_digest(
  44. credentials.password, settings.serial_password
  45. )
  46. if not (correct_username and correct_password):
  47. raise HTTPException(
  48. status_code=status.HTTP_401_UNAUTHORIZED,
  49. detail="Incorrect username or password",
  50. headers={"WWW-Authenticate": "Basic"},
  51. )
  52. return "serial"
  53. @app.post("/variables")
  54. async def update_all_variables(
  55. variables: AllVariablesModel, role: str = Depends(get_current_role)
  56. ) -> None:
  57. if role != "serial":
  58. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
  59. settings.variables.update(variables.variables)
  60. return
  61. @app.post("/messages")
  62. async def push_message(
  63. message: MessageModel, role: str = Depends(get_current_role)
  64. ) -> None:
  65. if role != "serial":
  66. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
  67. settings.push_message(message)
  68. return
  69. @app.get("/variables")
  70. async def get_all_variables() -> AllVariablesModel:
  71. return settings.variables
  72. def first_matching(lst: list, condition: Callable[[Any], bool]) -> int:
  73. """return the index of the first item that matches condition"""
  74. for i, elem in enumerate(lst):
  75. if condition(elem):
  76. return i
  77. return None
  78. @app.get("/messages")
  79. async def get_all_messages(from_id: Optional[UUID] = None) -> list[MessageModel]:
  80. messages = settings.messages
  81. if from_id is not None:
  82. match = first_matching(messages, lambda x: x.id == from_id)
  83. # if match is not found, we assume that the referred id is very old, so all messages are relevant
  84. if match is not None:
  85. messages = messages[match:]
  86. return messages