♑✏️
- feature access/refresh tokens auth
This commit is contained in:
@@ -27,19 +27,21 @@ uvicorn main:app --reload
|
|||||||
|
|
||||||
**`main.py`** — FastAPI entry point. Creates the app, includes auth and properties routers, creates DB tables on startup, serves the homepage.
|
**`main.py`** — FastAPI entry point. Creates the app, includes auth and properties routers, creates DB tables on startup, serves the homepage.
|
||||||
|
|
||||||
**`auth.py`** — Authentication router (`/auth` prefix). User registration, login (JWT in HttpOnly cookie), logout. `get_current_user()` checks cookie first, then bearer token, returns `None` for anonymous users.
|
**`auth.py`** — Authentication router (`/auth` prefix). Access-refresh token system: access token (15 min) + refresh token (7 days, stored in DB). `get_current_user()` checks access token cookie/bearer first, auto-refreshes via refresh token if expired. Web login sets both cookies; logout revokes refresh token in DB.
|
||||||
|
|
||||||
**`config.py`** — Shared config: `SECRET_KEY`, `templates`, `bcrypt_context`, `oauth2_bearer`. Jinja2 cache disabled for Python 3.14 compatibility.
|
**`config.py`** — Shared config: `SECRET_KEY`, `ALGORITHM`, `ACCESS_TOKEN_EXPIRE_MINUTES` (15), `REFRESH_TOKEN_EXPIRE_DAYS` (7), `templates`, `bcrypt_context`, `oauth2_bearer`. Jinja2 cache disabled for Python 3.14 compatibility.
|
||||||
|
|
||||||
**`database.py`** — SQLAlchemy setup with MySQL engine, `SessionLocal`, `Base`.
|
**`database.py`** — SQLAlchemy setup with MySQL engine, `SessionLocal`, `Base`.
|
||||||
|
|
||||||
**`models.py`** — ORM models: `User` (with email/full_name/phone), `Property` (with contact_email/contact_phone), `PropertyImage`, `Favorite`.
|
**`models.py`** — ORM models: `User` (with email/full_name/phone), `Property` (with contact_email/contact_phone), `PropertyImage`, `Favorite`, `RefreshToken` (for token rotation/revocation).
|
||||||
|
|
||||||
**`properties.py`** — Property CRUD, search/filter, favorites, dashboard. Static paths (`/properties/new`) must be defined before parameterized paths (`/properties/{prop_id}`).
|
**`properties.py`** — Property CRUD, search/filter, favorites, dashboard. Static paths (`/properties/new`) must be defined before parameterized paths (`/properties/{prop_id}`).
|
||||||
|
|
||||||
## Data Flow
|
## Data Flow
|
||||||
|
|
||||||
Request → Router → dependency injection (`get_current_user`, `get_db`) → SQLAlchemy → Jinja2 template → HTML response.
|
Request → `RefreshTokenMiddleware` (sets new access cookie if refreshed) → Router → dependency injection (`get_current_user`, `get_db`) → SQLAlchemy → Jinja2 template → HTML response.
|
||||||
|
|
||||||
|
**Auth flow:** Login → issue access (15 min) + refresh (7 days) tokens → both set as HttpOnly cookies. Page load → `get_current_user()` decodes access token; if expired, uses refresh token cookie to get new access token via `try_refresh()`. Logout → revoke refresh token in DB, clear both cookies.
|
||||||
|
|
||||||
**TemplateResponse API:** `templates.TemplateResponse(request, "template.html", {"user": user, ...})` — request is first arg (Starlette 1.x).
|
**TemplateResponse API:** `templates.TemplateResponse(request, "template.html", {"user": user, ...})` — request is first arg (Starlette 1.x).
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
@@ -8,9 +9,12 @@ from pydantic import BaseModel
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from starlette import status
|
from starlette import status
|
||||||
|
|
||||||
from config import SECRET_KEY, ALGORITHM, templates, bcrypt_context, oauth2_bearer
|
from config import (
|
||||||
|
SECRET_KEY, ALGORITHM, templates, bcrypt_context, oauth2_bearer,
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS,
|
||||||
|
)
|
||||||
from database import SessionLocal
|
from database import SessionLocal
|
||||||
from models import User
|
from models import User, RefreshToken
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/auth",
|
prefix="/auth",
|
||||||
@@ -29,15 +33,9 @@ def get_db():
|
|||||||
db_dependency = Annotated[Session, Depends(get_db)]
|
db_dependency = Annotated[Session, Depends(get_db)]
|
||||||
|
|
||||||
|
|
||||||
class CreateUserRequest(BaseModel):
|
|
||||||
username: str
|
|
||||||
password: str
|
|
||||||
email: str = ""
|
|
||||||
full_name: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
@@ -50,41 +48,104 @@ def authenticate_user(username: str, password: str, db: Session):
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(username: str, user_id: int, expires_delta: timedelta):
|
def create_access_token(username: str, user_id: int) -> str:
|
||||||
encode = {"sub": username, "id": user_id}
|
expires = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
expires = datetime.now(timezone.utc) + expires_delta
|
payload = {"sub": username, "id": user_id, "exp": expires, "type": "access"}
|
||||||
encode.update({"exp": expires})
|
return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
return jwt.encode(encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
||||||
|
|
||||||
|
def create_refresh_token(username: str, user_id: int, db: Session) -> str:
|
||||||
|
expires = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
token = jwt.encode(
|
||||||
|
{"sub": username, "id": user_id, "exp": expires, "type": "refresh", "jti": uuid.uuid4().hex},
|
||||||
|
SECRET_KEY, algorithm=ALGORITHM,
|
||||||
|
)
|
||||||
|
db.add(RefreshToken(user_id=user_id, token=token, expires_at=expires))
|
||||||
|
db.commit()
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def decode_token(token: str) -> dict | None:
|
||||||
|
try:
|
||||||
|
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
except JWTError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_user(payload: dict) -> dict | None:
|
||||||
|
username = payload.get("sub")
|
||||||
|
user_id = payload.get("id")
|
||||||
|
if username is None or user_id is None:
|
||||||
|
return None
|
||||||
|
return {"username": username, "id": user_id}
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
request: Request,
|
request: Request,
|
||||||
token: Annotated[Optional[str], Depends(oauth2_bearer)],
|
token: Annotated[Optional[str], Depends(oauth2_bearer)],
|
||||||
):
|
) -> dict | None:
|
||||||
cookie_token = request.cookies.get("access_token")
|
# Try access token from cookie
|
||||||
effective_token = cookie_token or token
|
access_cookie = request.cookies.get("access_token")
|
||||||
if not effective_token:
|
effective = access_cookie or token
|
||||||
|
if effective:
|
||||||
|
payload = decode_token(effective)
|
||||||
|
if payload and payload.get("type") == "access":
|
||||||
|
return _decode_user(payload)
|
||||||
|
|
||||||
|
# Access token missing/expired — try refresh
|
||||||
|
refresh_cookie = request.cookies.get("refresh_token")
|
||||||
|
if refresh_cookie:
|
||||||
|
new_access = try_refresh(refresh_cookie, request)
|
||||||
|
if new_access:
|
||||||
|
payload = decode_token(new_access)
|
||||||
|
if payload:
|
||||||
|
return _decode_user(payload)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def try_refresh(refresh_token: str, request: Request) -> str | None:
|
||||||
|
payload = decode_token(refresh_token)
|
||||||
|
if not payload or payload.get("type") != "refresh":
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(effective_token, SECRET_KEY, algorithms=[ALGORITHM])
|
db_token = db.query(RefreshToken).filter(
|
||||||
username: str = payload.get("sub")
|
RefreshToken.token == refresh_token,
|
||||||
user_id: int = payload.get("id")
|
RefreshToken.revoked == False,
|
||||||
if username is None or user_id is None:
|
).first()
|
||||||
|
if not db_token:
|
||||||
return None
|
return None
|
||||||
return {"username": username, "id": user_id}
|
if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||||
except JWTError:
|
db_token.revoked = True
|
||||||
return None
|
db.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
new_access = create_access_token(payload["sub"], payload["id"])
|
||||||
|
|
||||||
|
# Stash new access token on request state so the caller can set the cookie
|
||||||
|
request.state.new_access_token = new_access
|
||||||
|
return new_access
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", status_code=status.HTTP_201_CREATED)
|
@router.post("/", status_code=status.HTTP_201_CREATED)
|
||||||
async def create_user(db: db_dependency, create_user_request: CreateUserRequest):
|
async def create_user(db: db_dependency, username: str = Form(...), password: str = Form(...),
|
||||||
create_user_model = User(
|
email: str = Form(""), full_name: str = Form("")):
|
||||||
username=create_user_request.username,
|
if not username or not password:
|
||||||
email=create_user_request.email or None,
|
raise HTTPException(status_code=400, detail="Username and password required")
|
||||||
full_name=create_user_request.full_name or None,
|
if len(password) < 6:
|
||||||
hashed_password=bcrypt_context.hash(create_user_request.password),
|
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
||||||
)
|
if db.query(User).filter(User.username == username).first():
|
||||||
db.add(create_user_model)
|
raise HTTPException(status_code=400, detail="Username already taken")
|
||||||
|
db.add(User(
|
||||||
|
username=username,
|
||||||
|
email=email or None,
|
||||||
|
full_name=full_name or None,
|
||||||
|
hashed_password=bcrypt_context.hash(password),
|
||||||
|
))
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
@@ -96,12 +157,49 @@ async def login_for_access_token(
|
|||||||
):
|
):
|
||||||
user = authenticate_user(username, password, db)
|
user = authenticate_user(username, password, db)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user.")
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
access = create_access_token(user.username, user.id)
|
||||||
detail="Could not validate user.",
|
refresh = create_refresh_token(user.username, user.id, db)
|
||||||
)
|
return {"access_token": access, "refresh_token": refresh, "token_type": "bearer"}
|
||||||
token = create_access_token(user.username, user.id, timedelta(minutes=60))
|
|
||||||
return {"access_token": token, "token_type": "bearer"}
|
|
||||||
|
@router.post("/refresh", response_model=Token)
|
||||||
|
async def refresh_access_token(
|
||||||
|
refresh_token: str = Form(...),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
payload = decode_token(refresh_token)
|
||||||
|
if not payload or payload.get("type") != "refresh":
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||||
|
|
||||||
|
db_token = db.query(RefreshToken).filter(
|
||||||
|
RefreshToken.token == refresh_token,
|
||||||
|
RefreshToken.revoked == False,
|
||||||
|
).first()
|
||||||
|
if not db_token:
|
||||||
|
raise HTTPException(status_code=401, detail="Refresh token revoked or not found")
|
||||||
|
if db_token.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc):
|
||||||
|
db_token.revoked = True
|
||||||
|
db.commit()
|
||||||
|
raise HTTPException(status_code=401, detail="Refresh token expired")
|
||||||
|
|
||||||
|
new_access = create_access_token(payload["sub"], payload["id"])
|
||||||
|
new_refresh = create_refresh_token(payload["sub"], payload["id"], db)
|
||||||
|
db_token.revoked = True
|
||||||
|
db.commit()
|
||||||
|
return {"access_token": new_access, "refresh_token": new_refresh, "token_type": "bearer"}
|
||||||
|
|
||||||
|
|
||||||
|
def _set_auth_cookies(response: RedirectResponse, access: str, refresh: str) -> RedirectResponse:
|
||||||
|
response.set_cookie("access_token", access, httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, samesite="lax")
|
||||||
|
response.set_cookie("refresh_token", refresh, httponly=True, max_age=REFRESH_TOKEN_EXPIRE_DAYS * 86400, samesite="lax")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_auth_cookies(response: RedirectResponse) -> RedirectResponse:
|
||||||
|
response.delete_cookie("access_token")
|
||||||
|
response.delete_cookie("refresh_token")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@router.get("/login", response_class=HTMLResponse)
|
@router.get("/login", response_class=HTMLResponse)
|
||||||
@@ -119,9 +217,10 @@ async def login_submit(request: Request, db: db_dependency):
|
|||||||
return templates.TemplateResponse(request, "auth/login.html", {
|
return templates.TemplateResponse(request, "auth/login.html", {
|
||||||
"error": "Invalid username or password",
|
"error": "Invalid username or password",
|
||||||
})
|
})
|
||||||
token = create_access_token(user.username, user.id, timedelta(minutes=60))
|
access = create_access_token(user.username, user.id)
|
||||||
|
refresh = create_refresh_token(user.username, user.id, db)
|
||||||
response = RedirectResponse(url="/", status_code=303)
|
response = RedirectResponse(url="/", status_code=303)
|
||||||
response.set_cookie(key="access_token", value=token, httponly=True, max_age=3600, samesite="lax")
|
_set_auth_cookies(response, access, refresh)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@@ -161,19 +260,24 @@ async def register_submit(request: Request, db: db_dependency):
|
|||||||
"full_name": full_name,
|
"full_name": full_name,
|
||||||
})
|
})
|
||||||
|
|
||||||
new_user = User(
|
db.add(User(
|
||||||
username=username,
|
username=username,
|
||||||
email=email or None,
|
email=email or None,
|
||||||
full_name=full_name or None,
|
full_name=full_name or None,
|
||||||
hashed_password=bcrypt_context.hash(password),
|
hashed_password=bcrypt_context.hash(password),
|
||||||
)
|
))
|
||||||
db.add(new_user)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return RedirectResponse(url="/auth/login?registered=1", status_code=303)
|
return RedirectResponse(url="/auth/login?registered=1", status_code=303)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/logout")
|
@router.get("/logout")
|
||||||
async def logout():
|
async def logout(request: Request, db: db_dependency):
|
||||||
|
refresh_token = request.cookies.get("refresh_token")
|
||||||
|
if refresh_token:
|
||||||
|
db_token = db.query(RefreshToken).filter(RefreshToken.token == refresh_token).first()
|
||||||
|
if db_token:
|
||||||
|
db_token.revoked = True
|
||||||
|
db.commit()
|
||||||
response = RedirectResponse(url="/", status_code=303)
|
response = RedirectResponse(url="/", status_code=303)
|
||||||
response.delete_cookie("access_token")
|
_clear_auth_cookies(response)
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import jinja2
|
|||||||
|
|
||||||
SECRET_KEY = os.getenv("SECRET_KEY", "nexhome-dev-secret-key-change-in-production")
|
SECRET_KEY = os.getenv("SECRET_KEY", "nexhome-dev-secret-key-change-in-production")
|
||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
||||||
|
|
||||||
# Disable Jinja2 cache to work around Python 3.14 compatibility issue
|
# Disable Jinja2 cache to work around Python 3.14 compatibility issue
|
||||||
_loader = jinja2.FileSystemLoader("templates")
|
_loader = jinja2.FileSystemLoader("templates")
|
||||||
|
|||||||
@@ -3,11 +3,12 @@ from typing import Annotated, Optional
|
|||||||
from fastapi import FastAPI, Depends, Request
|
from fastapi import FastAPI, Depends, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from starlette import status
|
from starlette import status
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.staticfiles import StaticFiles
|
from starlette.staticfiles import StaticFiles
|
||||||
|
|
||||||
import models
|
import models
|
||||||
from auth import get_db, get_current_user, router as auth_router
|
from auth import get_db, get_current_user, router as auth_router
|
||||||
from config import templates
|
from config import templates, ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
from models import Property
|
from models import Property
|
||||||
from properties import router as properties_router
|
from properties import router as properties_router
|
||||||
|
|
||||||
@@ -19,6 +20,21 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|||||||
|
|
||||||
models.Base.metadata.create_all(bind=__import__("database", fromlist=["engine"]).engine)
|
models.Base.metadata.create_all(bind=__import__("database", fromlist=["engine"]).engine)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
new_token = getattr(request.state, "new_access_token", None)
|
||||||
|
if new_token:
|
||||||
|
response.set_cookie(
|
||||||
|
"access_token", new_token,
|
||||||
|
httponly=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, samesite="lax",
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
app.add_middleware(RefreshTokenMiddleware)
|
||||||
|
|
||||||
|
|
||||||
db_dependency = Annotated[Session, Depends(get_db)]
|
db_dependency = Annotated[Session, Depends(get_db)]
|
||||||
user_dependency = Annotated[Optional[dict], Depends(get_current_user)]
|
user_dependency = Annotated[Optional[dict], Depends(get_current_user)]
|
||||||
|
|
||||||
|
|||||||
@@ -13,11 +13,11 @@ class User(Base):
|
|||||||
hashed_password = Column(String(60))
|
hashed_password = Column(String(60))
|
||||||
full_name = Column(String(100), nullable=True)
|
full_name = Column(String(100), nullable=True)
|
||||||
phone = Column(String(20), nullable=True)
|
phone = Column(String(20), nullable=True)
|
||||||
is_agent = Column(Boolean, default=False)
|
|
||||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
properties = relationship("Property", back_populates="owner")
|
properties = relationship("Property", back_populates="owner")
|
||||||
favorites = relationship("Favorite", back_populates="user")
|
favorites = relationship("Favorite", back_populates="user")
|
||||||
|
refresh_tokens = relationship("RefreshToken", back_populates="user")
|
||||||
|
|
||||||
|
|
||||||
class Property(Base):
|
class Property(Base):
|
||||||
@@ -74,3 +74,16 @@ class Favorite(Base):
|
|||||||
property = relationship("Property", back_populates="favorites")
|
property = relationship("Property", back_populates="favorites")
|
||||||
|
|
||||||
__table_args__ = (UniqueConstraint("user_id", "property_id", name="uq_user_property"),)
|
__table_args__ = (UniqueConstraint("user_id", "property_id", name="uq_user_property"),)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(Base):
|
||||||
|
__tablename__ = "refresh_tokens"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
|
token = Column(String(500), unique=True, nullable=False, index=True)
|
||||||
|
expires_at = Column(DateTime, nullable=False)
|
||||||
|
revoked = Column(Boolean, default=False)
|
||||||
|
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
user = relationship("User", back_populates="refresh_tokens")
|
||||||
|
|||||||
Reference in New Issue
Block a user