Browse Source

root_prefix + some refactor

boyska 2 years ago
parent
commit
c3abe6a284
1 changed files with 70 additions and 30 deletions
  1. 70 30
      tresetter.py

+ 70 - 30
tresetter.py

@@ -5,39 +5,57 @@ import logging
 import json
 import uuid
 from subprocess import Popen, CalledProcessError
+from typing import Optional
 
 import redis
-from fastapi import FastAPI, HTTPException, Cookie
-from fastapi.responses import JSONResponse
+from fastapi import FastAPI, APIRouter, HTTPException, Cookie, Request
+from fastapi.responses import JSONResponse, RedirectResponse
 from fastapi.staticfiles import StaticFiles
 from pydantic import BaseModel, BaseSettings
 
 
 code_dir = Path(".")
-redis_params = {
-    "host": "localhost",
-    # "password": "foobar",
-}
 
 
 class Settings(BaseSettings):
     app_name: str = "tResetter"
     validate_login_exe: str
     change_password_exe: str
-    redis_params: dict = {
-            "host": "localhost"
-            }
+    redis_host: str = "localhost"
     expire_time: int = 60 * 20
+    root_path: Optional[str]
+    root_prefix: str = "/tresetter"
+
+    @property
+    def redis_params(self):
+        return {"host": self.redis_host}
 
 
 settings = Settings()
-app = FastAPI()
+kwargs = {}
+if settings.root_path:
+    kwargs["root_path"] = settings.root_path
+app = FastAPI(**kwargs)
+
+router = APIRouter(prefix=settings.root_prefix)
 
-app.mount("/static", StaticFiles(directory=str(code_dir / "static")))
+app.mount(
+    settings.root_prefix + "/s",
+    StaticFiles(directory=str(code_dir / "static")),
+    name="static",
+)
 
 logger = logging.getLogger("server")
 
 
+# Session {{{
+
+
+class SessionNotFoundException(HTTPException):
+    def __init__(self):
+        super().__init__(status_code=401)
+
+
 def create_session(content) -> str:
     session_id = str(uuid.uuid4())
     set_session(session_id, content)
@@ -45,7 +63,7 @@ def create_session(content) -> str:
 
 
 def set_session(session_id: str, content) -> bool:
-    r = redis.Redis(**redis_params)
+    r = redis.Redis(**settings.redis_params)
     serialized = json.dumps(content)
     ret = r.set(session_id, serialized)
     r.expire(session_id, settings.expire_time)
@@ -53,22 +71,27 @@ def set_session(session_id: str, content) -> bool:
 
 def get_session(session_id: str, renew=True):
     if session_id is None:
-        raise HTTPException(status_code=400)
-    r = redis.Redis(**redis_params)
+        raise SessionNotFoundException()
+    r = redis.Redis(**settings.redis_params)
     serialized = r.get(session_id)
     logger.error("%s", repr(serialized))
     if serialized is None:
-        raise HTTPException(status_code=401)
+        raise SessionNotFoundException()
     if renew:
         r.expire(session_id, settings.expire_time)
     return json.loads(serialized)
 
 
 def delete_session(session_id: str):
-    r = redis.Redis(**redis_params)
+    r = redis.Redis(**settings.redis_params)
     r.delete(session_id)
 
 
+# end session }}}
+
+# Models {{{
+
+
 class LoginData(BaseModel):
     username: str
     password: str
@@ -77,16 +100,15 @@ class LoginData(BaseModel):
 class ChangeData(BaseModel):
     password: str
 
+
 class SuccessData(BaseModel):
     success: bool = True
 
 
-@app.get("/")
-async def home():
-    # XXX: read index.html
-    return "Ciao!"
+# end Models }}}
 
 
+# external commands {{{
 def validate(username, password):
     p = Popen(
         [settings.validate_login_exe],
@@ -100,7 +122,24 @@ def validate(username, password):
     return p.returncode == 0
 
 
-@app.post("/login")
+def password_generate():
+    symbols = list(string.ascii_lowercase) + list(string.digits)
+    return "".join(random.choices(symbols, k=10))
+
+
+# end external commands }}}
+
+
+@router.get("/")
+async def home(request: Request, session_id: str = Cookie(None)):
+    try:
+        get_session(session_id)
+    except SessionNotFoundException:
+        return RedirectResponse(app.url_path_for("static", path="login.html"))
+    return RedirectResponse(app.url_path_for("static", path="change.html"))
+
+
+@router.post("/login", tags=["session"])
 async def login(req: LoginData):
 
     ok = validate(req.username, req.password)
@@ -121,25 +160,20 @@ async def login(req: LoginData):
     return response
 
 
-@app.get("/whoami")
+@router.get("/whoami", tags=["session"])
 async def whoami(session_id: str = Cookie(None)):
     session = get_session(session_id)
     return JSONResponse(content={"username": session["username"]})
 
 
-@app.post("/logout")
+@router.post("/logout", tags=["session"])
 async def logout(session_id: str = Cookie(None)):
     get_session(session_id)
     delete_session(session_id)
     return "OKI"
 
 
-def password_generate():
-    symbols = list(string.ascii_lowercase) + list(string.digits)
-    return "".join(random.choices(symbols, k=10))
-
-
-@app.post("/generate")
+@router.post("/generate", tags=["password"])
 async def generate(session_id: str = Cookie(None)):
     session = get_session(session_id)
     session["proposed_password"] = password_generate()
@@ -148,7 +182,7 @@ async def generate(session_id: str = Cookie(None)):
     return JSONResponse(content={"password": session["proposed_password"]})
 
 
-@app.post("/change")
+@router.post("/change", tags=["password"])
 async def change(req: ChangeData, session_id: str = Cookie(None)):
     session = get_session(session_id)
     if "proposed_password" not in session:
@@ -171,3 +205,9 @@ async def change(req: ChangeData, session_id: str = Cookie(None)):
         return SuccessData(success=False)
 
     return SuccessData()
+
+
+app.include_router(router)
+
+
+# vim: set fdm=marker: