import uuid from datetime import datetime, timedelta, timezone from typing import Annotated, Optional from fastapi import APIRouter, Depends, HTTPException, Request, Form from fastapi.responses import HTMLResponse, RedirectResponse from jose import jwt, JWTError from pydantic import BaseModel from sqlalchemy.orm import Session from starlette import status from config import ( SECRET_KEY, ALGORITHM, templates, bcrypt_context, oauth2_bearer, ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS, ) from database import SessionLocal from models import User, RefreshToken router = APIRouter( prefix="/auth", tags=["auth"] ) def get_db(): db = SessionLocal() try: yield db finally: db.close() db_dependency = Annotated[Session, Depends(get_db)] class Token(BaseModel): access_token: str refresh_token: str token_type: str def authenticate_user(username: str, password: str, db: Session): user = db.query(User).filter(User.username == username).first() if not user: return False if not bcrypt_context.verify(password, user.hashed_password): return False return user def create_access_token(username: str, user_id: int) -> str: expires = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) payload = {"sub": username, "id": user_id, "exp": expires, "type": "access"} return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) def create_refresh_token(username: str, user_id: int, db: Session) -> str: expires = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) token = jwt.encode( {"sub": username, "id": user_id, "exp": expires, "type": "refresh", "jti": uuid.uuid4().hex}, SECRET_KEY, algorithm=ALGORITHM, ) db.add(RefreshToken(user_id=user_id, token=token, expires_at=expires)) db.commit() return token def decode_token(token: str) -> dict | None: try: return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) except JWTError: return None def _decode_user(payload: dict) -> dict | None: username = payload.get("sub") user_id = payload.get("id") if username is None or user_id is None: return None return {"username": username, "id": user_id} async def get_current_user( request: Request, token: Annotated[Optional[str], Depends(oauth2_bearer)], ) -> dict | None: # Try access token from cookie access_cookie = request.cookies.get("access_token") effective = access_cookie or token if effective: payload = decode_token(effective) if payload and payload.get("type") == "access": return _decode_user(payload) # Access token missing/expired — try refresh refresh_cookie = request.cookies.get("refresh_token") if refresh_cookie: new_access = try_refresh(refresh_cookie, request) if new_access: payload = decode_token(new_access) if payload: return _decode_user(payload) return None def try_refresh(refresh_token: str, request: Request) -> str | None: payload = decode_token(refresh_token) if not payload or payload.get("type") != "refresh": return None db = SessionLocal() try: db_token = db.query(RefreshToken).filter( RefreshToken.token == refresh_token, RefreshToken.revoked == False, ).first() if not db_token: return None if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): db_token.revoked = True db.commit() return None new_access = create_access_token(payload["sub"], payload["id"]) # Stash new access token on request state so the caller can set the cookie request.state.new_access_token = new_access return new_access finally: db.close() @router.post("/", status_code=status.HTTP_201_CREATED) async def create_user(db: db_dependency, username: str = Form(...), password: str = Form(...), email: str = Form(""), full_name: str = Form("")): if not username or not password: raise HTTPException(status_code=400, detail="Username and password required") if len(password) < 6: raise HTTPException(status_code=400, detail="Password must be at least 6 characters") if db.query(User).filter(User.username == username).first(): raise HTTPException(status_code=400, detail="Username already taken") db.add(User( username=username, email=email or None, full_name=full_name or None, hashed_password=bcrypt_context.hash(password), )) db.commit() @router.post("/token", response_model=Token) async def login_for_access_token( username: str = Form(...), password: str = Form(...), db: Session = Depends(get_db), ): user = authenticate_user(username, password, db) if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user.") access = create_access_token(user.username, user.id) refresh = create_refresh_token(user.username, user.id, db) return {"access_token": access, "refresh_token": refresh, "token_type": "bearer"} @router.post("/refresh", response_model=Token) async def refresh_access_token( refresh_token: str = Form(...), db: Session = Depends(get_db), ): payload = decode_token(refresh_token) if not payload or payload.get("type") != "refresh": raise HTTPException(status_code=401, detail="Invalid refresh token") db_token = db.query(RefreshToken).filter( RefreshToken.token == refresh_token, RefreshToken.revoked == False, ).first() if not db_token: raise HTTPException(status_code=401, detail="Refresh token revoked or not found") if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): db_token.revoked = True db.commit() raise HTTPException(status_code=401, detail="Refresh token expired") new_access = create_access_token(payload["sub"], payload["id"]) new_refresh = create_refresh_token(payload["sub"], payload["id"], db) db_token.revoked = True db.commit() return {"access_token": new_access, "refresh_token": new_refresh, "token_type": "bearer"} def _set_auth_cookies(response: RedirectResponse, access: str, refresh: str) -> RedirectResponse: response.set_cookie("access_token", access, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, samesite="lax") response.set_cookie("refresh_token", refresh, httponly=True, max_age=REFRESH_TOKEN_EXPIRE_DAYS * 86400, samesite="lax") return response def _clear_auth_cookies(response: RedirectResponse) -> RedirectResponse: response.delete_cookie("access_token") response.delete_cookie("refresh_token") return response @router.get("/login", response_class=HTMLResponse) async def login_page(request: Request): return templates.TemplateResponse(request, "auth/login.html", {}) @router.post("/login") async def login_submit(request: Request, db: db_dependency): form = await request.form() username = form.get("username", "") password = form.get("password", "") user = authenticate_user(username, password, db) if not user: return templates.TemplateResponse(request, "auth/login.html", { "error": "Invalid username or password", }) access = create_access_token(user.username, user.id) refresh = create_refresh_token(user.username, user.id, db) response = RedirectResponse(url="/", status_code=303) _set_auth_cookies(response, access, refresh) return response @router.get("/register", response_class=HTMLResponse) async def register_page(request: Request): return templates.TemplateResponse(request, "auth/register.html", {}) @router.post("/register") async def register_submit(request: Request, db: db_dependency): form = await request.form() username = form.get("username", "").strip() email = form.get("email", "").strip() full_name = form.get("full_name", "").strip() password = form.get("password", "") confirm = form.get("confirm_password", "") errors = [] if not username: errors.append("Username is required.") if not password: errors.append("Password is required.") if password != confirm: errors.append("Passwords do not match.") if len(password) < 6: errors.append("Password must be at least 6 characters.") if db.query(User).filter(User.username == username).first(): errors.append("Username already taken.") if email and db.query(User).filter(User.email == email).first(): errors.append("Email already registered.") if errors: return templates.TemplateResponse(request, "auth/register.html", { "errors": errors, "username": username, "email": email, "full_name": full_name, }) db.add(User( username=username, email=email or None, full_name=full_name or None, hashed_password=bcrypt_context.hash(password), )) db.commit() return RedirectResponse(url="/auth/login?registered=1", status_code=303) @router.get("/logout") async def logout(request: Request, db: db_dependency): refresh_token = request.cookies.get("refresh_token") if refresh_token: db_token = db.query(RefreshToken).filter(RefreshToken.token == refresh_token).first() if db_token: db_token.revoked = True db.commit() response = RedirectResponse(url="/", status_code=303) _clear_auth_cookies(response) return response