3f386e5e38
- feature access/refresh tokens auth
284 lines
9.5 KiB
Python
284 lines
9.5 KiB
Python
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Annotated, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, Form
|
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
|
from jose import jwt, JWTError
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
from starlette import status
|
|
|
|
from config import (
|
|
SECRET_KEY, ALGORITHM, templates, bcrypt_context, oauth2_bearer,
|
|
ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS,
|
|
)
|
|
from database import SessionLocal
|
|
from models import User, RefreshToken
|
|
|
|
router = APIRouter(
|
|
prefix="/auth",
|
|
tags=["auth"]
|
|
)
|
|
|
|
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
db_dependency = Annotated[Session, Depends(get_db)]
|
|
|
|
|
|
class Token(BaseModel):
|
|
access_token: str
|
|
refresh_token: str
|
|
token_type: str
|
|
|
|
|
|
def authenticate_user(username: str, password: str, db: Session):
|
|
user = db.query(User).filter(User.username == username).first()
|
|
if not user:
|
|
return False
|
|
if not bcrypt_context.verify(password, user.hashed_password):
|
|
return False
|
|
return user
|
|
|
|
|
|
def create_access_token(username: str, user_id: int) -> str:
|
|
expires = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
payload = {"sub": username, "id": user_id, "exp": expires, "type": "access"}
|
|
return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
|
def create_refresh_token(username: str, user_id: int, db: Session) -> str:
|
|
expires = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
token = jwt.encode(
|
|
{"sub": username, "id": user_id, "exp": expires, "type": "refresh", "jti": uuid.uuid4().hex},
|
|
SECRET_KEY, algorithm=ALGORITHM,
|
|
)
|
|
db.add(RefreshToken(user_id=user_id, token=token, expires_at=expires))
|
|
db.commit()
|
|
return token
|
|
|
|
|
|
def decode_token(token: str) -> dict | None:
|
|
try:
|
|
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
def _decode_user(payload: dict) -> dict | None:
|
|
username = payload.get("sub")
|
|
user_id = payload.get("id")
|
|
if username is None or user_id is None:
|
|
return None
|
|
return {"username": username, "id": user_id}
|
|
|
|
|
|
async def get_current_user(
|
|
request: Request,
|
|
token: Annotated[Optional[str], Depends(oauth2_bearer)],
|
|
) -> dict | None:
|
|
# Try access token from cookie
|
|
access_cookie = request.cookies.get("access_token")
|
|
effective = access_cookie or token
|
|
if effective:
|
|
payload = decode_token(effective)
|
|
if payload and payload.get("type") == "access":
|
|
return _decode_user(payload)
|
|
|
|
# Access token missing/expired — try refresh
|
|
refresh_cookie = request.cookies.get("refresh_token")
|
|
if refresh_cookie:
|
|
new_access = try_refresh(refresh_cookie, request)
|
|
if new_access:
|
|
payload = decode_token(new_access)
|
|
if payload:
|
|
return _decode_user(payload)
|
|
|
|
return None
|
|
|
|
|
|
def try_refresh(refresh_token: str, request: Request) -> str | None:
|
|
payload = decode_token(refresh_token)
|
|
if not payload or payload.get("type") != "refresh":
|
|
return None
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
db_token = db.query(RefreshToken).filter(
|
|
RefreshToken.token == refresh_token,
|
|
RefreshToken.revoked == False,
|
|
).first()
|
|
if not db_token:
|
|
return None
|
|
if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
|
db_token.revoked = True
|
|
db.commit()
|
|
return None
|
|
|
|
new_access = create_access_token(payload["sub"], payload["id"])
|
|
|
|
# Stash new access token on request state so the caller can set the cookie
|
|
request.state.new_access_token = new_access
|
|
return new_access
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@router.post("/", status_code=status.HTTP_201_CREATED)
|
|
async def create_user(db: db_dependency, username: str = Form(...), password: str = Form(...),
|
|
email: str = Form(""), full_name: str = Form("")):
|
|
if not username or not password:
|
|
raise HTTPException(status_code=400, detail="Username and password required")
|
|
if len(password) < 6:
|
|
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
|
if db.query(User).filter(User.username == username).first():
|
|
raise HTTPException(status_code=400, detail="Username already taken")
|
|
db.add(User(
|
|
username=username,
|
|
email=email or None,
|
|
full_name=full_name or None,
|
|
hashed_password=bcrypt_context.hash(password),
|
|
))
|
|
db.commit()
|
|
|
|
|
|
@router.post("/token", response_model=Token)
|
|
async def login_for_access_token(
|
|
username: str = Form(...),
|
|
password: str = Form(...),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
user = authenticate_user(username, password, db)
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user.")
|
|
access = create_access_token(user.username, user.id)
|
|
refresh = create_refresh_token(user.username, user.id, db)
|
|
return {"access_token": access, "refresh_token": refresh, "token_type": "bearer"}
|
|
|
|
|
|
@router.post("/refresh", response_model=Token)
|
|
async def refresh_access_token(
|
|
refresh_token: str = Form(...),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
payload = decode_token(refresh_token)
|
|
if not payload or payload.get("type") != "refresh":
|
|
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
|
|
|
db_token = db.query(RefreshToken).filter(
|
|
RefreshToken.token == refresh_token,
|
|
RefreshToken.revoked == False,
|
|
).first()
|
|
if not db_token:
|
|
raise HTTPException(status_code=401, detail="Refresh token revoked or not found")
|
|
if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
|
db_token.revoked = True
|
|
db.commit()
|
|
raise HTTPException(status_code=401, detail="Refresh token expired")
|
|
|
|
new_access = create_access_token(payload["sub"], payload["id"])
|
|
new_refresh = create_refresh_token(payload["sub"], payload["id"], db)
|
|
db_token.revoked = True
|
|
db.commit()
|
|
return {"access_token": new_access, "refresh_token": new_refresh, "token_type": "bearer"}
|
|
|
|
|
|
def _set_auth_cookies(response: RedirectResponse, access: str, refresh: str) -> RedirectResponse:
|
|
response.set_cookie("access_token", access, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, samesite="lax")
|
|
response.set_cookie("refresh_token", refresh, httponly=True, max_age=REFRESH_TOKEN_EXPIRE_DAYS * 86400, samesite="lax")
|
|
return response
|
|
|
|
|
|
def _clear_auth_cookies(response: RedirectResponse) -> RedirectResponse:
|
|
response.delete_cookie("access_token")
|
|
response.delete_cookie("refresh_token")
|
|
return response
|
|
|
|
|
|
@router.get("/login", response_class=HTMLResponse)
|
|
async def login_page(request: Request):
|
|
return templates.TemplateResponse(request, "auth/login.html", {})
|
|
|
|
|
|
@router.post("/login")
|
|
async def login_submit(request: Request, db: db_dependency):
|
|
form = await request.form()
|
|
username = form.get("username", "")
|
|
password = form.get("password", "")
|
|
user = authenticate_user(username, password, db)
|
|
if not user:
|
|
return templates.TemplateResponse(request, "auth/login.html", {
|
|
"error": "Invalid username or password",
|
|
})
|
|
access = create_access_token(user.username, user.id)
|
|
refresh = create_refresh_token(user.username, user.id, db)
|
|
response = RedirectResponse(url="/", status_code=303)
|
|
_set_auth_cookies(response, access, refresh)
|
|
return response
|
|
|
|
|
|
@router.get("/register", response_class=HTMLResponse)
|
|
async def register_page(request: Request):
|
|
return templates.TemplateResponse(request, "auth/register.html", {})
|
|
|
|
|
|
@router.post("/register")
|
|
async def register_submit(request: Request, db: db_dependency):
|
|
form = await request.form()
|
|
username = form.get("username", "").strip()
|
|
email = form.get("email", "").strip()
|
|
full_name = form.get("full_name", "").strip()
|
|
password = form.get("password", "")
|
|
confirm = form.get("confirm_password", "")
|
|
|
|
errors = []
|
|
if not username:
|
|
errors.append("Username is required.")
|
|
if not password:
|
|
errors.append("Password is required.")
|
|
if password != confirm:
|
|
errors.append("Passwords do not match.")
|
|
if len(password) < 6:
|
|
errors.append("Password must be at least 6 characters.")
|
|
if db.query(User).filter(User.username == username).first():
|
|
errors.append("Username already taken.")
|
|
if email and db.query(User).filter(User.email == email).first():
|
|
errors.append("Email already registered.")
|
|
|
|
if errors:
|
|
return templates.TemplateResponse(request, "auth/register.html", {
|
|
"errors": errors,
|
|
"username": username,
|
|
"email": email,
|
|
"full_name": full_name,
|
|
})
|
|
|
|
db.add(User(
|
|
username=username,
|
|
email=email or None,
|
|
full_name=full_name or None,
|
|
hashed_password=bcrypt_context.hash(password),
|
|
))
|
|
db.commit()
|
|
return RedirectResponse(url="/auth/login?registered=1", status_code=303)
|
|
|
|
|
|
@router.get("/logout")
|
|
async def logout(request: Request, db: db_dependency):
|
|
refresh_token = request.cookies.get("refresh_token")
|
|
if refresh_token:
|
|
db_token = db.query(RefreshToken).filter(RefreshToken.token == refresh_token).first()
|
|
if db_token:
|
|
db_token.revoked = True
|
|
db.commit()
|
|
response = RedirectResponse(url="/", status_code=303)
|
|
_clear_auth_cookies(response)
|
|
return response
|