tresetter/tresetter.py
2022-05-20 19:17:35 +02:00

257 lines
6.2 KiB
Python

from pathlib import Path
import string
import random
import logging
import json
import uuid
from subprocess import Popen, CalledProcessError, check_output
from typing import Optional
import hashlib
import base64
import redis
from fastapi import FastAPI, APIRouter, HTTPException, Cookie, Request
from fastapi.responses import Response, RedirectResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, BaseSettings
code_dir = Path(".")
class Settings(BaseSettings):
app_name: str = "tResetter"
validate_login_exe: str
change_password_exe: str
generate_password: Optional[str]
redis_host: str = "localhost"
expire_time: int = 60 * 20
root_path: Optional[str]
root_prefix: str = "/tresetter"
class Config:
pass
@property
def redis_params(self):
return {"host": self.redis_host}
settings = Settings()
kwargs = {}
if settings.root_path:
kwargs["root_path"] = settings.root_path
app = FastAPI(**kwargs)
router = APIRouter(prefix=settings.root_prefix)
app.mount(
settings.root_prefix + "/s",
StaticFiles(directory=str(code_dir / "static")),
name="static",
)
logger = logging.getLogger("server")
# Session {{{
class SessionNotFoundException(HTTPException):
def __init__(self):
super().__init__(status_code=401)
def create_session(content) -> str:
session_id = str(uuid.uuid4())
set_session(session_id, content)
return session_id
def set_session(session_id: str, content) -> bool:
r = redis.Redis(**settings.redis_params)
serialized = json.dumps(content)
ret = r.set(session_id, serialized)
r.expire(session_id, settings.expire_time)
def get_session(session_id: str, renew=True):
if session_id is None:
raise SessionNotFoundException()
r = redis.Redis(**settings.redis_params)
serialized = r.get(session_id)
logger.error("%s", repr(serialized))
if serialized is None:
raise SessionNotFoundException()
if renew:
r.expire(session_id, settings.expire_time)
return json.loads(serialized)
def delete_session(session_id: str):
r = redis.Redis(**settings.redis_params)
r.delete(session_id)
# end session }}}
# Models {{{
class UserData(BaseModel):
username: str
class ChangeData(BaseModel):
password: str
class LoginData(UserData, ChangeData):
pass
class SuccessData(BaseModel):
success: bool = True
# end Models }}}
# external commands {{{
def validate(username, password):
p = Popen(
[settings.validate_login_exe],
env={"VERIRY_USERNAME": username, "VERIFY_PASSWORD": password},
)
try:
p.communicate()
except CalledProcessError:
return False
return p.returncode == 0
def password_generate() -> str:
if settings.generate_password:
s = check_output([settings.generate_password], encoding='utf8')
assert type(s) is str
return s.strip()
else:
symbols = list(string.ascii_lowercase) + list(string.digits)
return "".join(random.choices(symbols, k=10))
def change_password(username: str, new_password: str) -> bool:
p = Popen(
[settings.change_password_exe],
env={"CHANGE_USERNAME": username, "CHANGE_PASSWORD": new_password},
)
try:
p.communicate()
except CalledProcessError:
return False
if p.returncode != 0:
return False
return True
# end external commands }}}
@router.get("/")
async def home(request: Request, session_id: str = Cookie(None)):
"""redirects to the user to the home"""
try:
get_session(session_id)
except SessionNotFoundException:
return RedirectResponse(app.url_path_for("static", path="login.html"))
return RedirectResponse(app.url_path_for("static", path="change.html"))
@router.post("/login", tags=["session"])
async def login(req: LoginData):
"""
performs login
"""
ok = validate(req.username, req.password)
if not ok:
raise HTTPException(status_code=401, detail="Authentication error")
session_id = create_session(
{
"username": req.username,
}
)
response = Response()
response.set_cookie(key="session_id", value=session_id)
return response
@router.get("/whoami", tags=["session"])
async def whoami(session_id: str = Cookie(None)):
"""Confirm login information"""
session = get_session(session_id)
return UserData(username=session["username"])
@router.post("/logout", tags=["session"])
async def logout(session_id: str = Cookie(None)) -> BaseModel:
get_session(session_id)
delete_session(session_id)
return BaseModel()
KDF_SALT_SIZE = 16
def kdf_gen(password, salt=None) -> str:
if salt is None:
salt = random.randbytes(KDF_SALT_SIZE)
if hasattr(password, 'encode'):
password = password.encode('utf8')
raw = hashlib.scrypt(password, n=2, r=1, p=1, salt=salt)
with_salt = salt + raw
return base64.b64encode(with_salt).decode('ascii')
def kdf_get_salt(hashed: str):
hashed_str = hashed.decode('ascii') if hasattr(hashed, 'decode') else hashed
with_salt = base64.b64decode(hashed_str)
salt = with_salt[:KDF_SALT_SIZE]
return salt
def kdf_verify(hashed: str, password: str) -> bool:
salt = kdf_get_salt(hashed)
hashed2 = kdf_gen(password, salt=salt)
return hashed == hashed2
@router.post("/generate", tags=["password"])
async def generate(session_id: str = Cookie(None)):
session = get_session(session_id)
proposed_password = password_generate()
session["proposed_password_hash"] = kdf_gen(proposed_password)
set_session(session_id, session)
return ChangeData(password=proposed_password)
@router.post("/change", tags=["password"])
async def change(req: ChangeData, session_id: str = Cookie(None)) -> SuccessData:
session = get_session(session_id)
if "proposed_password_hash" not in session:
raise HTTPException(status_code=400, detail="You must generate it first")
hashed = session["proposed_password_hash"]
if not kdf_verify(hashed, req.password):
raise HTTPException(status_code=409)
delete_session(session_id)
success = change_password(session["username"], req.password)
return SuccessData(success=success)
app.include_router(router)
# vim: set fdm=marker: