✏️

- feature access/refresh tokens auth
This commit is contained in:
2026-06-11 15:59:29 +08:00
parent ea8e41e688
commit 3f386e5e38
5 changed files with 189 additions and 52 deletions
+6 -4
View File
@@ -27,19 +27,21 @@ uvicorn main:app --reload
**`main.py`** — FastAPI entry point. Creates the app, includes auth and properties routers, creates DB tables on startup, serves the homepage. **`main.py`** — FastAPI entry point. Creates the app, includes auth and properties routers, creates DB tables on startup, serves the homepage.
**`auth.py`** — Authentication router (`/auth` prefix). User registration, login (JWT in HttpOnly cookie), logout. `get_current_user()` checks cookie first, then bearer token, returns `None` for anonymous users. **`auth.py`** — Authentication router (`/auth` prefix). Access-refresh token system: access token (15 min) + refresh token (7 days, stored in DB). `get_current_user()` checks access token cookie/bearer first, auto-refreshes via refresh token if expired. Web login sets both cookies; logout revokes refresh token in DB.
**`config.py`** — Shared config: `SECRET_KEY`, `templates`, `bcrypt_context`, `oauth2_bearer`. Jinja2 cache disabled for Python 3.14 compatibility. **`config.py`** — Shared config: `SECRET_KEY`, `ALGORITHM`, `ACCESS_TOKEN_EXPIRE_MINUTES` (15), `REFRESH_TOKEN_EXPIRE_DAYS` (7), `templates`, `bcrypt_context`, `oauth2_bearer`. Jinja2 cache disabled for Python 3.14 compatibility.
**`database.py`** — SQLAlchemy setup with MySQL engine, `SessionLocal`, `Base`. **`database.py`** — SQLAlchemy setup with MySQL engine, `SessionLocal`, `Base`.
**`models.py`** — ORM models: `User` (with email/full_name/phone), `Property` (with contact_email/contact_phone), `PropertyImage`, `Favorite`. **`models.py`** — ORM models: `User` (with email/full_name/phone), `Property` (with contact_email/contact_phone), `PropertyImage`, `Favorite`, `RefreshToken` (for token rotation/revocation).
**`properties.py`** — Property CRUD, search/filter, favorites, dashboard. Static paths (`/properties/new`) must be defined before parameterized paths (`/properties/{prop_id}`). **`properties.py`** — Property CRUD, search/filter, favorites, dashboard. Static paths (`/properties/new`) must be defined before parameterized paths (`/properties/{prop_id}`).
## Data Flow ## Data Flow
Request → Router → dependency injection (`get_current_user`, `get_db`) → SQLAlchemy → Jinja2 template → HTML response. Request → `RefreshTokenMiddleware` (sets new access cookie if refreshed) → Router → dependency injection (`get_current_user`, `get_db`) → SQLAlchemy → Jinja2 template → HTML response.
**Auth flow:** Login → issue access (15 min) + refresh (7 days) tokens → both set as HttpOnly cookies. Page load → `get_current_user()` decodes access token; if expired, uses refresh token cookie to get new access token via `try_refresh()`. Logout → revoke refresh token in DB, clear both cookies.
**TemplateResponse API:** `templates.TemplateResponse(request, "template.html", {"user": user, ...})` — request is first arg (Starlette 1.x). **TemplateResponse API:** `templates.TemplateResponse(request, "template.html", {"user": user, ...})` — request is first arg (Starlette 1.x).
+149 -45
View File
@@ -1,3 +1,4 @@
import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Annotated, Optional from typing import Annotated, Optional
@@ -8,9 +9,12 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from starlette import status from starlette import status
from config import SECRET_KEY, ALGORITHM, templates, bcrypt_context, oauth2_bearer from config import (
SECRET_KEY, ALGORITHM, templates, bcrypt_context, oauth2_bearer,
ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS,
)
from database import SessionLocal from database import SessionLocal
from models import User from models import User, RefreshToken
router = APIRouter( router = APIRouter(
prefix="/auth", prefix="/auth",
@@ -29,15 +33,9 @@ def get_db():
db_dependency = Annotated[Session, Depends(get_db)] db_dependency = Annotated[Session, Depends(get_db)]
class CreateUserRequest(BaseModel):
username: str
password: str
email: str = ""
full_name: str = ""
class Token(BaseModel): class Token(BaseModel):
access_token: str access_token: str
refresh_token: str
token_type: str token_type: str
@@ -50,41 +48,104 @@ def authenticate_user(username: str, password: str, db: Session):
return user return user
def create_access_token(username: str, user_id: int, expires_delta: timedelta): def create_access_token(username: str, user_id: int) -> str:
encode = {"sub": username, "id": user_id} expires = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expires = datetime.now(timezone.utc) + expires_delta payload = {"sub": username, "id": user_id, "exp": expires, "type": "access"}
encode.update({"exp": expires}) return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
return jwt.encode(encode, SECRET_KEY, algorithm=ALGORITHM)
def create_refresh_token(username: str, user_id: int, db: Session) -> str:
expires = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
token = jwt.encode(
{"sub": username, "id": user_id, "exp": expires, "type": "refresh", "jti": uuid.uuid4().hex},
SECRET_KEY, algorithm=ALGORITHM,
)
db.add(RefreshToken(user_id=user_id, token=token, expires_at=expires))
db.commit()
return token
def decode_token(token: str) -> dict | None:
try:
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
return None
def _decode_user(payload: dict) -> dict | None:
username = payload.get("sub")
user_id = payload.get("id")
if username is None or user_id is None:
return None
return {"username": username, "id": user_id}
async def get_current_user( async def get_current_user(
request: Request, request: Request,
token: Annotated[Optional[str], Depends(oauth2_bearer)], token: Annotated[Optional[str], Depends(oauth2_bearer)],
): ) -> dict | None:
cookie_token = request.cookies.get("access_token") # Try access token from cookie
effective_token = cookie_token or token access_cookie = request.cookies.get("access_token")
if not effective_token: effective = access_cookie or token
if effective:
payload = decode_token(effective)
if payload and payload.get("type") == "access":
return _decode_user(payload)
# Access token missing/expired — try refresh
refresh_cookie = request.cookies.get("refresh_token")
if refresh_cookie:
new_access = try_refresh(refresh_cookie, request)
if new_access:
payload = decode_token(new_access)
if payload:
return _decode_user(payload)
return None return None
def try_refresh(refresh_token: str, request: Request) -> str | None:
payload = decode_token(refresh_token)
if not payload or payload.get("type") != "refresh":
return None
db = SessionLocal()
try: try:
payload = jwt.decode(effective_token, SECRET_KEY, algorithms=[ALGORITHM]) db_token = db.query(RefreshToken).filter(
username: str = payload.get("sub") RefreshToken.token == refresh_token,
user_id: int = payload.get("id") RefreshToken.revoked == False,
if username is None or user_id is None: ).first()
if not db_token:
return None return None
return {"username": username, "id": user_id} if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
except JWTError: db_token.revoked = True
db.commit()
return None return None
new_access = create_access_token(payload["sub"], payload["id"])
# Stash new access token on request state so the caller can set the cookie
request.state.new_access_token = new_access
return new_access
finally:
db.close()
@router.post("/", status_code=status.HTTP_201_CREATED) @router.post("/", status_code=status.HTTP_201_CREATED)
async def create_user(db: db_dependency, create_user_request: CreateUserRequest): async def create_user(db: db_dependency, username: str = Form(...), password: str = Form(...),
create_user_model = User( email: str = Form(""), full_name: str = Form("")):
username=create_user_request.username, if not username or not password:
email=create_user_request.email or None, raise HTTPException(status_code=400, detail="Username and password required")
full_name=create_user_request.full_name or None, if len(password) < 6:
hashed_password=bcrypt_context.hash(create_user_request.password), raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
) if db.query(User).filter(User.username == username).first():
db.add(create_user_model) raise HTTPException(status_code=400, detail="Username already taken")
db.add(User(
username=username,
email=email or None,
full_name=full_name or None,
hashed_password=bcrypt_context.hash(password),
))
db.commit() db.commit()
@@ -96,12 +157,49 @@ async def login_for_access_token(
): ):
user = authenticate_user(username, password, db) user = authenticate_user(username, password, db)
if not user: if not user:
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user.")
status_code=status.HTTP_401_UNAUTHORIZED, access = create_access_token(user.username, user.id)
detail="Could not validate user.", refresh = create_refresh_token(user.username, user.id, db)
) return {"access_token": access, "refresh_token": refresh, "token_type": "bearer"}
token = create_access_token(user.username, user.id, timedelta(minutes=60))
return {"access_token": token, "token_type": "bearer"}
@router.post("/refresh", response_model=Token)
async def refresh_access_token(
refresh_token: str = Form(...),
db: Session = Depends(get_db),
):
payload = decode_token(refresh_token)
if not payload or payload.get("type") != "refresh":
raise HTTPException(status_code=401, detail="Invalid refresh token")
db_token = db.query(RefreshToken).filter(
RefreshToken.token == refresh_token,
RefreshToken.revoked == False,
).first()
if not db_token:
raise HTTPException(status_code=401, detail="Refresh token revoked or not found")
if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
db_token.revoked = True
db.commit()
raise HTTPException(status_code=401, detail="Refresh token expired")
new_access = create_access_token(payload["sub"], payload["id"])
new_refresh = create_refresh_token(payload["sub"], payload["id"], db)
db_token.revoked = True
db.commit()
return {"access_token": new_access, "refresh_token": new_refresh, "token_type": "bearer"}
def _set_auth_cookies(response: RedirectResponse, access: str, refresh: str) -> RedirectResponse:
response.set_cookie("access_token", access, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, samesite="lax")
response.set_cookie("refresh_token", refresh, httponly=True, max_age=REFRESH_TOKEN_EXPIRE_DAYS * 86400, samesite="lax")
return response
def _clear_auth_cookies(response: RedirectResponse) -> RedirectResponse:
response.delete_cookie("access_token")
response.delete_cookie("refresh_token")
return response
@router.get("/login", response_class=HTMLResponse) @router.get("/login", response_class=HTMLResponse)
@@ -119,9 +217,10 @@ async def login_submit(request: Request, db: db_dependency):
return templates.TemplateResponse(request, "auth/login.html", { return templates.TemplateResponse(request, "auth/login.html", {
"error": "Invalid username or password", "error": "Invalid username or password",
}) })
token = create_access_token(user.username, user.id, timedelta(minutes=60)) access = create_access_token(user.username, user.id)
refresh = create_refresh_token(user.username, user.id, db)
response = RedirectResponse(url="/", status_code=303) response = RedirectResponse(url="/", status_code=303)
response.set_cookie(key="access_token", value=token, httponly=True, max_age=3600, samesite="lax") _set_auth_cookies(response, access, refresh)
return response return response
@@ -161,19 +260,24 @@ async def register_submit(request: Request, db: db_dependency):
"full_name": full_name, "full_name": full_name,
}) })
new_user = User( db.add(User(
username=username, username=username,
email=email or None, email=email or None,
full_name=full_name or None, full_name=full_name or None,
hashed_password=bcrypt_context.hash(password), hashed_password=bcrypt_context.hash(password),
) ))
db.add(new_user)
db.commit() db.commit()
return RedirectResponse(url="/auth/login?registered=1", status_code=303) return RedirectResponse(url="/auth/login?registered=1", status_code=303)
@router.get("/logout") @router.get("/logout")
async def logout(): async def logout(request: Request, db: db_dependency):
refresh_token = request.cookies.get("refresh_token")
if refresh_token:
db_token = db.query(RefreshToken).filter(RefreshToken.token == refresh_token).first()
if db_token:
db_token.revoked = True
db.commit()
response = RedirectResponse(url="/", status_code=303) response = RedirectResponse(url="/", status_code=303)
response.delete_cookie("access_token") _clear_auth_cookies(response)
return response return response
+2
View File
@@ -6,6 +6,8 @@ import jinja2
SECRET_KEY = os.getenv("SECRET_KEY", "nexhome-dev-secret-key-change-in-production") SECRET_KEY = os.getenv("SECRET_KEY", "nexhome-dev-secret-key-change-in-production")
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Disable Jinja2 cache to work around Python 3.14 compatibility issue # Disable Jinja2 cache to work around Python 3.14 compatibility issue
_loader = jinja2.FileSystemLoader("templates") _loader = jinja2.FileSystemLoader("templates")
+17 -1
View File
@@ -3,11 +3,12 @@ from typing import Annotated, Optional
from fastapi import FastAPI, Depends, Request from fastapi import FastAPI, Depends, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from starlette import status from starlette import status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.staticfiles import StaticFiles from starlette.staticfiles import StaticFiles
import models import models
from auth import get_db, get_current_user, router as auth_router from auth import get_db, get_current_user, router as auth_router
from config import templates from config import templates, ACCESS_TOKEN_EXPIRE_MINUTES
from models import Property from models import Property
from properties import router as properties_router from properties import router as properties_router
@@ -19,6 +20,21 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
models.Base.metadata.create_all(bind=__import__("database", fromlist=["engine"]).engine) models.Base.metadata.create_all(bind=__import__("database", fromlist=["engine"]).engine)
class RefreshTokenMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
new_token = getattr(request.state, "new_access_token", None)
if new_token:
response.set_cookie(
"access_token", new_token,
httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, samesite="lax",
)
return response
app.add_middleware(RefreshTokenMiddleware)
db_dependency = Annotated[Session, Depends(get_db)] db_dependency = Annotated[Session, Depends(get_db)]
user_dependency = Annotated[Optional[dict], Depends(get_current_user)] user_dependency = Annotated[Optional[dict], Depends(get_current_user)]
+14 -1
View File
@@ -13,11 +13,11 @@ class User(Base):
hashed_password = Column(String(60)) hashed_password = Column(String(60))
full_name = Column(String(100), nullable=True) full_name = Column(String(100), nullable=True)
phone = Column(String(20), nullable=True) phone = Column(String(20), nullable=True)
is_agent = Column(Boolean, default=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
properties = relationship("Property", back_populates="owner") properties = relationship("Property", back_populates="owner")
favorites = relationship("Favorite", back_populates="user") favorites = relationship("Favorite", back_populates="user")
refresh_tokens = relationship("RefreshToken", back_populates="user")
class Property(Base): class Property(Base):
@@ -74,3 +74,16 @@ class Favorite(Base):
property = relationship("Property", back_populates="favorites") property = relationship("Property", back_populates="favorites")
__table_args__ = (UniqueConstraint("user_id", "property_id", name="uq_user_property"),) __table_args__ = (UniqueConstraint("user_id", "property_id", name="uq_user_property"),)
class RefreshToken(Base):
__tablename__ = "refresh_tokens"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
token = Column(String(500), unique=True, nullable=False, index=True)
expires_at = Column(DateTime, nullable=False)
revoked = Column(Boolean, default=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
user = relationship("User", back_populates="refresh_tokens")