mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 08:12:26 +08:00
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling - Add trusted proxy count configuration for client IP extraction in reverse proxy environments - Implement audit log cleanup scheduler with configurable retention period - Replace plaintext token logging with SHA256 hash fingerprints for security - Fix database session lifecycle management in middleware - Improve request tracing and error logging throughout the system - Add comprehensive tests for pipeline architecture
388 lines
15 KiB
Python
388 lines
15 KiB
Python
from __future__ import annotations
|
||
|
||
import time
|
||
from enum import Enum
|
||
from typing import Any, Optional, Tuple
|
||
|
||
from fastapi import HTTPException, Request
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.core.exceptions import QuotaExceededException
|
||
from src.core.logger import logger
|
||
from src.models.database import ApiKey, AuditEventType, User, UserRole
|
||
from src.services.auth.service import AuthService
|
||
from src.services.system.audit import AuditService
|
||
from src.services.usage.service import UsageService
|
||
|
||
from .adapter import ApiAdapter, ApiMode
|
||
from .context import ApiRequestContext
|
||
|
||
|
||
|
||
class ApiRequestPipeline:
|
||
"""负责统一执行认证、配额校验、上下文构建等通用逻辑的管道。"""
|
||
|
||
def __init__(
|
||
self,
|
||
auth_service: AuthService = AuthService,
|
||
usage_service: UsageService = UsageService,
|
||
audit_service: AuditService = AuditService,
|
||
):
|
||
self.auth_service = auth_service
|
||
self.usage_service = usage_service
|
||
self.audit_service = audit_service
|
||
|
||
async def run(
|
||
self,
|
||
adapter: ApiAdapter,
|
||
http_request: Request,
|
||
db: Session,
|
||
*,
|
||
mode: ApiMode = ApiMode.STANDARD,
|
||
api_format_hint: Optional[str] = None,
|
||
path_params: Optional[dict[str, Any]] = None,
|
||
):
|
||
logger.debug(f"[Pipeline] START | path={http_request.url.path}")
|
||
logger.debug(f"[Pipeline] Running with mode={mode}, adapter={adapter.__class__.__name__}, "
|
||
f"adapter.mode={adapter.mode}, path={http_request.url.path}")
|
||
if mode == ApiMode.ADMIN:
|
||
user = await self._authenticate_admin(http_request, db)
|
||
api_key = None
|
||
elif mode == ApiMode.USER:
|
||
user = await self._authenticate_user(http_request, db)
|
||
api_key = None
|
||
elif mode == ApiMode.PUBLIC:
|
||
user = None
|
||
api_key = None
|
||
else:
|
||
logger.debug("[Pipeline] 调用 _authenticate_client")
|
||
user, api_key = self._authenticate_client(http_request, db, adapter)
|
||
logger.debug(f"[Pipeline] 认证完成 | user={user.username if user else None}")
|
||
|
||
raw_body = None
|
||
if http_request.method in {"POST", "PUT", "PATCH"}:
|
||
try:
|
||
import asyncio
|
||
|
||
# 添加30秒超时防止卡死
|
||
raw_body = await asyncio.wait_for(http_request.body(), timeout=30.0)
|
||
logger.debug(f"[Pipeline] Raw body读取完成 | size={len(raw_body) if raw_body is not None else 0} bytes")
|
||
except asyncio.TimeoutError:
|
||
logger.error("读取请求体超时(30s),可能客户端未发送完整请求体")
|
||
raise HTTPException(
|
||
status_code=408, detail="Request timeout: body not received within 30 seconds"
|
||
)
|
||
else:
|
||
logger.debug(f"[Pipeline] 非写请求跳过读取Body | method={http_request.method}")
|
||
|
||
context = ApiRequestContext.build(
|
||
request=http_request,
|
||
db=db,
|
||
user=user,
|
||
api_key=api_key,
|
||
raw_body=raw_body,
|
||
mode=mode.value,
|
||
api_format_hint=api_format_hint,
|
||
path_params=path_params,
|
||
)
|
||
logger.debug(f"[Pipeline] Context构建完成 | adapter={adapter.name} | request_id={context.request_id}")
|
||
|
||
if mode != ApiMode.ADMIN and user:
|
||
context.quota_remaining = self._calculate_quota_remaining(user)
|
||
|
||
logger.debug(f"[Pipeline] Adapter={adapter.name} | RequestID={context.request_id}")
|
||
|
||
logger.debug(f"[Pipeline] Calling authorize on {adapter.__class__.__name__}, user={context.user}")
|
||
# authorize 可能是异步的,需要检查并 await
|
||
authorize_result = adapter.authorize(context)
|
||
if hasattr(authorize_result, "__await__"):
|
||
await authorize_result
|
||
|
||
try:
|
||
response = await adapter.handle(context)
|
||
status_code = getattr(response, "status_code", None)
|
||
self._record_audit_event(context, adapter, success=True, status_code=status_code)
|
||
return response
|
||
except HTTPException as exc:
|
||
err_detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
||
self._record_audit_event(
|
||
context,
|
||
adapter,
|
||
success=False,
|
||
status_code=exc.status_code,
|
||
error=err_detail,
|
||
)
|
||
raise
|
||
except Exception as exc:
|
||
self._record_audit_event(
|
||
context,
|
||
adapter,
|
||
success=False,
|
||
status_code=500,
|
||
error=str(exc),
|
||
)
|
||
raise
|
||
|
||
# --------------------------------------------------------------------- #
|
||
# Internal helpers
|
||
# --------------------------------------------------------------------- #
|
||
|
||
def _authenticate_client(
|
||
self, request: Request, db: Session, adapter: ApiAdapter
|
||
) -> Tuple[User, ApiKey]:
|
||
logger.debug("[Pipeline._authenticate_client] 开始")
|
||
# 使用 adapter 的 extract_api_key 方法,支持不同 API 格式的认证头
|
||
client_api_key = adapter.extract_api_key(request)
|
||
logger.debug(f"[Pipeline._authenticate_client] 提取API密钥完成 | key_prefix={client_api_key[:8] if client_api_key else None}...")
|
||
if not client_api_key:
|
||
raise HTTPException(status_code=401, detail="请提供API密钥")
|
||
|
||
logger.debug("[Pipeline._authenticate_client] 调用 auth_service.authenticate_api_key")
|
||
auth_result = self.auth_service.authenticate_api_key(db, client_api_key)
|
||
logger.debug(f"[Pipeline._authenticate_client] 认证结果 | result={bool(auth_result)}")
|
||
if not auth_result:
|
||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||
|
||
user, api_key = auth_result
|
||
if not user or not api_key:
|
||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||
|
||
request.state.user_id = user.id
|
||
request.state.api_key_id = api_key.id
|
||
|
||
# 检查配额或余额(支持独立Key)
|
||
quota_ok, message = self.usage_service.check_user_quota(db, user, api_key=api_key)
|
||
if not quota_ok:
|
||
# 根据Key类型计算剩余额度
|
||
if api_key.is_standalone:
|
||
# 独立Key:显示剩余余额
|
||
remaining = (
|
||
None
|
||
if api_key.current_balance_usd is None
|
||
else float(api_key.current_balance_usd - (api_key.balance_used_usd or 0))
|
||
)
|
||
else:
|
||
# 普通Key:显示用户配额剩余
|
||
remaining = (
|
||
None
|
||
if user.quota_usd is None or user.quota_usd < 0
|
||
else float(user.quota_usd - user.used_usd)
|
||
)
|
||
raise QuotaExceededException(quota_type="USD", remaining=remaining)
|
||
|
||
return user, api_key
|
||
|
||
async def _authenticate_admin(self, request: Request, db: Session) -> User:
|
||
authorization = request.headers.get("authorization")
|
||
if not authorization or not authorization.lower().startswith("bearer "):
|
||
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
||
|
||
token = authorization[7:].strip()
|
||
try:
|
||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||
except HTTPException:
|
||
raise
|
||
except Exception as exc:
|
||
logger.error(f"Admin token 验证失败: {exc}")
|
||
raise HTTPException(status_code=401, detail="无效的管理员令牌")
|
||
|
||
user_id = payload.get("user_id")
|
||
if not user_id:
|
||
raise HTTPException(status_code=401, detail="无效的管理员令牌")
|
||
|
||
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if not user or not user.is_active:
|
||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||
|
||
request.state.user_id = user.id
|
||
return user
|
||
|
||
async def _authenticate_user(self, request: Request, db: Session) -> User:
|
||
"""JWT 认证普通用户(不要求管理员权限)"""
|
||
authorization = request.headers.get("authorization")
|
||
if not authorization or not authorization.lower().startswith("bearer "):
|
||
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
||
|
||
token = authorization[7:].strip()
|
||
try:
|
||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||
except HTTPException:
|
||
raise
|
||
except Exception as exc:
|
||
logger.error(f"User token 验证失败: {exc}")
|
||
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
||
|
||
user_id = payload.get("user_id")
|
||
if not user_id:
|
||
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
||
|
||
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if not user or not user.is_active:
|
||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||
|
||
request.state.user_id = user.id
|
||
return user
|
||
|
||
def _calculate_quota_remaining(self, user: Optional[User]) -> Optional[float]:
|
||
if not user:
|
||
return None
|
||
if user.quota_usd is None or user.quota_usd < 0:
|
||
return None
|
||
return max(float(user.quota_usd - user.used_usd), 0.0)
|
||
|
||
def _record_audit_event(
|
||
self,
|
||
context: ApiRequestContext,
|
||
adapter: ApiAdapter,
|
||
*,
|
||
success: bool,
|
||
status_code: Optional[int] = None,
|
||
error: Optional[str] = None,
|
||
) -> None:
|
||
"""记录审计事件
|
||
|
||
事务策略:复用请求级 Session,不单独提交。
|
||
审计记录随主事务一起提交,由中间件统一管理。
|
||
"""
|
||
if not getattr(adapter, "audit_log_enabled", True):
|
||
return
|
||
|
||
if context.db is None:
|
||
return
|
||
|
||
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
|
||
if not event_type:
|
||
if not success and status_code == 401:
|
||
event_type = AuditEventType.UNAUTHORIZED_ACCESS
|
||
else:
|
||
event_type = (
|
||
AuditEventType.REQUEST_SUCCESS if success else AuditEventType.REQUEST_FAILED
|
||
)
|
||
|
||
metadata = self._build_audit_metadata(
|
||
context=context,
|
||
adapter=adapter,
|
||
success=success,
|
||
status_code=status_code,
|
||
error=error,
|
||
)
|
||
|
||
try:
|
||
# 复用请求级 Session,不创建新的连接
|
||
# 审计记录随主事务一起提交,由中间件统一管理
|
||
self.audit_service.log_event(
|
||
db=context.db,
|
||
event_type=event_type,
|
||
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
|
||
user_id=context.user.id if context.user else None,
|
||
api_key_id=context.api_key.id if context.api_key else None,
|
||
ip_address=context.client_ip,
|
||
user_agent=context.user_agent,
|
||
request_id=context.request_id,
|
||
status_code=status_code,
|
||
error_message=error,
|
||
metadata=metadata,
|
||
)
|
||
except Exception as exc:
|
||
# 审计失败不应影响主请求,仅记录警告
|
||
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
|
||
|
||
def _build_audit_metadata(
|
||
self,
|
||
context: ApiRequestContext,
|
||
adapter: ApiAdapter,
|
||
*,
|
||
success: bool,
|
||
status_code: Optional[int],
|
||
error: Optional[str],
|
||
) -> dict:
|
||
duration_ms = max((time.time() - context.start_time) * 1000, 0.0)
|
||
request = context.request
|
||
path_params = {}
|
||
try:
|
||
path_params = dict(getattr(request, "path_params", {}) or {})
|
||
except Exception:
|
||
path_params = {}
|
||
|
||
metadata: dict[str, Any] = {
|
||
"path": request.url.path,
|
||
"path_params": path_params,
|
||
"method": request.method,
|
||
"adapter": adapter.name,
|
||
"adapter_class": adapter.__class__.__name__,
|
||
"adapter_mode": getattr(adapter.mode, "value", str(adapter.mode)),
|
||
"mode": context.mode,
|
||
"api_format_hint": context.api_format_hint,
|
||
"query": context.query_params,
|
||
"duration_ms": round(duration_ms, 2),
|
||
"request_body_bytes": len(context.raw_body or b""),
|
||
"has_body": bool(context.raw_body),
|
||
"request_content_type": request.headers.get("content-type"),
|
||
"quota_remaining": context.quota_remaining,
|
||
"success": success,
|
||
}
|
||
if status_code is not None:
|
||
metadata["status_code"] = status_code
|
||
|
||
if context.user and getattr(context.user, "role", None):
|
||
role = context.user.role
|
||
metadata["user_role"] = getattr(role, "value", role)
|
||
|
||
if context.api_key:
|
||
if getattr(context.api_key, "name", None):
|
||
metadata["api_key_name"] = context.api_key.name
|
||
# 使用脱敏后的密钥显示
|
||
if hasattr(context.api_key, "get_display_key"):
|
||
metadata["api_key_display"] = context.api_key.get_display_key()
|
||
|
||
extra_details: dict[str, Any] = {}
|
||
if context.audit_metadata:
|
||
extra_details.update(context.audit_metadata)
|
||
|
||
try:
|
||
adapter_details = adapter.get_audit_metadata(
|
||
context,
|
||
success=success,
|
||
status_code=status_code,
|
||
error=error,
|
||
)
|
||
if adapter_details:
|
||
extra_details.update(adapter_details)
|
||
except Exception as exc:
|
||
logger.warning(f"[Audit] Adapter metadata failed: {adapter.__class__.__name__}: {exc}")
|
||
|
||
if extra_details:
|
||
metadata["details"] = extra_details
|
||
|
||
if error:
|
||
metadata["error"] = error
|
||
|
||
return self._sanitize_metadata(metadata)
|
||
|
||
def _sanitize_metadata(self, value: Any, depth: int = 0):
|
||
if value is None:
|
||
return None
|
||
if depth > 5:
|
||
return str(value)
|
||
if isinstance(value, (str, int, float, bool)):
|
||
return value
|
||
if isinstance(value, Enum):
|
||
return value.value
|
||
if isinstance(value, dict):
|
||
sanitized = {}
|
||
for key, val in value.items():
|
||
cleaned = self._sanitize_metadata(val, depth + 1)
|
||
if cleaned is not None:
|
||
sanitized[str(key)] = cleaned
|
||
return sanitized
|
||
if isinstance(value, (list, tuple, set)):
|
||
return [self._sanitize_metadata(item, depth + 1) for item in value]
|
||
if hasattr(value, "isoformat"):
|
||
try:
|
||
return value.isoformat()
|
||
except Exception:
|
||
return str(value)
|
||
return str(value)
|