tresetter.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from pathlib import Path
  2. import string
  3. import random
  4. import logging
  5. import json
  6. import uuid
  7. from subprocess import Popen, CalledProcessError, check_output
  8. from typing import Optional
  9. import hashlib
  10. import base64
  11. import redis
  12. from fastapi import FastAPI, APIRouter, HTTPException, Cookie, Request
  13. from fastapi.responses import Response, RedirectResponse
  14. from fastapi.staticfiles import StaticFiles
  15. from pydantic import BaseModel, BaseSettings
  16. code_dir = Path(".")
  17. class Settings(BaseSettings):
  18. app_name: str = "tResetter"
  19. validate_login_exe: str
  20. change_password_exe: str
  21. generate_password: Optional[str]
  22. redis_host: str = "localhost"
  23. expire_time: int = 60 * 20
  24. root_path: Optional[str]
  25. root_prefix: str = "/tresetter"
  26. class Config:
  27. pass
  28. @property
  29. def redis_params(self):
  30. return {"host": self.redis_host}
  31. settings = Settings()
  32. kwargs = {}
  33. if settings.root_path:
  34. kwargs["root_path"] = settings.root_path
  35. app = FastAPI(**kwargs)
  36. router = APIRouter(prefix=settings.root_prefix)
  37. app.mount(
  38. settings.root_prefix + "/s",
  39. StaticFiles(directory=str(code_dir / "static")),
  40. name="static",
  41. )
  42. logger = logging.getLogger("server")
  43. # Session {{{
  44. class SessionNotFoundException(HTTPException):
  45. def __init__(self):
  46. super().__init__(status_code=401)
  47. def create_session(content) -> str:
  48. session_id = str(uuid.uuid4())
  49. set_session(session_id, content)
  50. return session_id
  51. def set_session(session_id: str, content) -> bool:
  52. r = redis.Redis(**settings.redis_params)
  53. serialized = json.dumps(content)
  54. ret = r.set(session_id, serialized)
  55. r.expire(session_id, settings.expire_time)
  56. def get_session(session_id: str, renew=True):
  57. if session_id is None:
  58. raise SessionNotFoundException()
  59. r = redis.Redis(**settings.redis_params)
  60. serialized = r.get(session_id)
  61. logger.error("%s", repr(serialized))
  62. if serialized is None:
  63. raise SessionNotFoundException()
  64. if renew:
  65. r.expire(session_id, settings.expire_time)
  66. return json.loads(serialized)
  67. def delete_session(session_id: str):
  68. r = redis.Redis(**settings.redis_params)
  69. r.delete(session_id)
  70. # end session }}}
  71. # Models {{{
  72. class UserData(BaseModel):
  73. username: str
  74. class ChangeData(BaseModel):
  75. password: str
  76. class LoginData(UserData, ChangeData):
  77. pass
  78. class SuccessData(BaseModel):
  79. success: bool = True
  80. # end Models }}}
  81. # external commands {{{
  82. def validate(username, password):
  83. p = Popen(
  84. [settings.validate_login_exe],
  85. env={"VERIRY_USERNAME": username, "VERIFY_PASSWORD": password},
  86. )
  87. try:
  88. p.communicate()
  89. except CalledProcessError:
  90. return False
  91. return p.returncode == 0
  92. def password_generate() -> str:
  93. if settings.generate_password:
  94. s = check_output([settings.generate_password], encoding='utf8')
  95. assert type(s) is str
  96. return s.strip()
  97. else:
  98. symbols = list(string.ascii_lowercase) + list(string.digits)
  99. return "".join(random.choices(symbols, k=10))
  100. def change_password(username: str, new_password: str) -> bool:
  101. p = Popen(
  102. [settings.change_password_exe],
  103. env={"CHANGE_USERNAME": username, "CHANGE_PASSWORD": new_password},
  104. )
  105. try:
  106. p.communicate()
  107. except CalledProcessError:
  108. return False
  109. if p.returncode != 0:
  110. return False
  111. return True
  112. # end external commands }}}
  113. @router.get("/")
  114. async def home(request: Request, session_id: str = Cookie(None)):
  115. """redirects to the user to the home"""
  116. try:
  117. get_session(session_id)
  118. except SessionNotFoundException:
  119. return RedirectResponse(app.url_path_for("static", path="login.html"))
  120. return RedirectResponse(app.url_path_for("static", path="change.html"))
  121. @router.post("/login", tags=["session"])
  122. async def login(req: LoginData):
  123. """
  124. performs login
  125. """
  126. ok = validate(req.username, req.password)
  127. if not ok:
  128. raise HTTPException(status_code=401, detail="Authentication error")
  129. session_id = create_session(
  130. {
  131. "username": req.username,
  132. }
  133. )
  134. response = Response()
  135. response.set_cookie(key="session_id", value=session_id)
  136. return response
  137. @router.get("/whoami", tags=["session"])
  138. async def whoami(session_id: str = Cookie(None)):
  139. """Confirm login information"""
  140. session = get_session(session_id)
  141. return UserData(username=session["username"])
  142. @router.post("/logout", tags=["session"])
  143. async def logout(session_id: str = Cookie(None)) -> BaseModel:
  144. get_session(session_id)
  145. delete_session(session_id)
  146. return BaseModel()
  147. KDF_SALT_SIZE = 16
  148. def kdf_gen(password, salt=None) -> str:
  149. if salt is None:
  150. salt = random.randbytes(KDF_SALT_SIZE)
  151. if hasattr(password, 'encode'):
  152. password = password.encode('utf8')
  153. raw = hashlib.scrypt(password, n=2, r=1, p=1, salt=salt)
  154. with_salt = salt + raw
  155. return base64.b64encode(with_salt).decode('ascii')
  156. def kdf_get_salt(hashed: str):
  157. hashed_str = hashed.decode('ascii') if hasattr(hashed, 'decode') else hashed
  158. with_salt = base64.b64decode(hashed_str)
  159. salt = with_salt[:KDF_SALT_SIZE]
  160. return salt
  161. def kdf_verify(hashed: str, password: str) -> bool:
  162. salt = kdf_get_salt(hashed)
  163. hashed2 = kdf_gen(password, salt=salt)
  164. return hashed == hashed2
  165. @router.post("/generate", tags=["password"])
  166. async def generate(session_id: str = Cookie(None)):
  167. session = get_session(session_id)
  168. proposed_password = password_generate()
  169. session["proposed_password_hash"] = kdf_gen(proposed_password)
  170. set_session(session_id, session)
  171. return ChangeData(password=proposed_password)
  172. @router.post("/change", tags=["password"])
  173. async def change(req: ChangeData, session_id: str = Cookie(None)) -> SuccessData:
  174. session = get_session(session_id)
  175. if "proposed_password_hash" not in session:
  176. raise HTTPException(status_code=400, detail="You must generate it first")
  177. hashed = session["proposed_password_hash"]
  178. if not kdf_verify(hashed, req.password):
  179. raise HTTPException(status_code=409)
  180. delete_session(session_id)
  181. success = change_password(session["username"], req.password)
  182. return SuccessData(success=success)
  183. app.include_router(router)
  184. # vim: set fdm=marker: