""" Authentication utilities for JWT-based session management with role-based expiration times. """ from datetime import datetime, timedelta, timezone from typing import Optional, Union from jose import JWTError, jwt from passlib.context import CryptContext from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlmodel import Session, select from app.core.config import settings from app.core.db import get_session from app.schemas.models import User from app.schemas.schemas import TokenData, UserResponse from app.schemas.base import UserRole # Password hashing pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # Security scheme security = HTTPBearer() def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a plain password against its hash.""" return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: """Generate password hash.""" return pwd_context.hash(password) def authenticate_user( session: Session, username: str, password: str ) -> Optional[User]: """Authenticate user with username and password.""" statement = select(User).where(User.username == username) user = session.exec(statement).first() if not user: return None if not verify_password(password, user.password_hash): return None return user def get_token_expiration_minutes(role: UserRole) -> int: """Get token expiration time based on user role.""" role_expiration_map = { UserRole.ADMIN: settings.admin_token_expire_minutes, UserRole.WRITE: settings.write_token_expire_minutes, UserRole.READ_ONLY: settings.read_only_token_expire_minutes, } return role_expiration_map.get(role, settings.read_only_token_expire_minutes) def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): """Create JWT access token.""" to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(minutes=15) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm) return encoded_jwt def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> TokenData: """Verify JWT token and extract token data.""" credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: token = credentials.credentials payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm]) username: Optional[str] = payload.get("sub") user_id: Optional[int] = payload.get("user_id") role: Optional[str] = payload.get("role") if username is None or user_id is None or role is None: raise credentials_exception token_data = TokenData( username=username, user_id=user_id, role=UserRole(role) ) except JWTError: raise credentials_exception return token_data def get_current_user( token_data: TokenData = Depends(verify_token), session: Session = Depends(get_session) ) -> User: """Get current user from token.""" credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) user = session.get(User, token_data.user_id) if user is None: raise credentials_exception return user def get_current_active_user( current_user: UserResponse = Depends(get_current_user) ) -> UserResponse: """Get current active user (extend this if you add user activation status).""" return current_user def require_role(required_roles: list[UserRole]): """Dependency factory for role-based access control.""" def role_checker(current_user: User = Depends(get_current_active_user)) -> User: if current_user.role not in required_roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Operation not permitted for your role" ) return current_user return role_checker # Common role dependencies require_admin = require_role([UserRole.ADMIN]) require_write_access = require_role([UserRole.ADMIN, UserRole.WRITE]) require_any_access = require_role([UserRole.ADMIN, UserRole.WRITE, UserRole.READ_ONLY])