diff --git a/tresetter.py b/tresetter.py index 11b0d84..2958bae 100644 --- a/tresetter.py +++ b/tresetter.py @@ -5,39 +5,57 @@ import logging import json import uuid from subprocess import Popen, CalledProcessError +from typing import Optional import redis -from fastapi import FastAPI, HTTPException, Cookie -from fastapi.responses import JSONResponse +from fastapi import FastAPI, APIRouter, HTTPException, Cookie, Request +from fastapi.responses import JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, BaseSettings code_dir = Path(".") -redis_params = { - "host": "localhost", - # "password": "foobar", -} class Settings(BaseSettings): app_name: str = "tResetter" validate_login_exe: str change_password_exe: str - redis_params: dict = { - "host": "localhost" - } + redis_host: str = "localhost" expire_time: int = 60 * 20 + root_path: Optional[str] + root_prefix: str = "/tresetter" + + @property + def redis_params(self): + return {"host": self.redis_host} settings = Settings() -app = FastAPI() +kwargs = {} +if settings.root_path: + kwargs["root_path"] = settings.root_path +app = FastAPI(**kwargs) -app.mount("/static", StaticFiles(directory=str(code_dir / "static"))) +router = APIRouter(prefix=settings.root_prefix) + +app.mount( + settings.root_prefix + "/s", + StaticFiles(directory=str(code_dir / "static")), + name="static", +) logger = logging.getLogger("server") +# Session {{{ + + +class SessionNotFoundException(HTTPException): + def __init__(self): + super().__init__(status_code=401) + + def create_session(content) -> str: session_id = str(uuid.uuid4()) set_session(session_id, content) @@ -45,7 +63,7 @@ def create_session(content) -> str: def set_session(session_id: str, content) -> bool: - r = redis.Redis(**redis_params) + r = redis.Redis(**settings.redis_params) serialized = json.dumps(content) ret = r.set(session_id, serialized) r.expire(session_id, settings.expire_time) @@ -53,22 +71,27 @@ def set_session(session_id: str, content) -> bool: def get_session(session_id: str, renew=True): if session_id is None: - raise HTTPException(status_code=400) - r = redis.Redis(**redis_params) + raise SessionNotFoundException() + r = redis.Redis(**settings.redis_params) serialized = r.get(session_id) logger.error("%s", repr(serialized)) if serialized is None: - raise HTTPException(status_code=401) + raise SessionNotFoundException() if renew: r.expire(session_id, settings.expire_time) return json.loads(serialized) def delete_session(session_id: str): - r = redis.Redis(**redis_params) + r = redis.Redis(**settings.redis_params) r.delete(session_id) +# end session }}} + +# Models {{{ + + class LoginData(BaseModel): username: str password: str @@ -77,16 +100,15 @@ class LoginData(BaseModel): class ChangeData(BaseModel): password: str + class SuccessData(BaseModel): success: bool = True -@app.get("/") -async def home(): - # XXX: read index.html - return "Ciao!" +# end Models }}} +# external commands {{{ def validate(username, password): p = Popen( [settings.validate_login_exe], @@ -100,7 +122,24 @@ def validate(username, password): return p.returncode == 0 -@app.post("/login") +def password_generate(): + symbols = list(string.ascii_lowercase) + list(string.digits) + return "".join(random.choices(symbols, k=10)) + + +# end external commands }}} + + +@router.get("/") +async def home(request: Request, session_id: str = Cookie(None)): + try: + get_session(session_id) + except SessionNotFoundException: + return RedirectResponse(app.url_path_for("static", path="login.html")) + return RedirectResponse(app.url_path_for("static", path="change.html")) + + +@router.post("/login", tags=["session"]) async def login(req: LoginData): ok = validate(req.username, req.password) @@ -121,25 +160,20 @@ async def login(req: LoginData): return response -@app.get("/whoami") +@router.get("/whoami", tags=["session"]) async def whoami(session_id: str = Cookie(None)): session = get_session(session_id) return JSONResponse(content={"username": session["username"]}) -@app.post("/logout") +@router.post("/logout", tags=["session"]) async def logout(session_id: str = Cookie(None)): get_session(session_id) delete_session(session_id) return "OKI" -def password_generate(): - symbols = list(string.ascii_lowercase) + list(string.digits) - return "".join(random.choices(symbols, k=10)) - - -@app.post("/generate") +@router.post("/generate", tags=["password"]) async def generate(session_id: str = Cookie(None)): session = get_session(session_id) session["proposed_password"] = password_generate() @@ -148,7 +182,7 @@ async def generate(session_id: str = Cookie(None)): return JSONResponse(content={"password": session["proposed_password"]}) -@app.post("/change") +@router.post("/change", tags=["password"]) async def change(req: ChangeData, session_id: str = Cookie(None)): session = get_session(session_id) if "proposed_password" not in session: @@ -171,3 +205,9 @@ async def change(req: ChangeData, session_id: str = Cookie(None)): return SuccessData(success=False) return SuccessData() + + +app.include_router(router) + + +# vim: set fdm=marker: