Files
NexHome/auth.py
T
Black-Cyan 3f386e5e38 ✏️
- feature access/refresh tokens auth
2026-06-11 15:59:29 +08:00

284 lines
9.5 KiB
Python

import uuid
from datetime import datetime, timedelta, timezone
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, Form
from fastapi.responses import HTMLResponse, RedirectResponse
from jose import jwt, JWTError
from pydantic import BaseModel
from sqlalchemy.orm import Session
from starlette import status
from config import (
SECRET_KEY, ALGORITHM, templates, bcrypt_context, oauth2_bearer,
ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS,
)
from database import SessionLocal
from models import User, RefreshToken
router = APIRouter(
prefix="/auth",
tags=["auth"]
)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
db_dependency = Annotated[Session, Depends(get_db)]
class Token(BaseModel):
access_token: str
refresh_token: str
token_type: str
def authenticate_user(username: str, password: str, db: Session):
user = db.query(User).filter(User.username == username).first()
if not user:
return False
if not bcrypt_context.verify(password, user.hashed_password):
return False
return user
def create_access_token(username: str, user_id: int) -> str:
expires = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {"sub": username, "id": user_id, "exp": expires, "type": "access"}
return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
def create_refresh_token(username: str, user_id: int, db: Session) -> str:
expires = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
token = jwt.encode(
{"sub": username, "id": user_id, "exp": expires, "type": "refresh", "jti": uuid.uuid4().hex},
SECRET_KEY, algorithm=ALGORITHM,
)
db.add(RefreshToken(user_id=user_id, token=token, expires_at=expires))
db.commit()
return token
def decode_token(token: str) -> dict | None:
try:
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
return None
def _decode_user(payload: dict) -> dict | None:
username = payload.get("sub")
user_id = payload.get("id")
if username is None or user_id is None:
return None
return {"username": username, "id": user_id}
async def get_current_user(
request: Request,
token: Annotated[Optional[str], Depends(oauth2_bearer)],
) -> dict | None:
# Try access token from cookie
access_cookie = request.cookies.get("access_token")
effective = access_cookie or token
if effective:
payload = decode_token(effective)
if payload and payload.get("type") == "access":
return _decode_user(payload)
# Access token missing/expired — try refresh
refresh_cookie = request.cookies.get("refresh_token")
if refresh_cookie:
new_access = try_refresh(refresh_cookie, request)
if new_access:
payload = decode_token(new_access)
if payload:
return _decode_user(payload)
return None
def try_refresh(refresh_token: str, request: Request) -> str | None:
payload = decode_token(refresh_token)
if not payload or payload.get("type") != "refresh":
return None
db = SessionLocal()
try:
db_token = db.query(RefreshToken).filter(
RefreshToken.token == refresh_token,
RefreshToken.revoked == False,
).first()
if not db_token:
return None
if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
db_token.revoked = True
db.commit()
return None
new_access = create_access_token(payload["sub"], payload["id"])
# Stash new access token on request state so the caller can set the cookie
request.state.new_access_token = new_access
return new_access
finally:
db.close()
@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_user(db: db_dependency, username: str = Form(...), password: str = Form(...),
email: str = Form(""), full_name: str = Form("")):
if not username or not password:
raise HTTPException(status_code=400, detail="Username and password required")
if len(password) < 6:
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
if db.query(User).filter(User.username == username).first():
raise HTTPException(status_code=400, detail="Username already taken")
db.add(User(
username=username,
email=email or None,
full_name=full_name or None,
hashed_password=bcrypt_context.hash(password),
))
db.commit()
@router.post("/token", response_model=Token)
async def login_for_access_token(
username: str = Form(...),
password: str = Form(...),
db: Session = Depends(get_db),
):
user = authenticate_user(username, password, db)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user.")
access = create_access_token(user.username, user.id)
refresh = create_refresh_token(user.username, user.id, db)
return {"access_token": access, "refresh_token": refresh, "token_type": "bearer"}
@router.post("/refresh", response_model=Token)
async def refresh_access_token(
refresh_token: str = Form(...),
db: Session = Depends(get_db),
):
payload = decode_token(refresh_token)
if not payload or payload.get("type") != "refresh":
raise HTTPException(status_code=401, detail="Invalid refresh token")
db_token = db.query(RefreshToken).filter(
RefreshToken.token == refresh_token,
RefreshToken.revoked == False,
).first()
if not db_token:
raise HTTPException(status_code=401, detail="Refresh token revoked or not found")
if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
db_token.revoked = True
db.commit()
raise HTTPException(status_code=401, detail="Refresh token expired")
new_access = create_access_token(payload["sub"], payload["id"])
new_refresh = create_refresh_token(payload["sub"], payload["id"], db)
db_token.revoked = True
db.commit()
return {"access_token": new_access, "refresh_token": new_refresh, "token_type": "bearer"}
def _set_auth_cookies(response: RedirectResponse, access: str, refresh: str) -> RedirectResponse:
response.set_cookie("access_token", access, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, samesite="lax")
response.set_cookie("refresh_token", refresh, httponly=True, max_age=REFRESH_TOKEN_EXPIRE_DAYS * 86400, samesite="lax")
return response
def _clear_auth_cookies(response: RedirectResponse) -> RedirectResponse:
response.delete_cookie("access_token")
response.delete_cookie("refresh_token")
return response
@router.get("/login", response_class=HTMLResponse)
async def login_page(request: Request):
return templates.TemplateResponse(request, "auth/login.html", {})
@router.post("/login")
async def login_submit(request: Request, db: db_dependency):
form = await request.form()
username = form.get("username", "")
password = form.get("password", "")
user = authenticate_user(username, password, db)
if not user:
return templates.TemplateResponse(request, "auth/login.html", {
"error": "Invalid username or password",
})
access = create_access_token(user.username, user.id)
refresh = create_refresh_token(user.username, user.id, db)
response = RedirectResponse(url="/", status_code=303)
_set_auth_cookies(response, access, refresh)
return response
@router.get("/register", response_class=HTMLResponse)
async def register_page(request: Request):
return templates.TemplateResponse(request, "auth/register.html", {})
@router.post("/register")
async def register_submit(request: Request, db: db_dependency):
form = await request.form()
username = form.get("username", "").strip()
email = form.get("email", "").strip()
full_name = form.get("full_name", "").strip()
password = form.get("password", "")
confirm = form.get("confirm_password", "")
errors = []
if not username:
errors.append("Username is required.")
if not password:
errors.append("Password is required.")
if password != confirm:
errors.append("Passwords do not match.")
if len(password) < 6:
errors.append("Password must be at least 6 characters.")
if db.query(User).filter(User.username == username).first():
errors.append("Username already taken.")
if email and db.query(User).filter(User.email == email).first():
errors.append("Email already registered.")
if errors:
return templates.TemplateResponse(request, "auth/register.html", {
"errors": errors,
"username": username,
"email": email,
"full_name": full_name,
})
db.add(User(
username=username,
email=email or None,
full_name=full_name or None,
hashed_password=bcrypt_context.hash(password),
))
db.commit()
return RedirectResponse(url="/auth/login?registered=1", status_code=303)
@router.get("/logout")
async def logout(request: Request, db: db_dependency):
refresh_token = request.cookies.get("refresh_token")
if refresh_token:
db_token = db.query(RefreshToken).filter(RefreshToken.token == refresh_token).first()
if db_token:
db_token.revoked = True
db.commit()
response = RedirectResponse(url="/", status_code=303)
_clear_auth_cookies(response)
return response