mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 16:22:27 +08:00
Initial commit
This commit is contained in:
387
src/api/base/pipeline.py
Normal file
387
src/api/base/pipeline.py
Normal file
@@ -0,0 +1,387 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from src.core.exceptions import QuotaExceededException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, AuditEventType, User, UserRole
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.cache.user_cache import UserCacheService
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
from .adapter import ApiAdapter, ApiMode
|
||||
from .context import ApiRequestContext
|
||||
|
||||
|
||||
|
||||
class ApiRequestPipeline:
|
||||
"""负责统一执行认证、配额校验、上下文构建等通用逻辑的管道。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auth_service: AuthService = AuthService,
|
||||
usage_service: UsageService = UsageService,
|
||||
audit_service: AuditService = AuditService,
|
||||
):
|
||||
self.auth_service = auth_service
|
||||
self.usage_service = usage_service
|
||||
self.audit_service = audit_service
|
||||
|
||||
async def run(
|
||||
self,
|
||||
adapter: ApiAdapter,
|
||||
http_request: Request,
|
||||
db: Session,
|
||||
*,
|
||||
mode: ApiMode = ApiMode.STANDARD,
|
||||
api_format_hint: Optional[str] = None,
|
||||
path_params: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
logger.debug(f"[Pipeline] START | path={http_request.url.path}")
|
||||
logger.debug(f"[Pipeline] Running with mode={mode}, adapter={adapter.__class__.__name__}, "
|
||||
f"adapter.mode={adapter.mode}, path={http_request.url.path}")
|
||||
if mode == ApiMode.ADMIN:
|
||||
user = await self._authenticate_admin(http_request, db)
|
||||
api_key = None
|
||||
elif mode == ApiMode.USER:
|
||||
user = await self._authenticate_user(http_request, db)
|
||||
api_key = None
|
||||
elif mode == ApiMode.PUBLIC:
|
||||
user = None
|
||||
api_key = None
|
||||
else:
|
||||
logger.debug("[Pipeline] 调用 _authenticate_client")
|
||||
user, api_key = self._authenticate_client(http_request, db, adapter)
|
||||
logger.debug(f"[Pipeline] 认证完成 | user={user.username if user else None}")
|
||||
|
||||
raw_body = None
|
||||
if http_request.method in {"POST", "PUT", "PATCH"}:
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
# 添加30秒超时防止卡死
|
||||
raw_body = await asyncio.wait_for(http_request.body(), timeout=30.0)
|
||||
logger.debug(f"[Pipeline] Raw body读取完成 | size={len(raw_body) if raw_body is not None else 0} bytes")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("读取请求体超时(30s),可能客户端未发送完整请求体")
|
||||
raise HTTPException(
|
||||
status_code=408, detail="Request timeout: body not received within 30 seconds"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"[Pipeline] 非写请求跳过读取Body | method={http_request.method}")
|
||||
|
||||
context = ApiRequestContext.build(
|
||||
request=http_request,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
raw_body=raw_body,
|
||||
mode=mode.value,
|
||||
api_format_hint=api_format_hint,
|
||||
path_params=path_params,
|
||||
)
|
||||
logger.debug(f"[Pipeline] Context构建完成 | adapter={adapter.name} | request_id={context.request_id}")
|
||||
|
||||
if mode != ApiMode.ADMIN and user:
|
||||
context.quota_remaining = self._calculate_quota_remaining(user)
|
||||
|
||||
logger.debug(f"[Pipeline] Adapter={adapter.name} | RequestID={context.request_id}")
|
||||
|
||||
logger.debug(f"[Pipeline] Calling authorize on {adapter.__class__.__name__}, user={context.user}")
|
||||
# authorize 可能是异步的,需要检查并 await
|
||||
authorize_result = adapter.authorize(context)
|
||||
if hasattr(authorize_result, "__await__"):
|
||||
await authorize_result
|
||||
|
||||
try:
|
||||
response = await adapter.handle(context)
|
||||
status_code = getattr(response, "status_code", None)
|
||||
self._record_audit_event(context, adapter, success=True, status_code=status_code)
|
||||
return response
|
||||
except HTTPException as exc:
|
||||
err_detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
||||
self._record_audit_event(
|
||||
context,
|
||||
adapter,
|
||||
success=False,
|
||||
status_code=exc.status_code,
|
||||
error=err_detail,
|
||||
)
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._record_audit_event(
|
||||
context,
|
||||
adapter,
|
||||
success=False,
|
||||
status_code=500,
|
||||
error=str(exc),
|
||||
)
|
||||
raise
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Internal helpers
|
||||
# --------------------------------------------------------------------- #
|
||||
|
||||
def _authenticate_client(
|
||||
self, request: Request, db: Session, adapter: ApiAdapter
|
||||
) -> Tuple[User, ApiKey]:
|
||||
logger.debug("[Pipeline._authenticate_client] 开始")
|
||||
# 使用 adapter 的 extract_api_key 方法,支持不同 API 格式的认证头
|
||||
client_api_key = adapter.extract_api_key(request)
|
||||
logger.debug(f"[Pipeline._authenticate_client] 提取API密钥完成 | key_prefix={client_api_key[:8] if client_api_key else None}...")
|
||||
if not client_api_key:
|
||||
raise HTTPException(status_code=401, detail="请提供API密钥")
|
||||
|
||||
logger.debug("[Pipeline._authenticate_client] 调用 auth_service.authenticate_api_key")
|
||||
auth_result = self.auth_service.authenticate_api_key(db, client_api_key)
|
||||
logger.debug(f"[Pipeline._authenticate_client] 认证结果 | result={bool(auth_result)}")
|
||||
if not auth_result:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
user, api_key = auth_result
|
||||
if not user or not api_key:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
request.state.user_id = user.id
|
||||
request.state.api_key_id = api_key.id
|
||||
|
||||
# 检查配额或余额(支持独立Key)
|
||||
quota_ok, message = self.usage_service.check_user_quota(db, user, api_key=api_key)
|
||||
if not quota_ok:
|
||||
# 根据Key类型计算剩余额度
|
||||
if api_key.is_standalone:
|
||||
# 独立Key:显示剩余余额
|
||||
remaining = (
|
||||
None
|
||||
if api_key.current_balance_usd is None
|
||||
else float(api_key.current_balance_usd - (api_key.balance_used_usd or 0))
|
||||
)
|
||||
else:
|
||||
# 普通Key:显示用户配额剩余
|
||||
remaining = (
|
||||
None
|
||||
if user.quota_usd is None or user.quota_usd < 0
|
||||
else float(user.quota_usd - user.used_usd)
|
||||
)
|
||||
raise QuotaExceededException(quota_type="USD", remaining=remaining)
|
||||
|
||||
return user, api_key
|
||||
|
||||
async def _authenticate_admin(self, request: Request, db: Session) -> User:
|
||||
authorization = request.headers.get("authorization")
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Admin token 验证失败: {exc}")
|
||||
raise HTTPException(status_code=401, detail="无效的管理员令牌")
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的管理员令牌")
|
||||
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
request.state.user_id = user.id
|
||||
return user
|
||||
|
||||
async def _authenticate_user(self, request: Request, db: Session) -> User:
|
||||
"""JWT 认证普通用户(不要求管理员权限)"""
|
||||
authorization = request.headers.get("authorization")
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"User token 验证失败: {exc}")
|
||||
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
||||
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
request.state.user_id = user.id
|
||||
return user
|
||||
|
||||
def _calculate_quota_remaining(self, user: Optional[User]) -> Optional[float]:
|
||||
if not user:
|
||||
return None
|
||||
if user.quota_usd is None or user.quota_usd < 0:
|
||||
return None
|
||||
return max(float(user.quota_usd - user.used_usd), 0.0)
|
||||
|
||||
def _record_audit_event(
|
||||
self,
|
||||
context: ApiRequestContext,
|
||||
adapter: ApiAdapter,
|
||||
*,
|
||||
success: bool,
|
||||
status_code: Optional[int] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
if not getattr(adapter, "audit_log_enabled", True):
|
||||
return
|
||||
|
||||
bind = context.db.get_bind()
|
||||
if bind is None:
|
||||
return
|
||||
|
||||
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
|
||||
if not event_type:
|
||||
if not success and status_code == 401:
|
||||
event_type = AuditEventType.UNAUTHORIZED_ACCESS
|
||||
else:
|
||||
event_type = (
|
||||
AuditEventType.REQUEST_SUCCESS if success else AuditEventType.REQUEST_FAILED
|
||||
)
|
||||
|
||||
metadata = self._build_audit_metadata(
|
||||
context=context,
|
||||
adapter=adapter,
|
||||
success=success,
|
||||
status_code=status_code,
|
||||
error=error,
|
||||
)
|
||||
|
||||
SessionMaker = sessionmaker(bind=bind)
|
||||
audit_session = SessionMaker()
|
||||
try:
|
||||
self.audit_service.log_event(
|
||||
db=audit_session,
|
||||
event_type=event_type,
|
||||
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
|
||||
user_id=context.user.id if context.user else None,
|
||||
api_key_id=context.api_key.id if context.api_key else None,
|
||||
ip_address=context.client_ip,
|
||||
user_agent=context.user_agent,
|
||||
request_id=context.request_id,
|
||||
status_code=status_code,
|
||||
error_message=error,
|
||||
metadata=metadata,
|
||||
)
|
||||
audit_session.commit()
|
||||
except Exception as exc:
|
||||
audit_session.rollback()
|
||||
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
|
||||
finally:
|
||||
audit_session.close()
|
||||
|
||||
def _build_audit_metadata(
|
||||
self,
|
||||
context: ApiRequestContext,
|
||||
adapter: ApiAdapter,
|
||||
*,
|
||||
success: bool,
|
||||
status_code: Optional[int],
|
||||
error: Optional[str],
|
||||
) -> dict:
|
||||
duration_ms = max((time.time() - context.start_time) * 1000, 0.0)
|
||||
request = context.request
|
||||
path_params = {}
|
||||
try:
|
||||
path_params = dict(getattr(request, "path_params", {}) or {})
|
||||
except Exception:
|
||||
path_params = {}
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"path": request.url.path,
|
||||
"path_params": path_params,
|
||||
"method": request.method,
|
||||
"adapter": adapter.name,
|
||||
"adapter_class": adapter.__class__.__name__,
|
||||
"adapter_mode": getattr(adapter.mode, "value", str(adapter.mode)),
|
||||
"mode": context.mode,
|
||||
"api_format_hint": context.api_format_hint,
|
||||
"query": context.query_params,
|
||||
"duration_ms": round(duration_ms, 2),
|
||||
"request_body_bytes": len(context.raw_body or b""),
|
||||
"has_body": bool(context.raw_body),
|
||||
"request_content_type": request.headers.get("content-type"),
|
||||
"quota_remaining": context.quota_remaining,
|
||||
"success": success,
|
||||
}
|
||||
if status_code is not None:
|
||||
metadata["status_code"] = status_code
|
||||
|
||||
if context.user and getattr(context.user, "role", None):
|
||||
role = context.user.role
|
||||
metadata["user_role"] = getattr(role, "value", role)
|
||||
|
||||
if context.api_key:
|
||||
if getattr(context.api_key, "name", None):
|
||||
metadata["api_key_name"] = context.api_key.name
|
||||
# 使用脱敏后的密钥显示
|
||||
if hasattr(context.api_key, "get_display_key"):
|
||||
metadata["api_key_display"] = context.api_key.get_display_key()
|
||||
|
||||
extra_details: dict[str, Any] = {}
|
||||
if context.audit_metadata:
|
||||
extra_details.update(context.audit_metadata)
|
||||
|
||||
try:
|
||||
adapter_details = adapter.get_audit_metadata(
|
||||
context,
|
||||
success=success,
|
||||
status_code=status_code,
|
||||
error=error,
|
||||
)
|
||||
if adapter_details:
|
||||
extra_details.update(adapter_details)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Audit] Adapter metadata failed: {adapter.__class__.__name__}: {exc}")
|
||||
|
||||
if extra_details:
|
||||
metadata["details"] = extra_details
|
||||
|
||||
if error:
|
||||
metadata["error"] = error
|
||||
|
||||
return self._sanitize_metadata(metadata)
|
||||
|
||||
def _sanitize_metadata(self, value: Any, depth: int = 0):
|
||||
if value is None:
|
||||
return None
|
||||
if depth > 5:
|
||||
return str(value)
|
||||
if isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
if isinstance(value, dict):
|
||||
sanitized = {}
|
||||
for key, val in value.items():
|
||||
cleaned = self._sanitize_metadata(val, depth + 1)
|
||||
if cleaned is not None:
|
||||
sanitized[str(key)] = cleaned
|
||||
return sanitized
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [self._sanitize_metadata(item, depth + 1) for item in value]
|
||||
if hasattr(value, "isoformat"):
|
||||
try:
|
||||
return value.isoformat()
|
||||
except Exception:
|
||||
return str(value)
|
||||
return str(value)
|
||||
Reference in New Issue
Block a user