✏️

- feature access/refresh tokens auth
This commit is contained in:
2026-06-11 15:59:29 +08:00
parent ea8e41e688
commit 3f386e5e38
5 changed files with 189 additions and 52 deletions
+150 -46
View File
@@ -1,3 +1,4 @@
import uuid
from datetime import datetime, timedelta, timezone
from typing import Annotated, Optional
@@ -8,9 +9,12 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from starlette import status
from config import SECRET_KEY, ALGORITHM, templates, bcrypt_context, oauth2_bearer
from config import (
SECRET_KEY, ALGORITHM, templates, bcrypt_context, oauth2_bearer,
ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS,
)
from database import SessionLocal
from models import User
from models import User, RefreshToken
router = APIRouter(
prefix="/auth",
@@ -29,15 +33,9 @@ def get_db():
db_dependency = Annotated[Session, Depends(get_db)]
class CreateUserRequest(BaseModel):
username: str
password: str
email: str = ""
full_name: str = ""
class Token(BaseModel):
access_token: str
refresh_token: str
token_type: str
@@ -50,41 +48,104 @@ def authenticate_user(username: str, password: str, db: Session):
return user
def create_access_token(username: str, user_id: int, expires_delta: timedelta):
encode = {"sub": username, "id": user_id}
expires = datetime.now(timezone.utc) + expires_delta
encode.update({"exp": expires})
return jwt.encode(encode, SECRET_KEY, algorithm=ALGORITHM)
def create_access_token(username: str, user_id: int) -> str:
expires = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {"sub": username, "id": user_id, "exp": expires, "type": "access"}
return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
def create_refresh_token(username: str, user_id: int, db: Session) -> str:
expires = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
token = jwt.encode(
{"sub": username, "id": user_id, "exp": expires, "type": "refresh", "jti": uuid.uuid4().hex},
SECRET_KEY, algorithm=ALGORITHM,
)
db.add(RefreshToken(user_id=user_id, token=token, expires_at=expires))
db.commit()
return token
def decode_token(token: str) -> dict | None:
try:
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
return None
def _decode_user(payload: dict) -> dict | None:
username = payload.get("sub")
user_id = payload.get("id")
if username is None or user_id is None:
return None
return {"username": username, "id": user_id}
async def get_current_user(
request: Request,
token: Annotated[Optional[str], Depends(oauth2_bearer)],
):
cookie_token = request.cookies.get("access_token")
effective_token = cookie_token or token
if not effective_token:
) -> dict | None:
# Try access token from cookie
access_cookie = request.cookies.get("access_token")
effective = access_cookie or token
if effective:
payload = decode_token(effective)
if payload and payload.get("type") == "access":
return _decode_user(payload)
# Access token missing/expired — try refresh
refresh_cookie = request.cookies.get("refresh_token")
if refresh_cookie:
new_access = try_refresh(refresh_cookie, request)
if new_access:
payload = decode_token(new_access)
if payload:
return _decode_user(payload)
return None
def try_refresh(refresh_token: str, request: Request) -> str | None:
payload = decode_token(refresh_token)
if not payload or payload.get("type") != "refresh":
return None
db = SessionLocal()
try:
payload = jwt.decode(effective_token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
user_id: int = payload.get("id")
if username is None or user_id is None:
db_token = db.query(RefreshToken).filter(
RefreshToken.token == refresh_token,
RefreshToken.revoked == False,
).first()
if not db_token:
return None
return {"username": username, "id": user_id}
except JWTError:
return None
if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
db_token.revoked = True
db.commit()
return None
new_access = create_access_token(payload["sub"], payload["id"])
# Stash new access token on request state so the caller can set the cookie
request.state.new_access_token = new_access
return new_access
finally:
db.close()
@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_user(db: db_dependency, create_user_request: CreateUserRequest):
create_user_model = User(
username=create_user_request.username,
email=create_user_request.email or None,
full_name=create_user_request.full_name or None,
hashed_password=bcrypt_context.hash(create_user_request.password),
)
db.add(create_user_model)
async def create_user(db: db_dependency, username: str = Form(...), password: str = Form(...),
email: str = Form(""), full_name: str = Form("")):
if not username or not password:
raise HTTPException(status_code=400, detail="Username and password required")
if len(password) < 6:
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
if db.query(User).filter(User.username == username).first():
raise HTTPException(status_code=400, detail="Username already taken")
db.add(User(
username=username,
email=email or None,
full_name=full_name or None,
hashed_password=bcrypt_context.hash(password),
))
db.commit()
@@ -96,12 +157,49 @@ async def login_for_access_token(
):
user = authenticate_user(username, password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate user.",
)
token = create_access_token(user.username, user.id, timedelta(minutes=60))
return {"access_token": token, "token_type": "bearer"}
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user.")
access = create_access_token(user.username, user.id)
refresh = create_refresh_token(user.username, user.id, db)
return {"access_token": access, "refresh_token": refresh, "token_type": "bearer"}
@router.post("/refresh", response_model=Token)
async def refresh_access_token(
refresh_token: str = Form(...),
db: Session = Depends(get_db),
):
payload = decode_token(refresh_token)
if not payload or payload.get("type") != "refresh":
raise HTTPException(status_code=401, detail="Invalid refresh token")
db_token = db.query(RefreshToken).filter(
RefreshToken.token == refresh_token,
RefreshToken.revoked == False,
).first()
if not db_token:
raise HTTPException(status_code=401, detail="Refresh token revoked or not found")
if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
db_token.revoked = True
db.commit()
raise HTTPException(status_code=401, detail="Refresh token expired")
new_access = create_access_token(payload["sub"], payload["id"])
new_refresh = create_refresh_token(payload["sub"], payload["id"], db)
db_token.revoked = True
db.commit()
return {"access_token": new_access, "refresh_token": new_refresh, "token_type": "bearer"}
def _set_auth_cookies(response: RedirectResponse, access: str, refresh: str) -> RedirectResponse:
response.set_cookie("access_token", access, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, samesite="lax")
response.set_cookie("refresh_token", refresh, httponly=True, max_age=REFRESH_TOKEN_EXPIRE_DAYS * 86400, samesite="lax")
return response
def _clear_auth_cookies(response: RedirectResponse) -> RedirectResponse:
response.delete_cookie("access_token")
response.delete_cookie("refresh_token")
return response
@router.get("/login", response_class=HTMLResponse)
@@ -119,9 +217,10 @@ async def login_submit(request: Request, db: db_dependency):
return templates.TemplateResponse(request, "auth/login.html", {
"error": "Invalid username or password",
})
token = create_access_token(user.username, user.id, timedelta(minutes=60))
access = create_access_token(user.username, user.id)
refresh = create_refresh_token(user.username, user.id, db)
response = RedirectResponse(url="/", status_code=303)
response.set_cookie(key="access_token", value=token, httponly=True, max_age=3600, samesite="lax")
_set_auth_cookies(response, access, refresh)
return response
@@ -161,19 +260,24 @@ async def register_submit(request: Request, db: db_dependency):
"full_name": full_name,
})
new_user = User(
db.add(User(
username=username,
email=email or None,
full_name=full_name or None,
hashed_password=bcrypt_context.hash(password),
)
db.add(new_user)
))
db.commit()
return RedirectResponse(url="/auth/login?registered=1", status_code=303)
@router.get("/logout")
async def logout():
async def logout(request: Request, db: db_dependency):
refresh_token = request.cookies.get("refresh_token")
if refresh_token:
db_token = db.query(RefreshToken).filter(RefreshToken.token == refresh_token).first()
if db_token:
db_token.revoked = True
db.commit()
response = RedirectResponse(url="/", status_code=303)
response.delete_cookie("access_token")
_clear_auth_cookies(response)
return response