Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

82
src/api/base/adapter.py Normal file
View File

@@ -0,0 +1,82 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional
from fastapi import Request, Response
from .context import ApiRequestContext
class ApiMode(str, Enum):
STANDARD = "standard"
PROXY = "proxy"
ADMIN = "admin"
USER = "user" # JWT 认证的普通用户(不要求管理员权限)
PUBLIC = "public"
class ApiAdapter(ABC):
"""所有API格式适配器的抽象基类。"""
name: str = "base"
mode: ApiMode = ApiMode.STANDARD
api_format: Optional[str] = None # 对应 Provider API 格式提示
audit_log_enabled: bool = True
audit_success_event = None
audit_failure_event = None
@abstractmethod
async def handle(self, context: ApiRequestContext) -> Response:
"""处理请求并返回 FastAPI Response。"""
def authorize(self, context: ApiRequestContext) -> None:
"""可选的授权钩子,默认允许通过。"""
return None
def extract_api_key(self, request: Request) -> Optional[str]:
"""
从请求中提取客户端 API 密钥。
子类应覆盖此方法以支持各自的认证头格式。
Args:
request: FastAPI Request 对象
Returns:
提取的 API 密钥,如果未找到则返回 None
"""
return None
def get_audit_metadata(
self,
context: ApiRequestContext,
*,
success: bool,
status_code: Optional[int],
error: Optional[str] = None,
) -> Dict[str, Any]:
"""允许适配器在审计日志中追加自定义字段。"""
return {}
def detect_capability_requirements(
self,
headers: Dict[str, str],
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""
检测请求中隐含的能力需求(子类可覆盖)
不同 API 格式有不同的能力声明方式,例如:
- Claude: anthropic-beta: context-1m-xxx 表示需要 1M 上下文
- 其他格式可能有不同的请求头或请求体字段
Args:
headers: 请求头字典
request_body: 请求体字典(可选)
Returns:
检测到的能力需求,如 {"context_1m": True}
"""
return {}

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from fastapi import HTTPException
from src.models.database import UserRole
from .adapter import ApiAdapter, ApiMode
from .context import ApiRequestContext
class AdminApiAdapter(ApiAdapter):
"""管理员端点适配器基类,提供统一的权限校验。"""
mode = ApiMode.ADMIN
required_roles: tuple[UserRole, ...] = (UserRole.ADMIN,)
def authorize(self, context: ApiRequestContext) -> None:
user = context.user
if not user:
raise HTTPException(status_code=401, detail="未登录")
# 检查是否使用独立余额Key访问管理接口
if context.api_key and context.api_key.is_standalone:
raise HTTPException(
status_code=403, detail="独立余额Key不允许访问管理接口仅可用于代理请求"
)
if not any(user.role == role for role in self.required_roles):
raise HTTPException(status_code=403, detail="需要管理员权限")

View File

@@ -0,0 +1,13 @@
from fastapi import HTTPException
from .adapter import ApiAdapter, ApiMode
class AuthenticatedApiAdapter(ApiAdapter):
"""通用需要登录的适配器基类。"""
mode = ApiMode.USER
def authorize(self, context): # type: ignore[override]
if not context.user:
raise HTTPException(status_code=401, detail="未登录")

116
src/api/base/context.py Normal file
View File

@@ -0,0 +1,116 @@
from __future__ import annotations
import json
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import ApiKey, User
@dataclass
class ApiRequestContext:
"""统一的API请求上下文贯穿Pipeline与格式适配器。"""
request: Request
db: Session
user: Optional[User]
api_key: Optional[ApiKey]
request_id: str
start_time: float
client_ip: str
user_agent: str
original_headers: Dict[str, str]
query_params: Dict[str, str]
raw_body: bytes | None = None
json_body: Optional[Dict[str, Any]] = None
quota_remaining: Optional[float] = None
mode: str = "standard" # standard / proxy
api_format_hint: Optional[str] = None
# URL 路径参数(如 Gemini API 的 /v1beta/models/{model}:generateContent
path_params: Dict[str, Any] = field(default_factory=dict)
# 供适配器扩展的状态存储
extra: Dict[str, Any] = field(default_factory=dict)
audit_metadata: Dict[str, Any] = field(default_factory=dict)
def ensure_json_body(self) -> Dict[str, Any]:
"""确保请求体已解析为JSON并返回。"""
if self.json_body is not None:
return self.json_body
if not self.raw_body:
raise HTTPException(status_code=400, detail="请求体不能为空")
try:
self.json_body = json.loads(self.raw_body.decode("utf-8"))
except json.JSONDecodeError as exc:
logger.warning(f"解析JSON失败: {exc}")
raise HTTPException(status_code=400, detail="请求体必须是合法的JSON") from exc
return self.json_body
def add_audit_metadata(self, **values: Any) -> None:
"""向审计日志附加字段(会自动过滤 None"""
for key, value in values.items():
if value is not None:
self.audit_metadata[key] = value
def extend_audit_metadata(self, data: Dict[str, Any]) -> None:
"""批量附加审计字段。"""
for key, value in data.items():
if value is not None:
self.audit_metadata[key] = value
@classmethod
def build(
cls,
request: Request,
db: Session,
user: Optional[User],
api_key: Optional[ApiKey],
raw_body: Optional[bytes] = None,
mode: str = "standard",
api_format_hint: Optional[str] = None,
path_params: Optional[Dict[str, Any]] = None,
) -> "ApiRequestContext":
"""创建上下文实例并提前读取必要的元数据。"""
request_id = getattr(request.state, "request_id", None) or str(uuid.uuid4())[:8]
setattr(request.state, "request_id", request_id)
start_time = time.time()
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
context = cls(
request=request,
db=db,
user=user,
api_key=api_key,
request_id=request_id,
start_time=start_time,
client_ip=client_ip,
user_agent=user_agent,
original_headers=dict(request.headers),
query_params=dict(request.query_params),
raw_body=raw_body,
mode=mode,
api_format_hint=api_format_hint,
path_params=path_params or {},
)
# 便于插件/日志引用
request.state.request_id = request_id
if user:
request.state.user_id = user.id
if api_key:
request.state.api_key_id = api_key.id
return context

View File

@@ -0,0 +1,49 @@
from __future__ import annotations
from dataclasses import asdict, dataclass
from typing import List, Sequence, Tuple, TypeVar
from sqlalchemy.orm import Query
T = TypeVar("T")
@dataclass
class PaginationMeta:
total: int
limit: int
offset: int
count: int
def to_dict(self) -> dict:
return asdict(self)
def paginate_query(query: Query, limit: int, offset: int) -> Tuple[int, List[T]]:
"""
对 SQLAlchemy 查询应用 limit/offset并返回总数与结果列表。
"""
total = query.order_by(None).count()
records = query.offset(offset).limit(limit).all()
return total, records
def paginate_sequence(
items: Sequence[T], limit: int, offset: int
) -> Tuple[List[T], PaginationMeta]:
"""
对内存序列应用分页,返回切片和元数据。
"""
total = len(items)
sliced = list(items[offset : offset + limit])
meta = PaginationMeta(total=total, limit=limit, offset=offset, count=len(sliced))
return sliced, meta
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra) -> dict:
"""
构建标准分页响应 payload。
"""
payload = {"items": items, "meta": meta.to_dict()}
payload.update(extra)
return payload

387
src/api/base/pipeline.py Normal file
View File

@@ -0,0 +1,387 @@
from __future__ import annotations
import time
from enum import Enum
from typing import Any, Optional, Tuple
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session, sessionmaker
from src.core.exceptions import QuotaExceededException
from src.core.logger import logger
from src.models.database import ApiKey, AuditEventType, User, UserRole
from src.services.auth.service import AuthService
from src.services.cache.user_cache import UserCacheService
from src.services.system.audit import AuditService
from src.services.usage.service import UsageService
from .adapter import ApiAdapter, ApiMode
from .context import ApiRequestContext
class ApiRequestPipeline:
"""负责统一执行认证、配额校验、上下文构建等通用逻辑的管道。"""
def __init__(
self,
auth_service: AuthService = AuthService,
usage_service: UsageService = UsageService,
audit_service: AuditService = AuditService,
):
self.auth_service = auth_service
self.usage_service = usage_service
self.audit_service = audit_service
async def run(
self,
adapter: ApiAdapter,
http_request: Request,
db: Session,
*,
mode: ApiMode = ApiMode.STANDARD,
api_format_hint: Optional[str] = None,
path_params: Optional[dict[str, Any]] = None,
):
logger.debug(f"[Pipeline] START | path={http_request.url.path}")
logger.debug(f"[Pipeline] Running with mode={mode}, adapter={adapter.__class__.__name__}, "
f"adapter.mode={adapter.mode}, path={http_request.url.path}")
if mode == ApiMode.ADMIN:
user = await self._authenticate_admin(http_request, db)
api_key = None
elif mode == ApiMode.USER:
user = await self._authenticate_user(http_request, db)
api_key = None
elif mode == ApiMode.PUBLIC:
user = None
api_key = None
else:
logger.debug("[Pipeline] 调用 _authenticate_client")
user, api_key = self._authenticate_client(http_request, db, adapter)
logger.debug(f"[Pipeline] 认证完成 | user={user.username if user else None}")
raw_body = None
if http_request.method in {"POST", "PUT", "PATCH"}:
try:
import asyncio
# 添加30秒超时防止卡死
raw_body = await asyncio.wait_for(http_request.body(), timeout=30.0)
logger.debug(f"[Pipeline] Raw body读取完成 | size={len(raw_body) if raw_body is not None else 0} bytes")
except asyncio.TimeoutError:
logger.error("读取请求体超时(30s),可能客户端未发送完整请求体")
raise HTTPException(
status_code=408, detail="Request timeout: body not received within 30 seconds"
)
else:
logger.debug(f"[Pipeline] 非写请求跳过读取Body | method={http_request.method}")
context = ApiRequestContext.build(
request=http_request,
db=db,
user=user,
api_key=api_key,
raw_body=raw_body,
mode=mode.value,
api_format_hint=api_format_hint,
path_params=path_params,
)
logger.debug(f"[Pipeline] Context构建完成 | adapter={adapter.name} | request_id={context.request_id}")
if mode != ApiMode.ADMIN and user:
context.quota_remaining = self._calculate_quota_remaining(user)
logger.debug(f"[Pipeline] Adapter={adapter.name} | RequestID={context.request_id}")
logger.debug(f"[Pipeline] Calling authorize on {adapter.__class__.__name__}, user={context.user}")
# authorize 可能是异步的,需要检查并 await
authorize_result = adapter.authorize(context)
if hasattr(authorize_result, "__await__"):
await authorize_result
try:
response = await adapter.handle(context)
status_code = getattr(response, "status_code", None)
self._record_audit_event(context, adapter, success=True, status_code=status_code)
return response
except HTTPException as exc:
err_detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
self._record_audit_event(
context,
adapter,
success=False,
status_code=exc.status_code,
error=err_detail,
)
raise
except Exception as exc:
self._record_audit_event(
context,
adapter,
success=False,
status_code=500,
error=str(exc),
)
raise
# --------------------------------------------------------------------- #
# Internal helpers
# --------------------------------------------------------------------- #
def _authenticate_client(
self, request: Request, db: Session, adapter: ApiAdapter
) -> Tuple[User, ApiKey]:
logger.debug("[Pipeline._authenticate_client] 开始")
# 使用 adapter 的 extract_api_key 方法,支持不同 API 格式的认证头
client_api_key = adapter.extract_api_key(request)
logger.debug(f"[Pipeline._authenticate_client] 提取API密钥完成 | key_prefix={client_api_key[:8] if client_api_key else None}...")
if not client_api_key:
raise HTTPException(status_code=401, detail="请提供API密钥")
logger.debug("[Pipeline._authenticate_client] 调用 auth_service.authenticate_api_key")
auth_result = self.auth_service.authenticate_api_key(db, client_api_key)
logger.debug(f"[Pipeline._authenticate_client] 认证结果 | result={bool(auth_result)}")
if not auth_result:
raise HTTPException(status_code=401, detail="无效的API密钥")
user, api_key = auth_result
if not user or not api_key:
raise HTTPException(status_code=401, detail="无效的API密钥")
request.state.user_id = user.id
request.state.api_key_id = api_key.id
# 检查配额或余额支持独立Key
quota_ok, message = self.usage_service.check_user_quota(db, user, api_key=api_key)
if not quota_ok:
# 根据Key类型计算剩余额度
if api_key.is_standalone:
# 独立Key显示剩余余额
remaining = (
None
if api_key.current_balance_usd is None
else float(api_key.current_balance_usd - (api_key.balance_used_usd or 0))
)
else:
# 普通Key显示用户配额剩余
remaining = (
None
if user.quota_usd is None or user.quota_usd < 0
else float(user.quota_usd - user.used_usd)
)
raise QuotaExceededException(quota_type="USD", remaining=remaining)
return user, api_key
async def _authenticate_admin(self, request: Request, db: Session) -> User:
authorization = request.headers.get("authorization")
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="缺少管理员凭证")
token = authorization.replace("Bearer ", "").strip()
try:
payload = await self.auth_service.verify_token(token)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Admin token 验证失败: {exc}")
raise HTTPException(status_code=401, detail="无效的管理员令牌")
user_id = payload.get("user_id")
if not user_id:
raise HTTPException(status_code=401, detail="无效的管理员令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
request.state.user_id = user.id
return user
async def _authenticate_user(self, request: Request, db: Session) -> User:
"""JWT 认证普通用户(不要求管理员权限)"""
authorization = request.headers.get("authorization")
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="缺少用户凭证")
token = authorization.replace("Bearer ", "").strip()
try:
payload = await self.auth_service.verify_token(token)
except HTTPException:
raise
except Exception as exc:
logger.error(f"User token 验证失败: {exc}")
raise HTTPException(status_code=401, detail="无效的用户令牌")
user_id = payload.get("user_id")
if not user_id:
raise HTTPException(status_code=401, detail="无效的用户令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
request.state.user_id = user.id
return user
def _calculate_quota_remaining(self, user: Optional[User]) -> Optional[float]:
if not user:
return None
if user.quota_usd is None or user.quota_usd < 0:
return None
return max(float(user.quota_usd - user.used_usd), 0.0)
def _record_audit_event(
self,
context: ApiRequestContext,
adapter: ApiAdapter,
*,
success: bool,
status_code: Optional[int] = None,
error: Optional[str] = None,
) -> None:
if not getattr(adapter, "audit_log_enabled", True):
return
bind = context.db.get_bind()
if bind is None:
return
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
if not event_type:
if not success and status_code == 401:
event_type = AuditEventType.UNAUTHORIZED_ACCESS
else:
event_type = (
AuditEventType.REQUEST_SUCCESS if success else AuditEventType.REQUEST_FAILED
)
metadata = self._build_audit_metadata(
context=context,
adapter=adapter,
success=success,
status_code=status_code,
error=error,
)
SessionMaker = sessionmaker(bind=bind)
audit_session = SessionMaker()
try:
self.audit_service.log_event(
db=audit_session,
event_type=event_type,
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
user_id=context.user.id if context.user else None,
api_key_id=context.api_key.id if context.api_key else None,
ip_address=context.client_ip,
user_agent=context.user_agent,
request_id=context.request_id,
status_code=status_code,
error_message=error,
metadata=metadata,
)
audit_session.commit()
except Exception as exc:
audit_session.rollback()
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
finally:
audit_session.close()
def _build_audit_metadata(
self,
context: ApiRequestContext,
adapter: ApiAdapter,
*,
success: bool,
status_code: Optional[int],
error: Optional[str],
) -> dict:
duration_ms = max((time.time() - context.start_time) * 1000, 0.0)
request = context.request
path_params = {}
try:
path_params = dict(getattr(request, "path_params", {}) or {})
except Exception:
path_params = {}
metadata: dict[str, Any] = {
"path": request.url.path,
"path_params": path_params,
"method": request.method,
"adapter": adapter.name,
"adapter_class": adapter.__class__.__name__,
"adapter_mode": getattr(adapter.mode, "value", str(adapter.mode)),
"mode": context.mode,
"api_format_hint": context.api_format_hint,
"query": context.query_params,
"duration_ms": round(duration_ms, 2),
"request_body_bytes": len(context.raw_body or b""),
"has_body": bool(context.raw_body),
"request_content_type": request.headers.get("content-type"),
"quota_remaining": context.quota_remaining,
"success": success,
}
if status_code is not None:
metadata["status_code"] = status_code
if context.user and getattr(context.user, "role", None):
role = context.user.role
metadata["user_role"] = getattr(role, "value", role)
if context.api_key:
if getattr(context.api_key, "name", None):
metadata["api_key_name"] = context.api_key.name
# 使用脱敏后的密钥显示
if hasattr(context.api_key, "get_display_key"):
metadata["api_key_display"] = context.api_key.get_display_key()
extra_details: dict[str, Any] = {}
if context.audit_metadata:
extra_details.update(context.audit_metadata)
try:
adapter_details = adapter.get_audit_metadata(
context,
success=success,
status_code=status_code,
error=error,
)
if adapter_details:
extra_details.update(adapter_details)
except Exception as exc:
logger.warning(f"[Audit] Adapter metadata failed: {adapter.__class__.__name__}: {exc}")
if extra_details:
metadata["details"] = extra_details
if error:
metadata["error"] = error
return self._sanitize_metadata(metadata)
def _sanitize_metadata(self, value: Any, depth: int = 0):
if value is None:
return None
if depth > 5:
return str(value)
if isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, Enum):
return value.value
if isinstance(value, dict):
sanitized = {}
for key, val in value.items():
cleaned = self._sanitize_metadata(val, depth + 1)
if cleaned is not None:
sanitized[str(key)] = cleaned
return sanitized
if isinstance(value, (list, tuple, set)):
return [self._sanitize_metadata(item, depth + 1) for item in value]
if hasattr(value, "isoformat"):
try:
return value.isoformat()
except Exception:
return str(value)
return str(value)