diff --git a/CLAUDE.md b/CLAUDE.md index f59c887..017f5b0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -27,19 +27,21 @@ uvicorn main:app --reload **`main.py`** — FastAPI entry point. Creates the app, includes auth and properties routers, creates DB tables on startup, serves the homepage. -**`auth.py`** — Authentication router (`/auth` prefix). User registration, login (JWT in HttpOnly cookie), logout. `get_current_user()` checks cookie first, then bearer token, returns `None` for anonymous users. +**`auth.py`** — Authentication router (`/auth` prefix). Access-refresh token system: access token (15 min) + refresh token (7 days, stored in DB). `get_current_user()` checks access token cookie/bearer first, auto-refreshes via refresh token if expired. Web login sets both cookies; logout revokes refresh token in DB. -**`config.py`** — Shared config: `SECRET_KEY`, `templates`, `bcrypt_context`, `oauth2_bearer`. Jinja2 cache disabled for Python 3.14 compatibility. +**`config.py`** — Shared config: `SECRET_KEY`, `ALGORITHM`, `ACCESS_TOKEN_EXPIRE_MINUTES` (15), `REFRESH_TOKEN_EXPIRE_DAYS` (7), `templates`, `bcrypt_context`, `oauth2_bearer`. Jinja2 cache disabled for Python 3.14 compatibility. **`database.py`** — SQLAlchemy setup with MySQL engine, `SessionLocal`, `Base`. -**`models.py`** — ORM models: `User` (with email/full_name/phone), `Property` (with contact_email/contact_phone), `PropertyImage`, `Favorite`. +**`models.py`** — ORM models: `User` (with email/full_name/phone), `Property` (with contact_email/contact_phone), `PropertyImage`, `Favorite`, `RefreshToken` (for token rotation/revocation). **`properties.py`** — Property CRUD, search/filter, favorites, dashboard. Static paths (`/properties/new`) must be defined before parameterized paths (`/properties/{prop_id}`). ## Data Flow -Request → Router → dependency injection (`get_current_user`, `get_db`) → SQLAlchemy → Jinja2 template → HTML response. +Request → `RefreshTokenMiddleware` (sets new access cookie if refreshed) → Router → dependency injection (`get_current_user`, `get_db`) → SQLAlchemy → Jinja2 template → HTML response. + +**Auth flow:** Login → issue access (15 min) + refresh (7 days) tokens → both set as HttpOnly cookies. Page load → `get_current_user()` decodes access token; if expired, uses refresh token cookie to get new access token via `try_refresh()`. Logout → revoke refresh token in DB, clear both cookies. **TemplateResponse API:** `templates.TemplateResponse(request, "template.html", {"user": user, ...})` — request is first arg (Starlette 1.x). diff --git a/auth.py b/auth.py index 10bb97b..75be3f0 100644 --- a/auth.py +++ b/auth.py @@ -1,3 +1,4 @@ +import uuid from datetime import datetime, timedelta, timezone from typing import Annotated, Optional @@ -8,9 +9,12 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from starlette import status -from config import SECRET_KEY, ALGORITHM, templates, bcrypt_context, oauth2_bearer +from config import ( + SECRET_KEY, ALGORITHM, templates, bcrypt_context, oauth2_bearer, + ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS, +) from database import SessionLocal -from models import User +from models import User, RefreshToken router = APIRouter( prefix="/auth", @@ -29,15 +33,9 @@ def get_db(): db_dependency = Annotated[Session, Depends(get_db)] -class CreateUserRequest(BaseModel): - username: str - password: str - email: str = "" - full_name: str = "" - - class Token(BaseModel): access_token: str + refresh_token: str token_type: str @@ -50,41 +48,104 @@ def authenticate_user(username: str, password: str, db: Session): return user -def create_access_token(username: str, user_id: int, expires_delta: timedelta): - encode = {"sub": username, "id": user_id} - expires = datetime.now(timezone.utc) + expires_delta - encode.update({"exp": expires}) - return jwt.encode(encode, SECRET_KEY, algorithm=ALGORITHM) +def create_access_token(username: str, user_id: int) -> str: + expires = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + payload = {"sub": username, "id": user_id, "exp": expires, "type": "access"} + return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) + + +def create_refresh_token(username: str, user_id: int, db: Session) -> str: + expires = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + token = jwt.encode( + {"sub": username, "id": user_id, "exp": expires, "type": "refresh", "jti": uuid.uuid4().hex}, + SECRET_KEY, algorithm=ALGORITHM, + ) + db.add(RefreshToken(user_id=user_id, token=token, expires_at=expires)) + db.commit() + return token + + +def decode_token(token: str) -> dict | None: + try: + return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + except JWTError: + return None + + +def _decode_user(payload: dict) -> dict | None: + username = payload.get("sub") + user_id = payload.get("id") + if username is None or user_id is None: + return None + return {"username": username, "id": user_id} async def get_current_user( request: Request, token: Annotated[Optional[str], Depends(oauth2_bearer)], -): - cookie_token = request.cookies.get("access_token") - effective_token = cookie_token or token - if not effective_token: +) -> dict | None: + # Try access token from cookie + access_cookie = request.cookies.get("access_token") + effective = access_cookie or token + if effective: + payload = decode_token(effective) + if payload and payload.get("type") == "access": + return _decode_user(payload) + + # Access token missing/expired — try refresh + refresh_cookie = request.cookies.get("refresh_token") + if refresh_cookie: + new_access = try_refresh(refresh_cookie, request) + if new_access: + payload = decode_token(new_access) + if payload: + return _decode_user(payload) + + return None + + +def try_refresh(refresh_token: str, request: Request) -> str | None: + payload = decode_token(refresh_token) + if not payload or payload.get("type") != "refresh": return None + + db = SessionLocal() try: - payload = jwt.decode(effective_token, SECRET_KEY, algorithms=[ALGORITHM]) - username: str = payload.get("sub") - user_id: int = payload.get("id") - if username is None or user_id is None: + db_token = db.query(RefreshToken).filter( + RefreshToken.token == refresh_token, + RefreshToken.revoked == False, + ).first() + if not db_token: return None - return {"username": username, "id": user_id} - except JWTError: - return None + if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): + db_token.revoked = True + db.commit() + return None + + new_access = create_access_token(payload["sub"], payload["id"]) + + # Stash new access token on request state so the caller can set the cookie + request.state.new_access_token = new_access + return new_access + finally: + db.close() @router.post("/", status_code=status.HTTP_201_CREATED) -async def create_user(db: db_dependency, create_user_request: CreateUserRequest): - create_user_model = User( - username=create_user_request.username, - email=create_user_request.email or None, - full_name=create_user_request.full_name or None, - hashed_password=bcrypt_context.hash(create_user_request.password), - ) - db.add(create_user_model) +async def create_user(db: db_dependency, username: str = Form(...), password: str = Form(...), + email: str = Form(""), full_name: str = Form("")): + if not username or not password: + raise HTTPException(status_code=400, detail="Username and password required") + if len(password) < 6: + raise HTTPException(status_code=400, detail="Password must be at least 6 characters") + if db.query(User).filter(User.username == username).first(): + raise HTTPException(status_code=400, detail="Username already taken") + db.add(User( + username=username, + email=email or None, + full_name=full_name or None, + hashed_password=bcrypt_context.hash(password), + )) db.commit() @@ -96,12 +157,49 @@ async def login_for_access_token( ): user = authenticate_user(username, password, db) if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate user.", - ) - token = create_access_token(user.username, user.id, timedelta(minutes=60)) - return {"access_token": token, "token_type": "bearer"} + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user.") + access = create_access_token(user.username, user.id) + refresh = create_refresh_token(user.username, user.id, db) + return {"access_token": access, "refresh_token": refresh, "token_type": "bearer"} + + +@router.post("/refresh", response_model=Token) +async def refresh_access_token( + refresh_token: str = Form(...), + db: Session = Depends(get_db), +): + payload = decode_token(refresh_token) + if not payload or payload.get("type") != "refresh": + raise HTTPException(status_code=401, detail="Invalid refresh token") + + db_token = db.query(RefreshToken).filter( + RefreshToken.token == refresh_token, + RefreshToken.revoked == False, + ).first() + if not db_token: + raise HTTPException(status_code=401, detail="Refresh token revoked or not found") + if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): + db_token.revoked = True + db.commit() + raise HTTPException(status_code=401, detail="Refresh token expired") + + new_access = create_access_token(payload["sub"], payload["id"]) + new_refresh = create_refresh_token(payload["sub"], payload["id"], db) + db_token.revoked = True + db.commit() + return {"access_token": new_access, "refresh_token": new_refresh, "token_type": "bearer"} + + +def _set_auth_cookies(response: RedirectResponse, access: str, refresh: str) -> RedirectResponse: + response.set_cookie("access_token", access, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, samesite="lax") + response.set_cookie("refresh_token", refresh, httponly=True, max_age=REFRESH_TOKEN_EXPIRE_DAYS * 86400, samesite="lax") + return response + + +def _clear_auth_cookies(response: RedirectResponse) -> RedirectResponse: + response.delete_cookie("access_token") + response.delete_cookie("refresh_token") + return response @router.get("/login", response_class=HTMLResponse) @@ -119,9 +217,10 @@ async def login_submit(request: Request, db: db_dependency): return templates.TemplateResponse(request, "auth/login.html", { "error": "Invalid username or password", }) - token = create_access_token(user.username, user.id, timedelta(minutes=60)) + access = create_access_token(user.username, user.id) + refresh = create_refresh_token(user.username, user.id, db) response = RedirectResponse(url="/", status_code=303) - response.set_cookie(key="access_token", value=token, httponly=True, max_age=3600, samesite="lax") + _set_auth_cookies(response, access, refresh) return response @@ -161,19 +260,24 @@ async def register_submit(request: Request, db: db_dependency): "full_name": full_name, }) - new_user = User( + db.add(User( username=username, email=email or None, full_name=full_name or None, hashed_password=bcrypt_context.hash(password), - ) - db.add(new_user) + )) db.commit() return RedirectResponse(url="/auth/login?registered=1", status_code=303) @router.get("/logout") -async def logout(): +async def logout(request: Request, db: db_dependency): + refresh_token = request.cookies.get("refresh_token") + if refresh_token: + db_token = db.query(RefreshToken).filter(RefreshToken.token == refresh_token).first() + if db_token: + db_token.revoked = True + db.commit() response = RedirectResponse(url="/", status_code=303) - response.delete_cookie("access_token") + _clear_auth_cookies(response) return response diff --git a/config.py b/config.py index 0083a46..befc5a2 100644 --- a/config.py +++ b/config.py @@ -6,6 +6,8 @@ import jinja2 SECRET_KEY = os.getenv("SECRET_KEY", "nexhome-dev-secret-key-change-in-production") ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 +REFRESH_TOKEN_EXPIRE_DAYS = 7 # Disable Jinja2 cache to work around Python 3.14 compatibility issue _loader = jinja2.FileSystemLoader("templates") diff --git a/main.py b/main.py index eff8443..79a2976 100644 --- a/main.py +++ b/main.py @@ -3,11 +3,12 @@ from typing import Annotated, Optional from fastapi import FastAPI, Depends, Request from sqlalchemy.orm import Session from starlette import status +from starlette.middleware.base import BaseHTTPMiddleware from starlette.staticfiles import StaticFiles import models from auth import get_db, get_current_user, router as auth_router -from config import templates +from config import templates, ACCESS_TOKEN_EXPIRE_MINUTES from models import Property from properties import router as properties_router @@ -19,6 +20,21 @@ app.mount("/static", StaticFiles(directory="static"), name="static") models.Base.metadata.create_all(bind=__import__("database", fromlist=["engine"]).engine) + +class RefreshTokenMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + response = await call_next(request) + new_token = getattr(request.state, "new_access_token", None) + if new_token: + response.set_cookie( + "access_token", new_token, + httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, samesite="lax", + ) + return response + +app.add_middleware(RefreshTokenMiddleware) + + db_dependency = Annotated[Session, Depends(get_db)] user_dependency = Annotated[Optional[dict], Depends(get_current_user)] diff --git a/models.py b/models.py index d93884c..c7a1f25 100644 --- a/models.py +++ b/models.py @@ -13,11 +13,11 @@ class User(Base): hashed_password = Column(String(60)) full_name = Column(String(100), nullable=True) phone = Column(String(20), nullable=True) - is_agent = Column(Boolean, default=False) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) properties = relationship("Property", back_populates="owner") favorites = relationship("Favorite", back_populates="user") + refresh_tokens = relationship("RefreshToken", back_populates="user") class Property(Base): @@ -74,3 +74,16 @@ class Favorite(Base): property = relationship("Property", back_populates="favorites") __table_args__ = (UniqueConstraint("user_id", "property_id", name="uq_user_property"),) + + +class RefreshToken(Base): + __tablename__ = "refresh_tokens" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) + token = Column(String(500), unique=True, nullable=False, index=True) + expires_at = Column(DateTime, nullable=False) + revoked = Column(Boolean, default=False) + created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + + user = relationship("User", back_populates="refresh_tokens")