Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

11
src/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
"""AI Proxy
A proxy server that enables AI models to work with multiple API providers.
"""
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
__version__ = "9.1.0"
__author__ = "AI Proxy"

0
src/api/__init__.py Normal file
View File

30
src/api/admin/__init__.py Normal file
View File

@@ -0,0 +1,30 @@
"""Admin API routers."""
from fastapi import APIRouter
from .adaptive import router as adaptive_router
from .api_keys import router as api_keys_router
from .endpoints import router as endpoints_router
from .models import router as models_router
from .monitoring import router as monitoring_router
from .provider_strategy import router as provider_strategy_router
from .providers import router as providers_router
from .security import router as security_router
from .system import router as system_router
from .usage import router as usage_router
from .users import router as users_router
router = APIRouter()
router.include_router(system_router)
router.include_router(users_router)
router.include_router(providers_router)
router.include_router(api_keys_router)
router.include_router(usage_router)
router.include_router(monitoring_router)
router.include_router(endpoints_router)
router.include_router(provider_strategy_router)
router.include_router(adaptive_router)
router.include_router(models_router)
router.include_router(security_router)
__all__ = ["router"]

377
src/api/admin/adaptive.py Normal file
View File

@@ -0,0 +1,377 @@
"""
自适应并发管理 API 端点
设计原则:
- 自适应模式由 max_concurrent 字段决定:
- max_concurrent = NULL启用自适应模式系统自动学习并调整并发限制
- max_concurrent = 数字:固定限制模式,使用用户指定的并发限制
- learned_max_concurrent自适应模式下学习到的并发限制值
- adaptive_mode 是计算字段,基于 max_concurrent 是否为 NULL
"""
from dataclasses import dataclass
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.database import get_db
from src.models.database import ProviderAPIKey
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive Concurrency"])
pipeline = ApiRequestPipeline()
# ==================== Pydantic Models ====================
class EnableAdaptiveRequest(BaseModel):
"""启用自适应模式请求"""
enabled: bool = Field(..., description="是否启用自适应模式true=自适应false=固定限制)")
fixed_limit: Optional[int] = Field(
None, ge=1, le=100, description="固定并发限制(仅当 enabled=false 时生效)"
)
class AdaptiveStatsResponse(BaseModel):
"""自适应统计响应"""
adaptive_mode: bool = Field(..., description="是否为自适应模式max_concurrent=NULL")
max_concurrent: Optional[int] = Field(None, description="用户配置的固定限制NULL=自适应)")
effective_limit: Optional[int] = Field(
None, description="当前有效限制(自适应使用学习值,固定使用配置值)"
)
learned_limit: Optional[int] = Field(None, description="学习到的并发限制")
concurrent_429_count: int
rpm_429_count: int
last_429_at: Optional[str]
last_429_type: Optional[str]
adjustment_count: int
recent_adjustments: List[dict]
class KeyListItem(BaseModel):
"""Key 列表项"""
id: str
name: Optional[str]
endpoint_id: str
is_adaptive: bool = Field(..., description="是否为自适应模式max_concurrent=NULL")
max_concurrent: Optional[int] = Field(None, description="固定并发限制NULL=自适应)")
effective_limit: Optional[int] = Field(None, description="当前有效限制")
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
concurrent_429_count: int
rpm_429_count: int
# ==================== API Endpoints ====================
@router.get(
"/keys",
response_model=List[KeyListItem],
summary="获取所有启用自适应模式的Key",
)
async def list_adaptive_keys(
request: Request,
endpoint_id: Optional[str] = Query(None, description="按 Endpoint 过滤"),
db: Session = Depends(get_db),
):
"""
获取所有启用自适应模式的Key列表
可选参数:
- endpoint_id: 按 Endpoint 过滤
"""
adapter = ListAdaptiveKeysAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch(
"/keys/{key_id}/mode",
summary="Toggle key's concurrency control mode",
)
async def toggle_adaptive_mode(
key_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Toggle the concurrency control mode for a specific key
Parameters:
- enabled: true=adaptive mode (max_concurrent=NULL), false=fixed limit mode
- fixed_limit: fixed limit value (required when enabled=false)
"""
adapter = ToggleAdaptiveModeAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get(
"/keys/{key_id}/stats",
response_model=AdaptiveStatsResponse,
summary="获取Key的自适应统计",
)
async def get_adaptive_stats(
key_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
获取指定Key的自适应并发统计信息
包括:
- 当前配置
- 学习到的限制
- 429错误统计
- 调整历史
"""
adapter = GetAdaptiveStatsAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete(
"/keys/{key_id}/learning",
summary="Reset key's learning state",
)
async def reset_adaptive_learning(
key_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Reset the adaptive learning state for a specific key
Clears:
- Learned concurrency limit (learned_max_concurrent)
- 429 error counts
- Adjustment history
Does not change:
- max_concurrent config (determines adaptive mode)
"""
adapter = ResetAdaptiveLearningAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch(
"/keys/{key_id}/limit",
summary="Set key to fixed concurrency limit mode",
)
async def set_concurrent_limit(
key_id: str,
request: Request,
limit: int = Query(..., ge=1, le=100, description="Concurrency limit value"),
db: Session = Depends(get_db),
):
"""
Set key to fixed concurrency limit mode
Note:
- After setting this value, key switches to fixed limit mode and won't auto-adjust
- To restore adaptive mode, use PATCH /keys/{key_id}/mode
"""
adapter = SetConcurrentLimitAdapter(key_id=key_id, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get(
"/summary",
summary="获取自适应并发的全局统计",
)
async def get_adaptive_summary(
request: Request,
db: Session = Depends(get_db),
):
"""
获取自适应并发的全局统计摘要
包括:
- 启用自适应模式的Key数量
- 总429错误数
- 并发限制调整次数
"""
adapter = AdaptiveSummaryAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ==================== Pipeline 适配器 ====================
@dataclass
class ListAdaptiveKeysAdapter(AdminApiAdapter):
endpoint_id: Optional[str] = None
async def handle(self, context): # type: ignore[override]
# 自适应模式max_concurrent = NULL
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None))
if self.endpoint_id:
query = query.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
keys = query.all()
return [
KeyListItem(
id=key.id,
name=key.name,
endpoint_id=key.endpoint_id,
is_adaptive=key.max_concurrent is None,
max_concurrent=key.max_concurrent,
effective_limit=(
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
),
learned_max_concurrent=key.learned_max_concurrent,
concurrent_429_count=key.concurrent_429_count or 0,
rpm_429_count=key.rpm_429_count or 0,
)
for key in keys
]
@dataclass
class ToggleAdaptiveModeAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise HTTPException(status_code=404, detail="Key not found")
payload = context.ensure_json_body()
try:
body = EnableAdaptiveRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
if body.enabled:
# 启用自适应模式:将 max_concurrent 设为 NULL
key.max_concurrent = None
message = "已切换为自适应模式,系统将自动学习并调整并发限制"
else:
# 禁用自适应模式:设置固定限制
if body.fixed_limit is None:
raise HTTPException(
status_code=400, detail="禁用自适应模式时必须提供 fixed_limit 参数"
)
key.max_concurrent = body.fixed_limit
message = f"已切换为固定限制模式,并发限制设为 {body.fixed_limit}"
context.db.commit()
context.db.refresh(key)
is_adaptive = key.max_concurrent is None
return {
"message": message,
"key_id": key.id,
"is_adaptive": is_adaptive,
"max_concurrent": key.max_concurrent,
"effective_limit": key.learned_max_concurrent if is_adaptive else key.max_concurrent,
}
@dataclass
class GetAdaptiveStatsAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise HTTPException(status_code=404, detail="Key not found")
adaptive_manager = get_adaptive_manager()
stats = adaptive_manager.get_adjustment_stats(key)
# 转换字段名以匹配响应模型
return AdaptiveStatsResponse(
adaptive_mode=stats["adaptive_mode"],
max_concurrent=stats["max_concurrent"],
effective_limit=stats["effective_limit"],
learned_limit=stats["learned_limit"],
concurrent_429_count=stats["concurrent_429_count"],
rpm_429_count=stats["rpm_429_count"],
last_429_at=stats["last_429_at"],
last_429_type=stats["last_429_type"],
adjustment_count=stats["adjustment_count"],
recent_adjustments=stats["recent_adjustments"],
)
@dataclass
class ResetAdaptiveLearningAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise HTTPException(status_code=404, detail="Key not found")
adaptive_manager = get_adaptive_manager()
adaptive_manager.reset_learning(context.db, key)
return {"message": "学习状态已重置", "key_id": key.id}
@dataclass
class SetConcurrentLimitAdapter(AdminApiAdapter):
key_id: str
limit: int
async def handle(self, context): # type: ignore[override]
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise HTTPException(status_code=404, detail="Key not found")
was_adaptive = key.max_concurrent is None
key.max_concurrent = self.limit
context.db.commit()
context.db.refresh(key)
return {
"message": f"已设置为固定限制模式,并发限制为 {self.limit}",
"key_id": key.id,
"is_adaptive": False,
"max_concurrent": key.max_concurrent,
"previous_mode": "adaptive" if was_adaptive else "fixed",
}
class AdaptiveSummaryAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
# 自适应模式max_concurrent = NULL
adaptive_keys = (
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None)).all()
)
total_keys = len(adaptive_keys)
total_concurrent_429 = sum(key.concurrent_429_count or 0 for key in adaptive_keys)
total_rpm_429 = sum(key.rpm_429_count or 0 for key in adaptive_keys)
total_adjustments = sum(len(key.adjustment_history or []) for key in adaptive_keys)
recent_adjustments = []
for key in adaptive_keys:
if key.adjustment_history:
for adj in key.adjustment_history[-3:]:
recent_adjustments.append(
{
"key_id": key.id,
"key_name": key.name,
**adj,
}
)
recent_adjustments.sort(key=lambda item: item.get("timestamp", ""), reverse=True)
return {
"total_adaptive_keys": total_keys,
"total_concurrent_429_errors": total_concurrent_429,
"total_rpm_429_errors": total_rpm_429,
"total_adjustments": total_adjustments,
"recent_adjustments": recent_adjustments[:10],
}

View File

@@ -0,0 +1,5 @@
"""API key admin routes export."""
from .routes import router
__all__ = ["router"]

View File

@@ -0,0 +1,497 @@
"""管理员独立余额 API Key 管理路由。
独立余额Key不关联用户配额有独立余额限制用于给非注册用户使用。
"""
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.api import CreateApiKeyRequest
from src.models.database import ApiKey, User
from src.services.user.apikey import ApiKeyService
router = APIRouter(prefix="/api/admin/api-keys", tags=["Admin - API Keys (Standalone)"])
pipeline = ApiRequestPipeline()
@router.get("")
async def list_standalone_api_keys(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=500),
is_active: Optional[bool] = None,
db: Session = Depends(get_db),
):
"""列出所有独立余额API Keys"""
adapter = AdminListStandaloneKeysAdapter(skip=skip, limit=limit, is_active=is_active)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("")
async def create_standalone_api_key(
request: Request,
key_data: CreateApiKeyRequest,
db: Session = Depends(get_db),
):
"""创建独立余额API Key必须设置余额限制"""
adapter = AdminCreateStandaloneKeyAdapter(key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{key_id}")
async def update_api_key(
key_id: str, request: Request, key_data: CreateApiKeyRequest, db: Session = Depends(get_db)
):
"""更新独立余额Key可修改名称、过期时间、余额限制等"""
adapter = AdminUpdateApiKeyAdapter(key_id=key_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{key_id}")
async def toggle_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
"""Toggle API key active status (PATCH with is_active in body)"""
adapter = AdminToggleApiKeyAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{key_id}")
async def delete_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
adapter = AdminDeleteApiKeyAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{key_id}/balance")
async def add_balance_to_key(
key_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""Adjust balance for standalone API key (positive to add, negative to deduct)"""
# 从请求体获取调整金额
body = await request.json()
amount_usd = body.get("amount_usd")
# 参数校验
if amount_usd is None:
raise HTTPException(status_code=400, detail="缺少必需参数: amount_usd")
if amount_usd == 0:
raise HTTPException(status_code=400, detail="调整金额不能为 0")
# 类型校验
try:
amount_usd = float(amount_usd)
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="调整金额必须是有效数字")
# 如果是扣除操作,检查Key是否存在以及余额是否充足
if amount_usd < 0:
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
if not api_key:
raise HTTPException(status_code=404, detail="API密钥不存在")
if not api_key.is_standalone:
raise HTTPException(status_code=400, detail="只能为独立余额Key调整余额")
if api_key.current_balance_usd is not None:
if abs(amount_usd) > api_key.current_balance_usd:
raise HTTPException(
status_code=400,
detail=f"扣除金额 ${abs(amount_usd):.2f} 超过当前余额 ${api_key.current_balance_usd:.2f}",
)
adapter = AdminAddBalanceAdapter(key_id=key_id, amount_usd=amount_usd)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{key_id}")
async def get_api_key_detail(
key_id: str,
request: Request,
include_key: bool = Query(False, description="Include full decrypted key in response"),
db: Session = Depends(get_db),
):
"""Get API key detail, optionally include full key"""
if include_key:
adapter = AdminGetFullKeyAdapter(key_id=key_id)
else:
# Return basic key info without full key
adapter = AdminGetKeyDetailAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminListStandaloneKeysAdapter(AdminApiAdapter):
"""列出独立余额Keys"""
def __init__(
self,
skip: int,
limit: int,
is_active: Optional[bool],
):
self.skip = skip
self.limit = limit
self.is_active = is_active
async def handle(self, context): # type: ignore[override]
db = context.db
# 只查询独立余额Keys
query = db.query(ApiKey).filter(ApiKey.is_standalone == True)
if self.is_active is not None:
query = query.filter(ApiKey.is_active == self.is_active)
total = query.count()
api_keys = (
query.order_by(ApiKey.created_at.desc()).offset(self.skip).limit(self.limit).all()
)
context.add_audit_metadata(
action="list_standalone_api_keys",
filter_is_active=self.is_active,
limit=self.limit,
skip=self.skip,
total=total,
)
return {
"api_keys": [
{
"id": api_key.id,
"user_id": api_key.user_id, # 创建者ID
"name": api_key.name,
"key_display": api_key.get_display_key(),
"is_active": api_key.is_active,
"is_standalone": api_key.is_standalone,
"current_balance_usd": api_key.current_balance_usd,
"balance_used_usd": float(api_key.balance_used_usd or 0),
"total_requests": api_key.total_requests,
"total_cost_usd": float(api_key.total_cost_usd or 0),
"rate_limit": api_key.rate_limit,
"allowed_providers": api_key.allowed_providers,
"allowed_api_formats": api_key.allowed_api_formats,
"allowed_models": api_key.allowed_models,
"last_used_at": (
api_key.last_used_at.isoformat() if api_key.last_used_at else None
),
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"created_at": api_key.created_at.isoformat(),
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
"auto_delete_on_expiry": api_key.auto_delete_on_expiry,
}
for api_key in api_keys
],
"total": total,
"limit": self.limit,
"skip": self.skip,
}
class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
"""创建独立余额Key"""
def __init__(self, key_data: CreateApiKeyRequest):
self.key_data = key_data
async def handle(self, context): # type: ignore[override]
db = context.db
# 独立Key必须设置初始余额
if not self.key_data.initial_balance_usd or self.key_data.initial_balance_usd <= 0:
raise HTTPException(
status_code=400,
detail="创建独立余额Key必须设置有效的初始余额initial_balance_usd > 0",
)
# 独立Key需要关联到管理员用户从context获取
admin_user_id = context.user.id
# 创建独立Key
api_key, plain_key = ApiKeyService.create_api_key(
db=db,
user_id=admin_user_id, # 关联到创建者
name=self.key_data.name,
allowed_providers=self.key_data.allowed_providers,
allowed_api_formats=self.key_data.allowed_api_formats,
allowed_models=self.key_data.allowed_models,
rate_limit=self.key_data.rate_limit or 100,
expire_days=self.key_data.expire_days,
initial_balance_usd=self.key_data.initial_balance_usd,
is_standalone=True, # 标记为独立Key
auto_delete_on_expiry=self.key_data.auto_delete_on_expiry,
)
logger.info(f"管理员创建独立余额Key: ID {api_key.id}, 初始余额 ${self.key_data.initial_balance_usd}")
context.add_audit_metadata(
action="create_standalone_api_key",
key_id=api_key.id,
initial_balance_usd=self.key_data.initial_balance_usd,
)
return {
"id": api_key.id,
"key": plain_key, # 只在创建时返回完整密钥
"name": api_key.name,
"key_display": api_key.get_display_key(),
"is_standalone": True,
"current_balance_usd": api_key.current_balance_usd,
"balance_used_usd": 0.0,
"rate_limit": api_key.rate_limit,
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"created_at": api_key.created_at.isoformat(),
"message": "独立余额Key创建成功请妥善保存完整密钥后续将无法查看",
}
class AdminUpdateApiKeyAdapter(AdminApiAdapter):
"""更新独立余额Key"""
def __init__(self, key_id: str, key_data: CreateApiKeyRequest):
self.key_id = key_id
self.key_data = key_data
async def handle(self, context): # type: ignore[override]
db = context.db
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
if not api_key:
raise NotFoundException("API密钥不存在", "api_key")
# 构建更新数据
update_data = {}
if self.key_data.name is not None:
update_data["name"] = self.key_data.name
if self.key_data.rate_limit is not None:
update_data["rate_limit"] = self.key_data.rate_limit
if (
hasattr(self.key_data, "auto_delete_on_expiry")
and self.key_data.auto_delete_on_expiry is not None
):
update_data["auto_delete_on_expiry"] = self.key_data.auto_delete_on_expiry
# 访问限制配置(允许设置为空数组来清除限制)
if hasattr(self.key_data, "allowed_providers"):
update_data["allowed_providers"] = self.key_data.allowed_providers
if hasattr(self.key_data, "allowed_api_formats"):
update_data["allowed_api_formats"] = self.key_data.allowed_api_formats
if hasattr(self.key_data, "allowed_models"):
update_data["allowed_models"] = self.key_data.allowed_models
# 处理过期时间
if self.key_data.expire_days is not None:
if self.key_data.expire_days > 0:
from datetime import timedelta
update_data["expires_at"] = datetime.now(timezone.utc) + timedelta(
days=self.key_data.expire_days
)
else:
# expire_days = 0 或负数表示永不过期
update_data["expires_at"] = None
elif hasattr(self.key_data, "expire_days") and self.key_data.expire_days is None:
# 明确传递 None设为永不过期
update_data["expires_at"] = None
# 使用 ApiKeyService 更新
updated_key = ApiKeyService.update_api_key(db, self.key_id, **update_data)
if not updated_key:
raise NotFoundException("更新失败", "api_key")
logger.info(f"管理员更新独立余额Key: ID {self.key_id}, 更新字段 {list(update_data.keys())}")
context.add_audit_metadata(
action="update_standalone_api_key",
key_id=self.key_id,
updated_fields=list(update_data.keys()),
)
return {
"id": updated_key.id,
"name": updated_key.name,
"key_display": updated_key.get_display_key(),
"is_active": updated_key.is_active,
"current_balance_usd": updated_key.current_balance_usd,
"balance_used_usd": float(updated_key.balance_used_usd or 0),
"rate_limit": updated_key.rate_limit,
"expires_at": updated_key.expires_at.isoformat() if updated_key.expires_at else None,
"updated_at": updated_key.updated_at.isoformat() if updated_key.updated_at else None,
"message": "API密钥已更新",
}
class AdminToggleApiKeyAdapter(AdminApiAdapter):
def __init__(self, key_id: str):
self.key_id = key_id
async def handle(self, context): # type: ignore[override]
db = context.db
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
if not api_key:
raise NotFoundException("API密钥不存在", "api_key")
api_key.is_active = not api_key.is_active
api_key.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(api_key)
logger.info(f"管理员切换API密钥状态: Key ID {self.key_id}, 新状态 {'启用' if api_key.is_active else '禁用'}")
context.add_audit_metadata(
action="toggle_api_key",
target_key_id=api_key.id,
user_id=api_key.user_id,
new_status="enabled" if api_key.is_active else "disabled",
)
return {
"id": api_key.id,
"is_active": api_key.is_active,
"message": f"API密钥已{'启用' if api_key.is_active else '禁用'}",
}
class AdminDeleteApiKeyAdapter(AdminApiAdapter):
def __init__(self, key_id: str):
self.key_id = key_id
async def handle(self, context): # type: ignore[override]
db = context.db
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
if not api_key:
raise HTTPException(status_code=404, detail="API密钥不存在")
user = api_key.user
db.delete(api_key)
db.commit()
logger.info(f"管理员删除API密钥: Key ID {self.key_id}, 用户 {user.email if user else '未知'}")
context.add_audit_metadata(
action="delete_api_key",
target_key_id=self.key_id,
user_id=user.id if user else None,
user_email=user.email if user else None,
)
return {"message": "API密钥已删除"}
class AdminAddBalanceAdapter(AdminApiAdapter):
"""为独立余额Key增加余额"""
def __init__(self, key_id: str, amount_usd: float):
self.key_id = key_id
self.amount_usd = amount_usd
async def handle(self, context): # type: ignore[override]
db = context.db
# 使用 ApiKeyService 增加余额
updated_key = ApiKeyService.add_balance(db, self.key_id, self.amount_usd)
if not updated_key:
raise NotFoundException("余额充值失败Key不存在或不是独立余额Key", "api_key")
logger.info(f"管理员为独立余额Key充值: ID {self.key_id}, 充值 ${self.amount_usd:.4f}")
context.add_audit_metadata(
action="add_balance_to_key",
key_id=self.key_id,
amount_usd=self.amount_usd,
new_current_balance=updated_key.current_balance_usd,
)
return {
"id": updated_key.id,
"name": updated_key.name,
"current_balance_usd": updated_key.current_balance_usd,
"balance_used_usd": float(updated_key.balance_used_usd or 0),
"message": f"余额充值成功,充值 ${self.amount_usd:.2f},当前余额 ${updated_key.current_balance_usd:.2f}",
}
class AdminGetFullKeyAdapter(AdminApiAdapter):
"""获取完整的API密钥"""
def __init__(self, key_id: str):
self.key_id = key_id
async def handle(self, context): # type: ignore[override]
from src.core.crypto import crypto_service
db = context.db
# 查找API密钥
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
if not api_key:
raise NotFoundException("API密钥不存在", "api_key")
# 解密完整密钥
if not api_key.key_encrypted:
raise HTTPException(status_code=400, detail="该密钥没有存储完整密钥信息")
try:
full_key = crypto_service.decrypt(api_key.key_encrypted)
except Exception as e:
logger.error(f"解密API密钥失败: Key ID {self.key_id}, 错误: {e}")
raise HTTPException(status_code=500, detail="解密密钥失败")
logger.info(f"管理员查看完整API密钥: Key ID {self.key_id}")
context.add_audit_metadata(
action="view_full_api_key",
key_id=self.key_id,
key_name=api_key.name,
)
return {
"key": full_key,
}
class AdminGetKeyDetailAdapter(AdminApiAdapter):
"""Get API key detail without full key"""
def __init__(self, key_id: str):
self.key_id = key_id
async def handle(self, context): # type: ignore[override]
db = context.db
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
if not api_key:
raise NotFoundException("API密钥不存在", "api_key")
context.add_audit_metadata(
action="get_api_key_detail",
key_id=self.key_id,
)
return {
"id": api_key.id,
"user_id": api_key.user_id,
"name": api_key.name,
"key_display": api_key.get_display_key(),
"is_active": api_key.is_active,
"is_standalone": api_key.is_standalone,
"current_balance_usd": api_key.current_balance_usd,
"balance_used_usd": float(api_key.balance_used_usd or 0),
"total_requests": api_key.total_requests,
"total_cost_usd": float(api_key.total_cost_usd or 0),
"rate_limit": api_key.rate_limit,
"allowed_providers": api_key.allowed_providers,
"allowed_api_formats": api_key.allowed_api_formats,
"allowed_models": api_key.allowed_models,
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"created_at": api_key.created_at.isoformat(),
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
}

View File

@@ -0,0 +1,24 @@
"""Endpoint management API routers."""
from fastapi import APIRouter
from .concurrency import router as concurrency_router
from .health import router as health_router
from .keys import router as keys_router
from .routes import router as routes_router
router = APIRouter(prefix="/api/admin/endpoints", tags=["Endpoint Management"])
# Endpoint CRUD
router.include_router(routes_router)
# Endpoint Keys management
router.include_router(keys_router)
# Health monitoring
router.include_router(health_router)
# Concurrency control
router.include_router(concurrency_router)
__all__ = ["router"]

View File

@@ -0,0 +1,116 @@
"""
Endpoint 并发控制管理 API
"""
from dataclasses import dataclass
from typing import Optional
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.database import get_db
from src.models.database import ProviderAPIKey, ProviderEndpoint
from src.models.endpoint_models import (
ConcurrencyStatusResponse,
ResetConcurrencyRequest,
)
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
router = APIRouter(tags=["Concurrency Control"])
pipeline = ApiRequestPipeline()
@router.get("/concurrency/endpoint/{endpoint_id}", response_model=ConcurrencyStatusResponse)
async def get_endpoint_concurrency(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse:
"""获取 Endpoint 当前并发状态"""
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/concurrency/key/{key_id}", response_model=ConcurrencyStatusResponse)
async def get_key_concurrency(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse:
"""获取 Key 当前并发状态"""
adapter = AdminKeyConcurrencyAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/concurrency")
async def reset_concurrency(
request: ResetConcurrencyRequest,
http_request: Request,
db: Session = Depends(get_db),
) -> dict:
"""Reset concurrency counters (admin function, use with caution)"""
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id)
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminEndpointConcurrencyAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
concurrency_manager = await get_concurrency_manager()
endpoint_count, _ = await concurrency_manager.get_current_concurrency(
endpoint_id=self.endpoint_id
)
return ConcurrencyStatusResponse(
endpoint_id=self.endpoint_id,
endpoint_current_concurrency=endpoint_count,
endpoint_max_concurrent=endpoint.max_concurrent,
)
@dataclass
class AdminKeyConcurrencyAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
concurrency_manager = await get_concurrency_manager()
_, key_count = await concurrency_manager.get_current_concurrency(key_id=self.key_id)
return ConcurrencyStatusResponse(
key_id=self.key_id,
key_current_concurrency=key_count,
key_max_concurrent=key.max_concurrent,
)
@dataclass
class AdminResetConcurrencyAdapter(AdminApiAdapter):
endpoint_id: Optional[str]
key_id: Optional[str]
async def handle(self, context): # type: ignore[override]
concurrency_manager = await get_concurrency_manager()
await concurrency_manager.reset_concurrency(
endpoint_id=self.endpoint_id, key_id=self.key_id
)
return {"message": "并发计数已重置"}

View File

@@ -0,0 +1,476 @@
"""
Endpoint 健康监控 API
"""
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Dict, List
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, RequestCandidate
from src.models.endpoint_models import (
ApiFormatHealthMonitor,
ApiFormatHealthMonitorResponse,
EndpointHealthEvent,
HealthStatusResponse,
HealthSummaryResponse,
)
from src.services.health.endpoint import EndpointHealthService
from src.services.health.monitor import health_monitor
router = APIRouter(tags=["Endpoint Health"])
pipeline = ApiRequestPipeline()
@router.get("/health/summary", response_model=HealthSummaryResponse)
async def get_health_summary(
request: Request,
db: Session = Depends(get_db),
) -> HealthSummaryResponse:
"""获取健康状态摘要"""
adapter = AdminHealthSummaryAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/health/status")
async def get_endpoint_health_status(
request: Request,
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
db: Session = Depends(get_db),
):
"""
获取端点健康状态(简化视图,与用户端点统一)
与 /health/api-formats 的区别:
- /health/status: 返回聚合的时间线状态50个时间段基于 Usage 表
- /health/api-formats: 返回详细的事件列表,基于 RequestCandidate 表
"""
adapter = AdminEndpointHealthStatusAdapter(lookback_hours=lookback_hours)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/health/api-formats", response_model=ApiFormatHealthMonitorResponse)
async def get_api_format_health_monitor(
request: Request,
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
per_format_limit: int = Query(60, ge=10, le=200, description="每个 API 格式的事件数量"),
db: Session = Depends(get_db),
) -> ApiFormatHealthMonitorResponse:
"""获取按 API 格式聚合的健康监控时间线(详细事件列表)"""
adapter = AdminApiFormatHealthMonitorAdapter(
lookback_hours=lookback_hours,
per_format_limit=per_format_limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/health/key/{key_id}", response_model=HealthStatusResponse)
async def get_key_health(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> HealthStatusResponse:
"""获取 Key 健康状态"""
adapter = AdminKeyHealthAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/health/keys/{key_id}")
async def recover_key_health(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""
Recover key health status
Resets health_score to 1.0, closes circuit breaker,
cancels auto-disable, and resets all failure counts.
Parameters:
- key_id: Key ID (path parameter)
"""
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/health/keys")
async def recover_all_keys_health(
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""
Batch recover all circuit-broken keys
Finds all keys with circuit_breaker_open=True and:
1. Resets health_score to 1.0
2. Closes circuit breaker
3. Resets failure counts
"""
adapter = AdminRecoverAllKeysHealthAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
class AdminHealthSummaryAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
summary = health_monitor.get_all_health_status(context.db)
return HealthSummaryResponse(**summary)
@dataclass
class AdminEndpointHealthStatusAdapter(AdminApiAdapter):
"""管理员端点健康状态适配器(与用户端点统一,但包含管理员字段)"""
lookback_hours: int
async def handle(self, context): # type: ignore[override]
from src.services.health.endpoint import EndpointHealthService
db = context.db
# 使用共享服务获取健康状态(管理员视图)
result = EndpointHealthService.get_endpoint_health_by_format(
db=db,
lookback_hours=self.lookback_hours,
include_admin_fields=True, # 包含管理员字段
use_cache=False, # 管理员不使用缓存,确保实时性
)
context.add_audit_metadata(
action="endpoint_health_status",
format_count=len(result),
lookback_hours=self.lookback_hours,
)
return result
@dataclass
class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
lookback_hours: int
per_format_limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
now = datetime.now(timezone.utc)
since = now - timedelta(hours=self.lookback_hours)
# 1. 获取所有活跃的 API 格式及其 Provider 数量
active_formats = (
db.query(
ProviderEndpoint.api_format,
func.count(func.distinct(ProviderEndpoint.provider_id)).label("provider_count"),
)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.group_by(ProviderEndpoint.api_format)
.all()
)
# 构建所有格式的 provider_count 映射
all_formats: Dict[str, int] = {}
for api_format_enum, provider_count in active_formats:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
all_formats[api_format] = provider_count
# 1.1 获取所有活跃的 API 格式及其 API Key 数量
active_keys = (
db.query(
ProviderEndpoint.api_format,
func.count(ProviderAPIKey.id).label("key_count"),
)
.join(ProviderAPIKey, ProviderEndpoint.id == ProviderAPIKey.endpoint_id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
ProviderAPIKey.is_active.is_(True),
)
.group_by(ProviderEndpoint.api_format)
.all()
)
# 构建所有格式的 key_count 映射
key_counts: Dict[str, int] = {}
for api_format_enum, key_count in active_keys:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
key_counts[api_format] = key_count
# 1.2 建立每个 API 格式对应的 Endpoint ID 列表,供 Usage 时间线生成使用
endpoint_rows = (
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.all()
)
endpoint_map: Dict[str, List[str]] = defaultdict(list)
for api_format_enum, endpoint_id in endpoint_rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
endpoint_map[api_format].append(endpoint_id)
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
# 只统计最终状态success, failed, skipped
final_statuses = ["success", "failed", "skipped"]
status_counts_query = (
db.query(
ProviderEndpoint.api_format,
RequestCandidate.status,
func.count(RequestCandidate.id).label("count"),
)
.join(RequestCandidate, ProviderEndpoint.id == RequestCandidate.endpoint_id)
.filter(
RequestCandidate.created_at >= since,
RequestCandidate.status.in_(final_statuses),
)
.group_by(ProviderEndpoint.api_format, RequestCandidate.status)
.all()
)
# 构建每个格式的状态统计
status_counts: Dict[str, Dict[str, int]] = {}
for api_format_enum, status, count in status_counts_query:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
if api_format not in status_counts:
status_counts[api_format] = {"success": 0, "failed": 0, "skipped": 0}
status_counts[api_format][status] = count
# 3. 获取最近一段时间的 RequestCandidate限制数量
# 使用上面定义的 final_statuses排除中间状态
limit_rows = max(500, self.per_format_limit * 10)
rows = (
db.query(
RequestCandidate,
ProviderEndpoint.api_format,
ProviderEndpoint.provider_id,
)
.join(ProviderEndpoint, RequestCandidate.endpoint_id == ProviderEndpoint.id)
.filter(
RequestCandidate.created_at >= since,
RequestCandidate.status.in_(final_statuses),
)
.order_by(RequestCandidate.created_at.desc())
.limit(limit_rows)
.all()
)
grouped_attempts: Dict[str, List[RequestCandidate]] = {}
for attempt, api_format_enum, provider_id in rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
if api_format not in grouped_attempts:
grouped_attempts[api_format] = []
# 只保留每个 API 格式最近 per_format_limit 条记录
if len(grouped_attempts[api_format]) < self.per_format_limit:
grouped_attempts[api_format].append(attempt)
# 4. 为所有活跃格式生成监控数据(包括没有请求记录的)
monitors: List[ApiFormatHealthMonitor] = []
for api_format in all_formats:
attempts = grouped_attempts.get(api_format, [])
# 获取窗口内的真实统计数据
# 只统计最终状态success, failed, skipped
# 中间状态available, pending, used, started不计入统计
format_stats = status_counts.get(api_format, {"success": 0, "failed": 0, "skipped": 0})
real_success_count = format_stats.get("success", 0)
real_failed_count = format_stats.get("failed", 0)
real_skipped_count = format_stats.get("skipped", 0)
# total_attempts 只包含最终状态的请求数
total_attempts = real_success_count + real_failed_count + real_skipped_count
# 时间线按时间正序
attempts_sorted = list(reversed(attempts))
events: List[EndpointHealthEvent] = []
for attempt in attempts_sorted:
event_timestamp = attempt.finished_at or attempt.started_at or attempt.created_at
events.append(
EndpointHealthEvent(
timestamp=event_timestamp,
status=attempt.status,
status_code=attempt.status_code,
latency_ms=attempt.latency_ms,
error_type=attempt.error_type,
error_message=attempt.error_message,
)
)
# 成功率 = success / (success + failed)
# skipped 不算失败,不计入成功率分母
# 无实际完成请求时成功率为 1.0(灰色状态)
actual_completed = real_success_count + real_failed_count
success_rate = real_success_count / actual_completed if actual_completed > 0 else 1.0
last_event_at = events[-1].timestamp if events else None
# 生成 Usage 基于时间窗口的健康时间线
timeline_data = EndpointHealthService._generate_timeline_from_usage(
db=db,
endpoint_ids=endpoint_map.get(api_format, []),
now=now,
lookback_hours=self.lookback_hours,
)
monitors.append(
ApiFormatHealthMonitor(
api_format=api_format,
total_attempts=total_attempts, # 真实总请求数
success_count=real_success_count, # 真实成功数
failed_count=real_failed_count, # 真实失败数
skipped_count=real_skipped_count, # 真实跳过数
success_rate=success_rate, # 基于真实统计的成功率
provider_count=all_formats[api_format],
key_count=key_counts.get(api_format, 0),
last_event_at=last_event_at,
events=events, # 限制为 per_format_limit 条(用于时间线显示)
timeline=timeline_data.get("timeline", []),
time_range_start=timeline_data.get("time_range_start"),
time_range_end=timeline_data.get("time_range_end"),
)
)
response = ApiFormatHealthMonitorResponse(
generated_at=now,
formats=monitors,
)
context.add_audit_metadata(
action="api_format_health_monitor",
format_count=len(monitors),
lookback_hours=self.lookback_hours,
per_format_limit=self.per_format_limit,
)
return response
@dataclass
class AdminKeyHealthAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
health_data = health_monitor.get_key_health(context.db, self.key_id)
if not health_data:
raise NotFoundException(f"Key {self.key_id} 不存在")
return HealthStatusResponse(
key_id=health_data["key_id"],
key_health_score=health_data["health_score"],
key_consecutive_failures=health_data["consecutive_failures"],
key_last_failure_at=health_data["last_failure_at"],
key_is_active=health_data["is_active"],
key_statistics=health_data["statistics"],
circuit_breaker_open=health_data["circuit_breaker_open"],
circuit_breaker_open_at=health_data["circuit_breaker_open_at"],
next_probe_at=health_data["next_probe_at"],
)
@dataclass
class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
key.health_score = 1.0
key.consecutive_failures = 0
key.last_failure_at = None
key.circuit_breaker_open = False
key.circuit_breaker_open_at = None
key.next_probe_at = None
if not key.is_active:
key.is_active = True
db.commit()
admin_name = context.user.username if context.user else "admin"
logger.info(f"管理员恢复Key健康状态: {self.key_id} (health_score: 1.0, circuit_breaker: closed)")
return {
"message": "Key已完全恢复",
"details": {
"health_score": 1.0,
"circuit_breaker_open": False,
"is_active": True,
},
}
class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
"""批量恢复所有熔断 Key 的健康状态"""
async def handle(self, context): # type: ignore[override]
db = context.db
# 查找所有熔断的 Key
circuit_open_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.circuit_breaker_open == True).all()
)
if not circuit_open_keys:
return {
"message": "没有需要恢复的 Key",
"recovered_count": 0,
"recovered_keys": [],
}
recovered_keys = []
for key in circuit_open_keys:
key.health_score = 1.0
key.consecutive_failures = 0
key.last_failure_at = None
key.circuit_breaker_open = False
key.circuit_breaker_open_at = None
key.next_probe_at = None
recovered_keys.append(
{
"key_id": key.id,
"key_name": key.name,
"endpoint_id": key.endpoint_id,
}
)
db.commit()
# 重置健康监控器的计数
from src.services.health.monitor import HealthMonitor, health_open_circuits
HealthMonitor._open_circuit_keys = 0
health_open_circuits.set(0)
admin_name = context.user.username if context.user else "admin"
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
return {
"message": f"已恢复 {len(recovered_keys)} 个 Key",
"recovered_count": len(recovered_keys),
"recovered_keys": recovered_keys,
}

View File

@@ -0,0 +1,425 @@
"""
Endpoint API Keys 管理
"""
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Dict, List
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.crypto import crypto_service
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.key_capabilities import get_capability
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.models.endpoint_models import (
BatchUpdateKeyPriorityRequest,
EndpointAPIKeyCreate,
EndpointAPIKeyResponse,
EndpointAPIKeyUpdate,
)
router = APIRouter(tags=["Endpoint Keys"])
pipeline = ApiRequestPipeline()
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
async def list_endpoint_keys(
endpoint_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[EndpointAPIKeyResponse]:
"""获取 Endpoint 的所有 Keys"""
adapter = AdminListEndpointKeysAdapter(
endpoint_id=endpoint_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
async def add_endpoint_key(
endpoint_id: str,
key_data: EndpointAPIKeyCreate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""为 Endpoint 添加 Key"""
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
async def update_endpoint_key(
key_id: str,
key_data: EndpointAPIKeyUpdate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""更新 Endpoint Key"""
adapter = AdminUpdateEndpointKeyAdapter(key_id=key_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/keys/grouped-by-format")
async def get_keys_grouped_by_format(
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""获取按 API 格式分组的所有 Keys用于全局优先级管理"""
adapter = AdminGetKeysGroupedByFormatAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/keys/{key_id}")
async def delete_endpoint_key(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""删除 Endpoint Key"""
adapter = AdminDeleteEndpointKeyAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{endpoint_id}/keys/batch-priority")
async def batch_update_key_priority(
endpoint_id: str,
request: Request,
priority_data: BatchUpdateKeyPriorityRequest,
db: Session = Depends(get_db),
) -> dict:
"""批量更新 Endpoint 下 Keys 的优先级(用于拖动排序)"""
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminListEndpointKeysAdapter(AdminApiAdapter):
endpoint_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
.offset(self.skip)
.limit(self.limit)
.all()
)
result: List[EndpointAPIKeyResponse] = []
for key in keys:
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
key_dict = key.__dict__.copy()
key_dict.pop("_sa_instance_state", None)
key_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
result.append(EndpointAPIKeyResponse(**key_dict))
return result
@dataclass
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
endpoint_id: str
key_data: EndpointAPIKeyCreate
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
if self.key_data.endpoint_id != self.endpoint_id:
raise InvalidRequestException("endpoint_id 不匹配")
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
now = datetime.now(timezone.utc)
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=self.endpoint_id,
api_key=encrypted_key,
name=self.key_data.name,
note=self.key_data.note,
rate_multiplier=self.key_data.rate_multiplier,
internal_priority=self.key_data.internal_priority,
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
rate_limit=self.key_data.rate_limit,
daily_limit=self.key_data.daily_limit,
monthly_limit=self.key_data.monthly_limit,
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
request_count=0,
success_count=0,
error_count=0,
total_response_time_ms=0,
is_active=True,
last_used_at=None,
created_at=now,
updated_at=now,
)
db.add(new_key)
db.commit()
db.refresh(new_key)
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
is_adaptive = new_key.max_concurrent is None
response_dict = new_key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": self.key_data.api_key,
"success_rate": 0.0,
"avg_response_time_ms": 0.0,
"is_adaptive": is_adaptive,
"effective_limit": (
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
@dataclass
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
key_id: str
key_data: EndpointAPIKeyUpdate
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
update_data = self.key_data.model_dump(exclude_unset=True)
if "api_key" in update_data:
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
for field, value in update_data.items():
setattr(key, field, value)
key.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(key)
logger.info(f"[OK] 更新 Key: ID={self.key_id}, Updates={list(update_data.keys())}")
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
response_dict = key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
@dataclass
class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
endpoint_id = key.endpoint_id
try:
db.delete(key)
db.commit()
except Exception as exc:
db.rollback()
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
raise
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}")
return {"message": f"Key {self.key_id} 已删除"}
class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
keys = (
db.query(ProviderAPIKey, ProviderEndpoint, Provider)
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderAPIKey.is_active.is_(True),
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.order_by(
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc()
)
.all()
)
grouped: Dict[str, List[dict]] = {}
for key, endpoint, provider in keys:
api_format = endpoint.api_format
if api_format not in grouped:
grouped[api_format] = []
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
# 计算健康度指标
success_rate = key.success_count / key.request_count if key.request_count > 0 else None
avg_response_time_ms = (
round(key.total_response_time_ms / key.success_count, 2)
if key.success_count > 0
else None
)
# 将 capabilities dict 转换为启用的能力简短名称列表
caps_list = []
if key.capabilities:
for cap_name, enabled in key.capabilities.items():
if enabled:
cap_def = get_capability(cap_name)
caps_list.append(cap_def.short_name if cap_def else cap_name)
grouped[api_format].append(
{
"id": key.id,
"name": key.name,
"api_key_masked": masked_key,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"rate_multiplier": key.rate_multiplier,
"is_active": key.is_active,
"circuit_breaker_open": key.circuit_breaker_open,
"provider_name": provider.display_name or provider.name,
"endpoint_base_url": endpoint.base_url,
"api_format": api_format,
"capabilities": caps_list,
"success_rate": success_rate,
"avg_response_time_ms": avg_response_time_ms,
"request_count": key.request_count,
}
)
# 直接返回分组对象,供前端使用
return grouped
@dataclass
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter):
endpoint_id: str
priority_data: BatchUpdateKeyPriorityRequest
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
# 获取所有需要更新的 Key ID
key_ids = [item.key_id for item in self.priority_data.priorities]
# 验证所有 Key 都属于该 Endpoint
keys = (
db.query(ProviderAPIKey)
.filter(
ProviderAPIKey.id.in_(key_ids),
ProviderAPIKey.endpoint_id == self.endpoint_id,
)
.all()
)
if len(keys) != len(key_ids):
found_ids = {k.id for k in keys}
missing_ids = set(key_ids) - found_ids
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
# 批量更新优先级
key_map = {k.id: k for k in keys}
updated_count = 0
for item in self.priority_data.priorities:
key = key_map.get(item.key_id)
if key and key.internal_priority != item.internal_priority:
key.internal_priority = item.internal_priority
key.updated_at = datetime.now(timezone.utc)
updated_count += 1
db.commit()
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}")
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count}

View File

@@ -0,0 +1,345 @@
"""
ProviderEndpoint CRUD 管理 API
"""
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import List
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import and_, func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.models.endpoint_models import (
ProviderEndpointCreate,
ProviderEndpointResponse,
ProviderEndpointUpdate,
)
router = APIRouter(tags=["Endpoint Management"])
pipeline = ApiRequestPipeline()
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
async def list_provider_endpoints(
provider_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[ProviderEndpointResponse]:
"""获取指定 Provider 的所有 Endpoints"""
adapter = AdminListProviderEndpointsAdapter(
provider_id=provider_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/providers/{provider_id}/endpoints", response_model=ProviderEndpointResponse)
async def create_provider_endpoint(
provider_id: str,
endpoint_data: ProviderEndpointCreate,
request: Request,
db: Session = Depends(get_db),
) -> ProviderEndpointResponse:
"""为 Provider 创建新的 Endpoint"""
adapter = AdminCreateProviderEndpointAdapter(
provider_id=provider_id,
endpoint_data=endpoint_data,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{endpoint_id}", response_model=ProviderEndpointResponse)
async def get_endpoint(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ProviderEndpointResponse:
"""获取 Endpoint 详情"""
adapter = AdminGetProviderEndpointAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{endpoint_id}", response_model=ProviderEndpointResponse)
async def update_endpoint(
endpoint_id: str,
endpoint_data: ProviderEndpointUpdate,
request: Request,
db: Session = Depends(get_db),
) -> ProviderEndpointResponse:
"""更新 Endpoint"""
adapter = AdminUpdateProviderEndpointAdapter(
endpoint_id=endpoint_id,
endpoint_data=endpoint_data,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{endpoint_id}")
async def delete_endpoint(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""删除 Endpoint级联删除所有关联的 Keys"""
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminListProviderEndpointsAdapter(AdminApiAdapter):
provider_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
endpoints = (
db.query(ProviderEndpoint)
.filter(ProviderEndpoint.provider_id == self.provider_id)
.order_by(ProviderEndpoint.created_at.desc())
.offset(self.skip)
.limit(self.limit)
.all()
)
endpoint_ids = [ep.id for ep in endpoints]
total_keys_map = {}
active_keys_map = {}
if endpoint_ids:
total_rows = (
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
.group_by(ProviderAPIKey.endpoint_id)
.all()
)
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
active_rows = (
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
.filter(
and_(
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
ProviderAPIKey.is_active.is_(True),
)
)
.group_by(ProviderAPIKey.endpoint_id)
.all()
)
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
result: List[ProviderEndpointResponse] = []
for endpoint in endpoints:
endpoint_dict = {
**endpoint.__dict__,
"provider_name": provider.name,
"api_format": endpoint.api_format,
"total_keys": total_keys_map.get(endpoint.id, 0),
"active_keys": active_keys_map.get(endpoint.id, 0),
}
endpoint_dict.pop("_sa_instance_state", None)
result.append(ProviderEndpointResponse(**endpoint_dict))
return result
@dataclass
class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
provider_id: str
endpoint_data: ProviderEndpointCreate
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
if self.endpoint_data.provider_id != self.provider_id:
raise InvalidRequestException("provider_id 不匹配")
existing = (
db.query(ProviderEndpoint)
.filter(
and_(
ProviderEndpoint.provider_id == self.provider_id,
ProviderEndpoint.api_format == self.endpoint_data.api_format,
)
)
.first()
)
if existing:
raise InvalidRequestException(
f"Provider {provider.name} 已存在 {self.endpoint_data.api_format} 格式的 Endpoint"
)
now = datetime.now(timezone.utc)
new_endpoint = ProviderEndpoint(
id=str(uuid.uuid4()),
provider_id=self.provider_id,
api_format=self.endpoint_data.api_format,
base_url=self.endpoint_data.base_url,
headers=self.endpoint_data.headers,
timeout=self.endpoint_data.timeout,
max_retries=self.endpoint_data.max_retries,
max_concurrent=self.endpoint_data.max_concurrent,
rate_limit=self.endpoint_data.rate_limit,
is_active=True,
config=self.endpoint_data.config,
created_at=now,
updated_at=now,
)
db.add(new_endpoint)
db.commit()
db.refresh(new_endpoint)
logger.info(f"[OK] 创建 Endpoint: Provider={provider.name}, Format={self.endpoint_data.api_format}, ID={new_endpoint.id}")
endpoint_dict = {
k: v
for k, v in new_endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=new_endpoint.api_format,
total_keys=0,
active_keys=0,
)
@dataclass
class AdminGetProviderEndpointAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint, Provider)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(ProviderEndpoint.id == self.endpoint_id)
.first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
endpoint_obj, provider = endpoint
total_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
)
active_keys = (
db.query(ProviderAPIKey)
.filter(
and_(
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
)
endpoint_dict = {
k: v
for k, v in endpoint_obj.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=endpoint_obj.api_format,
total_keys=total_keys,
active_keys=active_keys,
)
@dataclass
class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
endpoint_id: str
endpoint_data: ProviderEndpointUpdate
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
update_data = self.endpoint_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(endpoint, field, value)
endpoint.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(endpoint)
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
total_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
)
active_keys = (
db.query(ProviderAPIKey)
.filter(
and_(
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
)
endpoint_dict = {
k: v
for k, v in endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name if provider else "Unknown",
api_format=endpoint.api_format,
total_keys=total_keys,
active_keys=active_keys,
)
@dataclass
class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys_count = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
)
db.delete(endpoint)
db.commit()
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}

View File

@@ -0,0 +1,16 @@
"""
模型管理相关 Admin API
"""
from fastapi import APIRouter
from .catalog import router as catalog_router
from .global_models import router as global_models_router
from .mappings import router as mappings_router
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
# 挂载子路由
router.include_router(catalog_router)
router.include_router(global_models_router)
router.include_router(mappings_router)

View File

@@ -0,0 +1,432 @@
"""
统一模型目录 Admin API
阶段一:基于 ModelMapping 和 Model 的聚合视图
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Set
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import func, or_
from sqlalchemy.orm import Session, joinedload
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.logger import logger
from src.database import get_db
from src.models.database import GlobalModel, Model, ModelMapping, Provider
from src.models.pydantic_models import (
BatchAssignError,
BatchAssignModelMappingRequest,
BatchAssignModelMappingResponse,
BatchAssignProviderResult,
DeleteModelMappingResponse,
ModelCapabilities,
ModelCatalogItem,
ModelCatalogProviderDetail,
ModelCatalogResponse,
ModelPriceRange,
OrphanedModel,
UpdateModelMappingRequest,
UpdateModelMappingResponse,
)
from src.services.cache.invalidation import get_cache_invalidation_service
from src.services.model.service import ModelService
router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"])
pipeline = ApiRequestPipeline()
@router.get("", response_model=ModelCatalogResponse)
async def get_model_catalog(
request: Request,
db: Session = Depends(get_db),
) -> ModelCatalogResponse:
adapter = AdminGetModelCatalogAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/batch-assign", response_model=BatchAssignModelMappingResponse)
async def batch_assign_model_mappings(
request: Request,
payload: BatchAssignModelMappingRequest,
db: Session = Depends(get_db),
) -> BatchAssignModelMappingResponse:
adapter = AdminBatchAssignModelMappingsAdapter(payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@dataclass
class AdminGetModelCatalogAdapter(AdminApiAdapter):
"""管理员查询统一模型目录
新架构说明:
1. 以 GlobalModel 为中心聚合数据
2. ModelMapping 表提供别名信息provider_id=NULL 表示全局)
3. Model 表提供关联提供商和价格
"""
async def handle(self, context): # type: ignore[override]
db: Session = context.db
# 1. 获取所有活跃的 GlobalModel
global_models: List[GlobalModel] = (
db.query(GlobalModel).filter(GlobalModel.is_active == True).all()
)
# 2. 获取所有活跃的别名(含全局和 Provider 特定)
aliases_rows: List[ModelMapping] = (
db.query(ModelMapping)
.options(joinedload(ModelMapping.target_global_model))
.filter(
ModelMapping.is_active == True,
ModelMapping.provider_id.is_(None),
)
.all()
)
# 按 GlobalModel ID 组织别名
aliases_by_global_model: Dict[str, List[str]] = {}
for alias_row in aliases_rows:
if not alias_row.target_global_model_id:
continue
gm_id = alias_row.target_global_model_id
if gm_id not in aliases_by_global_model:
aliases_by_global_model[gm_id] = []
if alias_row.source_model not in aliases_by_global_model[gm_id]:
aliases_by_global_model[gm_id].append(alias_row.source_model)
# 3. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
models: List[Model] = (
db.query(Model)
.options(joinedload(Model.provider), joinedload(Model.global_model))
.filter(Model.is_active == True)
.all()
)
# 按 GlobalModel ID 组织关联提供商
models_by_global_model: Dict[str, List[Model]] = {}
for model in models:
if model.global_model_id:
models_by_global_model.setdefault(model.global_model_id, []).append(model)
# 4. 为每个 GlobalModel 构建 catalog item
catalog_items: List[ModelCatalogItem] = []
for gm in global_models:
gm_id = gm.id
provider_entries: List[ModelCatalogProviderDetail] = []
capability_flags = {
"supports_vision": gm.default_supports_vision or False,
"supports_function_calling": gm.default_supports_function_calling or False,
"supports_streaming": gm.default_supports_streaming or False,
}
# 遍历该 GlobalModel 的所有关联提供商
for model in models_by_global_model.get(gm_id, []):
provider = model.provider
if not provider:
continue
# 使用有效价格(考虑 GlobalModel 默认值)
effective_input = model.get_effective_input_price()
effective_output = model.get_effective_output_price()
effective_tiered = model.get_effective_tiered_pricing()
tier_count = len(effective_tiered.get("tiers", [])) if effective_tiered else 1
# 使用有效能力值
capability_flags["supports_vision"] = (
capability_flags["supports_vision"] or model.get_effective_supports_vision()
)
capability_flags["supports_function_calling"] = (
capability_flags["supports_function_calling"]
or model.get_effective_supports_function_calling()
)
capability_flags["supports_streaming"] = (
capability_flags["supports_streaming"]
or model.get_effective_supports_streaming()
)
provider_entries.append(
ModelCatalogProviderDetail(
provider_id=provider.id,
provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id,
target_model=model.provider_model_name,
# 显示有效价格
input_price_per_1m=effective_input,
output_price_per_1m=effective_output,
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
cache_read_price_per_1m=model.get_effective_cache_read_price(),
cache_1h_creation_price_per_1m=model.get_effective_1h_cache_creation_price(),
price_per_request=model.get_effective_price_per_request(),
effective_tiered_pricing=effective_tiered,
tier_count=tier_count,
supports_vision=model.get_effective_supports_vision(),
supports_function_calling=model.get_effective_supports_function_calling(),
supports_streaming=model.get_effective_supports_streaming(),
is_active=bool(model.is_active),
mapping_id=None, # 新架构中不再有 mapping_id
)
)
# 模型目录显示 GlobalModel 的第一个阶梯价格(不是 Provider 聚合价格)
tiered = gm.default_tiered_pricing or {}
first_tier = tiered.get("tiers", [{}])[0] if tiered.get("tiers") else {}
price_range = ModelPriceRange(
min_input=first_tier.get("input_price_per_1m", 0),
max_input=first_tier.get("input_price_per_1m", 0),
min_output=first_tier.get("output_price_per_1m", 0),
max_output=first_tier.get("output_price_per_1m", 0),
)
catalog_items.append(
ModelCatalogItem(
global_model_name=gm.name,
display_name=gm.display_name,
description=gm.description,
aliases=aliases_by_global_model.get(gm_id, []),
providers=provider_entries,
price_range=price_range,
total_providers=len(provider_entries),
capabilities=ModelCapabilities(**capability_flags),
)
)
# 5. 查找孤立的别名(别名指向的 GlobalModel 不存在或不活跃)
orphaned_rows = (
db.query(ModelMapping.source_model, GlobalModel.name, func.count(ModelMapping.id))
.outerjoin(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
.filter(
ModelMapping.is_active == True,
ModelMapping.provider_id.is_(None),
or_(GlobalModel.id == None, GlobalModel.is_active == False),
)
.group_by(ModelMapping.source_model, GlobalModel.name)
.all()
)
orphaned_models = [
OrphanedModel(alias=row[0], global_model_name=row[1], mapping_count=row[2])
for row in orphaned_rows
if row[0]
]
return ModelCatalogResponse(
models=catalog_items,
total=len(catalog_items),
orphaned_models=orphaned_models,
)
@dataclass
class AdminBatchAssignModelMappingsAdapter(AdminApiAdapter):
payload: BatchAssignModelMappingRequest
async def handle(self, context): # type: ignore[override]
db: Session = context.db
created: List[BatchAssignProviderResult] = []
errors: List[BatchAssignError] = []
for provider_config in self.payload.providers:
provider_id = provider_config.provider_id
try:
provider: Provider = db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
errors.append(
BatchAssignError(provider_id=provider_id, error="Provider 不存在")
)
continue
model_id: Optional[str] = None
created_model = False
if provider_config.create_model:
model_data = provider_config.model_data
if not model_data:
errors.append(
BatchAssignError(provider_id=provider_id, error="缺少 model_data 配置")
)
continue
existing_model = ModelService.get_model_by_name(
db, provider_id, model_data.provider_model_name
)
if existing_model:
model_id = existing_model.id
logger.info("模型 %s 已存在于 Provider %s,复用现有模型",
model_data.provider_model_name,
provider.name,
)
else:
model = ModelService.create_model(db, provider_id, model_data)
model_id = model.id
created_model = True
else:
model_id = provider_config.model_id
if not model_id:
errors.append(
BatchAssignError(provider_id=provider_id, error="缺少 model_id")
)
continue
model = (
db.query(Model)
.filter(Model.id == model_id, Model.provider_id == provider_id)
.first()
)
if not model:
errors.append(
BatchAssignError(
provider_id=provider_id, error="模型不存在或不属于当前 Provider")
)
continue
# 批量分配功能需要适配 GlobalModel 架构
# 参见 docs/optimization-backlog.md 中的待办项
errors.append(
BatchAssignError(
provider_id=provider_id,
error="批量分配功能暂时不可用,需要适配新的 GlobalModel 架构",
)
)
continue
except Exception as exc:
db.rollback()
logger.error("批量添加模型映射失败(需要适配新架构)")
errors.append(BatchAssignError(provider_id=provider_id, error=str(exc)))
return BatchAssignModelMappingResponse(
success=len(created) > 0,
created_mappings=created,
errors=errors,
)
@router.put("/mappings/{mapping_id}", response_model=UpdateModelMappingResponse)
async def update_model_mapping(
request: Request,
mapping_id: str,
payload: UpdateModelMappingRequest,
db: Session = Depends(get_db),
) -> UpdateModelMappingResponse:
"""更新模型映射"""
adapter = AdminUpdateModelMappingAdapter(mapping_id=mapping_id, payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/mappings/{mapping_id}", response_model=DeleteModelMappingResponse)
async def delete_model_mapping(
request: Request,
mapping_id: str,
db: Session = Depends(get_db),
) -> DeleteModelMappingResponse:
"""删除模型映射"""
adapter = AdminDeleteModelMappingAdapter(mapping_id=mapping_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@dataclass
class AdminUpdateModelMappingAdapter(AdminApiAdapter):
"""更新模型映射"""
mapping_id: str
payload: UpdateModelMappingRequest
async def handle(self, context): # type: ignore[override]
db: Session = context.db
mapping: Optional[ModelMapping] = (
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
)
if not mapping:
raise HTTPException(status_code=404, detail="映射不存在")
update_data = self.payload.model_dump(exclude_unset=True)
if "provider_id" in update_data:
new_provider_id = update_data["provider_id"]
if new_provider_id:
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider 不存在")
mapping.provider_id = new_provider_id
if "target_global_model_id" in update_data:
target_model = (
db.query(GlobalModel)
.filter(
GlobalModel.id == update_data["target_global_model_id"],
GlobalModel.is_active == True,
)
.first()
)
if not target_model:
raise HTTPException(status_code=404, detail="目标 GlobalModel 不存在或未激活")
mapping.target_global_model_id = update_data["target_global_model_id"]
if "source_model" in update_data:
new_source = update_data["source_model"].strip()
if not new_source:
raise HTTPException(status_code=400, detail="source_model 不能为空")
mapping.source_model = new_source
if "is_active" in update_data:
mapping.is_active = update_data["is_active"]
duplicate = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == mapping.source_model,
ModelMapping.provider_id == mapping.provider_id,
ModelMapping.id != mapping.id,
)
.first()
)
if duplicate:
raise HTTPException(status_code=400, detail="映射已存在")
db.commit()
db.refresh(mapping)
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
return UpdateModelMappingResponse(
success=True,
mapping_id=mapping.id,
message="映射更新成功",
)
@dataclass
class AdminDeleteModelMappingAdapter(AdminApiAdapter):
"""删除模型映射"""
mapping_id: str
async def handle(self, context): # type: ignore[override]
db: Session = context.db
mapping: Optional[ModelMapping] = (
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
)
if not mapping:
raise HTTPException(status_code=404, detail="映射不存在")
source_model = mapping.source_model
provider_id = mapping.provider_id
db.delete(mapping)
db.commit()
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(source_model, provider_id)
return DeleteModelMappingResponse(
success=True,
message=f"映射 {self.mapping_id} 已删除",
)

View File

@@ -0,0 +1,292 @@
"""
GlobalModel Admin API
提供 GlobalModel 的 CRUD 操作接口
"""
from dataclasses import dataclass
from typing import List, Optional
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.logger import logger
from src.database import get_db
from src.models.pydantic_models import (
BatchAssignToProvidersRequest,
BatchAssignToProvidersResponse,
GlobalModelCreate,
GlobalModelListResponse,
GlobalModelResponse,
GlobalModelUpdate,
GlobalModelWithStats,
)
from src.services.model.global_model import GlobalModelService
router = APIRouter(prefix="/global", tags=["Admin - Global Models"])
pipeline = ApiRequestPipeline()
@router.get("", response_model=GlobalModelListResponse)
async def list_global_models(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
is_active: Optional[bool] = Query(None),
search: Optional[str] = Query(None),
db: Session = Depends(get_db),
) -> GlobalModelListResponse:
"""获取 GlobalModel 列表"""
adapter = AdminListGlobalModelsAdapter(
skip=skip,
limit=limit,
is_active=is_active,
search=search,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{global_model_id}", response_model=GlobalModelWithStats)
async def get_global_model(
request: Request,
global_model_id: str,
db: Session = Depends(get_db),
) -> GlobalModelWithStats:
"""获取单个 GlobalModel 详情(含统计信息)"""
adapter = AdminGetGlobalModelAdapter(global_model_id=global_model_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("", response_model=GlobalModelResponse, status_code=201)
async def create_global_model(
request: Request,
payload: GlobalModelCreate,
db: Session = Depends(get_db),
) -> GlobalModelResponse:
"""创建 GlobalModel"""
adapter = AdminCreateGlobalModelAdapter(payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{global_model_id}", response_model=GlobalModelResponse)
async def update_global_model(
request: Request,
global_model_id: str,
payload: GlobalModelUpdate,
db: Session = Depends(get_db),
) -> GlobalModelResponse:
"""更新 GlobalModel"""
adapter = AdminUpdateGlobalModelAdapter(global_model_id=global_model_id, payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{global_model_id}", status_code=204)
async def delete_global_model(
request: Request,
global_model_id: str,
db: Session = Depends(get_db),
):
"""删除 GlobalModel级联删除所有关联的 Provider 模型实现)"""
adapter = AdminDeleteGlobalModelAdapter(global_model_id=global_model_id)
await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
return None
@router.post(
"/{global_model_id}/assign-to-providers", response_model=BatchAssignToProvidersResponse
)
async def batch_assign_to_providers(
request: Request,
global_model_id: str,
payload: BatchAssignToProvidersRequest,
db: Session = Depends(get_db),
) -> BatchAssignToProvidersResponse:
"""批量为多个 Provider 添加 GlobalModel 实现"""
adapter = AdminBatchAssignToProvidersAdapter(global_model_id=global_model_id, payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ========== Adapters ==========
@dataclass
class AdminListGlobalModelsAdapter(AdminApiAdapter):
"""列出 GlobalModel"""
skip: int
limit: int
is_active: Optional[bool]
search: Optional[str]
async def handle(self, context): # type: ignore[override]
from sqlalchemy import func
from src.models.database import Model, ModelMapping
models = GlobalModelService.list_global_models(
db=context.db,
skip=self.skip,
limit=self.limit,
is_active=self.is_active,
search=self.search,
)
# 为每个 GlobalModel 添加统计数据
model_responses = []
for gm in models:
# 统计关联的 Model 数量(去重 Provider
provider_count = (
context.db.query(func.count(func.distinct(Model.provider_id)))
.filter(Model.global_model_id == gm.id)
.scalar()
or 0
)
# 统计别名数量
alias_count = (
context.db.query(func.count(ModelMapping.id))
.filter(ModelMapping.target_global_model_id == gm.id)
.scalar()
or 0
)
response = GlobalModelResponse.model_validate(gm)
response.provider_count = provider_count
response.alias_count = alias_count
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
model_responses.append(response)
return GlobalModelListResponse(
models=model_responses,
total=len(models),
)
@dataclass
class AdminGetGlobalModelAdapter(AdminApiAdapter):
"""获取单个 GlobalModel"""
global_model_id: str
async def handle(self, context): # type: ignore[override]
global_model = GlobalModelService.get_global_model(context.db, self.global_model_id)
stats = GlobalModelService.get_global_model_stats(context.db, self.global_model_id)
return GlobalModelWithStats(
**GlobalModelResponse.model_validate(global_model).model_dump(),
total_models=stats["total_models"],
total_providers=stats["total_providers"],
price_range=stats["price_range"],
)
@dataclass
class AdminCreateGlobalModelAdapter(AdminApiAdapter):
"""创建 GlobalModel"""
payload: GlobalModelCreate
async def handle(self, context): # type: ignore[override]
# 将 TieredPricingConfig 转换为 dict
tiered_pricing_dict = self.payload.default_tiered_pricing.model_dump()
global_model = GlobalModelService.create_global_model(
db=context.db,
name=self.payload.name,
display_name=self.payload.display_name,
description=self.payload.description,
official_url=self.payload.official_url,
icon_url=self.payload.icon_url,
is_active=self.payload.is_active,
# 按次计费配置
default_price_per_request=self.payload.default_price_per_request,
# 阶梯计费配置
default_tiered_pricing=tiered_pricing_dict,
# 默认能力配置
default_supports_vision=self.payload.default_supports_vision,
default_supports_function_calling=self.payload.default_supports_function_calling,
default_supports_streaming=self.payload.default_supports_streaming,
default_supports_extended_thinking=self.payload.default_supports_extended_thinking,
# Key 能力配置
supported_capabilities=self.payload.supported_capabilities,
)
logger.info(f"GlobalModel 已创建: id={global_model.id} name={global_model.name}")
return GlobalModelResponse.model_validate(global_model)
@dataclass
class AdminUpdateGlobalModelAdapter(AdminApiAdapter):
"""更新 GlobalModel"""
global_model_id: str
payload: GlobalModelUpdate
async def handle(self, context): # type: ignore[override]
global_model = GlobalModelService.update_global_model(
db=context.db,
global_model_id=self.global_model_id,
update_data=self.payload,
)
logger.info(f"GlobalModel 已更新: id={global_model.id} name={global_model.name}")
# 失效相关缓存
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.on_global_model_changed(global_model.name)
return GlobalModelResponse.model_validate(global_model)
@dataclass
class AdminDeleteGlobalModelAdapter(AdminApiAdapter):
"""删除 GlobalModel级联删除所有关联的 Provider 模型实现)"""
global_model_id: str
async def handle(self, context): # type: ignore[override]
# 先获取 GlobalModel 信息(用于失效缓存)
from src.models.database import GlobalModel
global_model = (
context.db.query(GlobalModel).filter(GlobalModel.id == self.global_model_id).first()
)
model_name = global_model.name if global_model else None
GlobalModelService.delete_global_model(context.db, self.global_model_id)
logger.info(f"GlobalModel 已删除: id={self.global_model_id}")
# 失效相关缓存
if model_name:
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.on_global_model_changed(model_name)
return None
@dataclass
class AdminBatchAssignToProvidersAdapter(AdminApiAdapter):
"""批量为 Provider 添加 GlobalModel 实现"""
global_model_id: str
payload: BatchAssignToProvidersRequest
async def handle(self, context): # type: ignore[override]
result = GlobalModelService.batch_assign_to_providers(
db=context.db,
global_model_id=self.global_model_id,
provider_ids=self.payload.provider_ids,
create_models=self.payload.create_models,
)
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
return BatchAssignToProvidersResponse(**result)

View File

@@ -0,0 +1,303 @@
"""模型映射管理 API
提供模型映射的 CRUD 操作。
模型映射Mapping用于将源模型映射到目标模型例如
- 请求 gpt-5.1 → Provider A 映射到 gpt-4
- 用于处理 Provider 不支持请求模型的情况
映射必须关联到特定的 Provider。
"""
import uuid
from datetime import datetime, timezone
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session, joinedload
from src.core.logger import logger
from src.database import get_db
from src.models.api import (
ModelMappingCreate,
ModelMappingResponse,
ModelMappingUpdate,
)
from src.models.database import GlobalModel, ModelMapping, Provider, User
from src.services.cache.invalidation import get_cache_invalidation_service
router = APIRouter(prefix="/mappings", tags=["Model Mappings"])
def _serialize_mapping(mapping: ModelMapping) -> ModelMappingResponse:
target = mapping.target_global_model
provider = mapping.provider
scope = "provider" if mapping.provider_id else "global"
return ModelMappingResponse(
id=mapping.id,
source_model=mapping.source_model,
target_global_model_id=mapping.target_global_model_id,
target_global_model_name=target.name if target else None,
target_global_model_display_name=target.display_name if target else None,
provider_id=mapping.provider_id,
provider_name=provider.name if provider else None,
scope=scope,
mapping_type=mapping.mapping_type,
is_active=mapping.is_active,
created_at=mapping.created_at,
updated_at=mapping.updated_at,
)
@router.get("", response_model=List[ModelMappingResponse])
async def list_mappings(
provider_id: Optional[str] = Query(None, description="按 Provider 筛选"),
source_model: Optional[str] = Query(None, description="按源模型名筛选"),
target_global_model_id: Optional[str] = Query(None, description="按目标模型筛选"),
scope: Optional[str] = Query(None, description="global 或 provider"),
mapping_type: Optional[str] = Query(None, description="映射类型: alias 或 mapping"),
is_active: Optional[bool] = Query(None, description="按状态筛选"),
skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回记录数"),
db: Session = Depends(get_db),
):
"""获取模型映射列表"""
query = db.query(ModelMapping).options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
if provider_id is not None:
query = query.filter(ModelMapping.provider_id == provider_id)
if scope == "global":
query = query.filter(ModelMapping.provider_id.is_(None))
elif scope == "provider":
query = query.filter(ModelMapping.provider_id.isnot(None))
if mapping_type is not None:
query = query.filter(ModelMapping.mapping_type == mapping_type)
if source_model:
query = query.filter(ModelMapping.source_model.ilike(f"%{source_model}%"))
if target_global_model_id is not None:
query = query.filter(ModelMapping.target_global_model_id == target_global_model_id)
if is_active is not None:
query = query.filter(ModelMapping.is_active == is_active)
mappings = query.offset(skip).limit(limit).all()
return [_serialize_mapping(mapping) for mapping in mappings]
@router.get("/{mapping_id}", response_model=ModelMappingResponse)
async def get_mapping(
mapping_id: str,
db: Session = Depends(get_db),
):
"""获取单个模型映射"""
mapping = (
db.query(ModelMapping)
.options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
.filter(ModelMapping.id == mapping_id)
.first()
)
if not mapping:
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
return _serialize_mapping(mapping)
@router.post("", response_model=ModelMappingResponse, status_code=201)
async def create_mapping(
data: ModelMappingCreate,
db: Session = Depends(get_db),
):
"""创建模型映射"""
source_model = data.source_model.strip()
if not source_model:
raise HTTPException(status_code=400, detail="source_model 不能为空")
# 验证 mapping_type
if data.mapping_type not in ("alias", "mapping"):
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias''mapping'")
# 验证目标 GlobalModel 存在
target_model = (
db.query(GlobalModel)
.filter(GlobalModel.id == data.target_global_model_id, GlobalModel.is_active == True)
.first()
)
if not target_model:
raise HTTPException(
status_code=404, detail=f"目标模型 {data.target_global_model_id} 不存在或未激活"
)
# 验证 Provider 存在
provider = None
provider_id = data.provider_id
if provider_id:
provider = db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail=f"Provider {provider_id} 不存在")
# 检查映射是否已存在(全局或同一 Provider 下不可重复)
existing = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id == provider_id,
)
.first()
)
if existing:
raise HTTPException(status_code=400, detail="映射已存在")
# 创建映射
mapping = ModelMapping(
id=str(uuid.uuid4()),
source_model=source_model,
target_global_model_id=data.target_global_model_id,
provider_id=provider_id,
mapping_type=data.mapping_type,
is_active=data.is_active,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db.add(mapping)
db.commit()
mapping = (
db.query(ModelMapping)
.options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
.filter(ModelMapping.id == mapping.id)
.first()
)
logger.info(f"创建模型映射: {source_model} -> {target_model.name} "
f"(Provider: {provider.name if provider else 'global'}, ID: {mapping.id})")
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(source_model, provider_id)
return _serialize_mapping(mapping)
@router.patch("/{mapping_id}", response_model=ModelMappingResponse)
async def update_mapping(
mapping_id: str,
data: ModelMappingUpdate,
db: Session = Depends(get_db),
):
"""更新模型映射"""
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
if not mapping:
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
update_data = data.model_dump(exclude_unset=True)
# 更新 Provider
if "provider_id" in update_data:
new_provider_id = update_data["provider_id"]
if new_provider_id:
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail=f"Provider {new_provider_id} 不存在")
mapping.provider_id = new_provider_id
# 更新目标模型
if "target_global_model_id" in update_data:
target_model = (
db.query(GlobalModel)
.filter(
GlobalModel.id == update_data["target_global_model_id"],
GlobalModel.is_active == True,
)
.first()
)
if not target_model:
raise HTTPException(
status_code=404,
detail=f"目标模型 {update_data['target_global_model_id']} 不存在或未激活",
)
mapping.target_global_model_id = update_data["target_global_model_id"]
# 更新源模型名
if "source_model" in update_data:
new_source = update_data["source_model"].strip()
if not new_source:
raise HTTPException(status_code=400, detail="source_model 不能为空")
mapping.source_model = new_source
# 检查唯一约束
duplicate = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == mapping.source_model,
ModelMapping.provider_id == mapping.provider_id,
ModelMapping.id != mapping_id,
)
.first()
)
if duplicate:
raise HTTPException(status_code=400, detail="映射已存在")
# 更新映射类型
if "mapping_type" in update_data:
if update_data["mapping_type"] not in ("alias", "mapping"):
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias''mapping'")
mapping.mapping_type = update_data["mapping_type"]
# 更新状态
if "is_active" in update_data:
mapping.is_active = update_data["is_active"]
mapping.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(mapping)
logger.info(f"更新模型映射 (ID: {mapping.id})")
mapping = (
db.query(ModelMapping)
.options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
.filter(ModelMapping.id == mapping.id)
.first()
)
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
return _serialize_mapping(mapping)
@router.delete("/{mapping_id}", status_code=204)
async def delete_mapping(
mapping_id: str,
db: Session = Depends(get_db),
):
"""删除模型映射"""
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
if not mapping:
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
source_model = mapping.source_model
provider_id = mapping.provider_id
logger.info(f"删除模型映射: {source_model} -> {mapping.target_global_model_id} (ID: {mapping.id})")
db.delete(mapping)
db.commit()
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(source_model, provider_id)
return None

View File

@@ -0,0 +1,14 @@
"""Admin monitoring router合集。"""
from fastapi import APIRouter
from .audit import router as audit_router
from .cache import router as cache_router
from .trace import router as trace_router
router = APIRouter()
router.include_router(audit_router)
router.include_router(cache_router)
router.include_router(trace_router)
__all__ = ["router"]

View File

@@ -0,0 +1,399 @@
"""管理员监控与审计端点。"""
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_query
from src.api.base.pipeline import ApiRequestPipeline
from src.core.logger import logger
from src.database import get_db
from src.models.database import (
ApiKey,
AuditEventType,
AuditLog,
Provider,
Usage,
)
from src.models.database import User as DBUser
from src.services.health.monitor import HealthMonitor
from src.services.system.audit import audit_service
router = APIRouter(prefix="/api/admin/monitoring", tags=["Admin - Monitoring"])
pipeline = ApiRequestPipeline()
@router.get("/audit-logs")
async def get_audit_logs(
request: Request,
user_id: Optional[str] = Query(None, description="用户ID筛选 (支持UUID)"),
event_type: Optional[str] = Query(None, description="事件类型筛选"),
days: int = Query(7, description="查询天数"),
limit: int = Query(100, description="返回数量限制"),
offset: int = Query(0, description="偏移量"),
db: Session = Depends(get_db),
):
adapter = AdminGetAuditLogsAdapter(
user_id=user_id,
event_type=event_type,
days=days,
limit=limit,
offset=offset,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/system-status")
async def get_system_status(request: Request, db: Session = Depends(get_db)):
adapter = AdminSystemStatusAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/suspicious-activities")
async def get_suspicious_activities(
request: Request,
hours: int = Query(24, description="时间范围(小时)"),
db: Session = Depends(get_db),
):
adapter = AdminSuspiciousActivitiesAdapter(hours=hours)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/user-behavior/{user_id}")
async def analyze_user_behavior(
user_id: str,
request: Request,
days: int = Query(30, description="分析天数"),
db: Session = Depends(get_db),
):
adapter = AdminUserBehaviorAdapter(user_id=user_id, days=days)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/resilience-status")
async def get_resilience_status(request: Request, db: Session = Depends(get_db)):
adapter = AdminResilienceStatusAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/resilience/error-stats")
async def reset_error_stats(request: Request, db: Session = Depends(get_db)):
"""Reset resilience error statistics"""
adapter = AdminResetErrorStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/resilience/circuit-history")
async def get_circuit_history(
request: Request,
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
):
adapter = AdminCircuitHistoryAdapter(limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@dataclass
class AdminGetAuditLogsAdapter(AdminApiAdapter):
user_id: Optional[str]
event_type: Optional[str]
days: int
limit: int
offset: int
async def handle(self, context): # type: ignore[override]
db = context.db
cutoff_time = datetime.now(timezone.utc) - timedelta(days=self.days)
base_query = (
db.query(AuditLog, DBUser)
.outerjoin(DBUser, AuditLog.user_id == DBUser.id)
.filter(AuditLog.created_at >= cutoff_time)
)
if self.user_id:
base_query = base_query.filter(AuditLog.user_id == self.user_id)
if self.event_type:
base_query = base_query.filter(AuditLog.event_type == self.event_type)
ordered_query = base_query.order_by(AuditLog.created_at.desc())
total, logs_with_users = paginate_query(ordered_query, self.limit, self.offset)
items = [
{
"id": log.id,
"event_type": log.event_type,
"user_id": log.user_id,
"user_email": user.email if user else None,
"user_username": user.username if user else None,
"description": log.description,
"ip_address": log.ip_address,
"status_code": log.status_code,
"error_message": log.error_message,
"metadata": log.event_metadata,
"created_at": log.created_at.isoformat() if log.created_at else None,
}
for log, user in logs_with_users
]
meta = PaginationMeta(
total=total,
limit=self.limit,
offset=self.offset,
count=len(items),
)
payload = build_pagination_payload(
items,
meta,
filters={
"user_id": self.user_id,
"event_type": self.event_type,
"days": self.days,
},
)
context.add_audit_metadata(
action="monitor_audit_logs",
filter_user_id=self.user_id,
filter_event_type=self.event_type,
days=self.days,
limit=self.limit,
offset=self.offset,
total=total,
result_count=meta.count,
)
return payload
class AdminSystemStatusAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
total_users = db.query(func.count(DBUser.id)).scalar()
active_users = db.query(func.count(DBUser.id)).filter(DBUser.is_active.is_(True)).scalar()
total_providers = db.query(func.count(Provider.id)).scalar()
active_providers = (
db.query(func.count(Provider.id)).filter(Provider.is_active.is_(True)).scalar()
)
total_api_keys = db.query(func.count(ApiKey.id)).scalar()
active_api_keys = (
db.query(func.count(ApiKey.id)).filter(ApiKey.is_active.is_(True)).scalar()
)
today_start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
today_requests = (
db.query(func.count(Usage.id)).filter(Usage.created_at >= today_start).scalar()
)
today_tokens = (
db.query(func.sum(Usage.total_tokens)).filter(Usage.created_at >= today_start).scalar()
or 0
)
today_cost = (
db.query(func.sum(Usage.total_cost_usd))
.filter(Usage.created_at >= today_start)
.scalar()
or 0
)
recent_errors = (
db.query(AuditLog)
.filter(
AuditLog.event_type.in_(
[
AuditEventType.REQUEST_FAILED.value,
AuditEventType.SUSPICIOUS_ACTIVITY.value,
]
),
AuditLog.created_at >= datetime.now(timezone.utc) - timedelta(hours=1),
)
.count()
)
context.add_audit_metadata(
action="system_status_snapshot",
total_users=int(total_users or 0),
active_users=int(active_users or 0),
total_providers=int(total_providers or 0),
active_providers=int(active_providers or 0),
total_api_keys=int(total_api_keys or 0),
active_api_keys=int(active_api_keys or 0),
today_requests=int(today_requests or 0),
today_tokens=int(today_tokens or 0),
today_cost=float(today_cost or 0.0),
recent_errors=int(recent_errors or 0),
)
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"users": {"total": total_users, "active": active_users},
"providers": {"total": total_providers, "active": active_providers},
"api_keys": {"total": total_api_keys, "active": active_api_keys},
"today_stats": {
"requests": today_requests,
"tokens": today_tokens,
"cost_usd": f"${today_cost:.4f}",
},
"recent_errors": recent_errors,
}
@dataclass
class AdminSuspiciousActivitiesAdapter(AdminApiAdapter):
hours: int
async def handle(self, context): # type: ignore[override]
db = context.db
activities = audit_service.get_suspicious_activities(db=db, hours=self.hours, limit=100)
response = {
"activities": [
{
"id": activity.id,
"event_type": activity.event_type,
"user_id": activity.user_id,
"description": activity.description,
"ip_address": activity.ip_address,
"metadata": activity.event_metadata,
"created_at": activity.created_at.isoformat() if activity.created_at else None,
}
for activity in activities
],
"count": len(activities),
"time_range_hours": self.hours,
}
context.add_audit_metadata(
action="monitor_suspicious_activity",
hours=self.hours,
result_count=len(activities),
)
return response
@dataclass
class AdminUserBehaviorAdapter(AdminApiAdapter):
user_id: str
days: int
async def handle(self, context): # type: ignore[override]
result = audit_service.analyze_user_behavior(
db=context.db,
user_id=self.user_id,
days=self.days,
)
context.add_audit_metadata(
action="monitor_user_behavior",
target_user_id=self.user_id,
days=self.days,
contains_summary=bool(result),
)
return result
class AdminResilienceStatusAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
try:
from src.core.resilience import resilience_manager
except ImportError as exc:
raise HTTPException(status_code=503, detail="韧性管理系统未启用") from exc
error_stats = resilience_manager.get_error_stats()
recent_errors = [
{
"error_id": info["error_id"],
"error_type": info["error_type"],
"operation": info["operation"],
"timestamp": info["timestamp"].isoformat(),
"context": info.get("context", {}),
}
for info in resilience_manager.last_errors[-10:]
]
total_errors = error_stats.get("total_errors", 0)
circuit_breakers = error_stats.get("circuit_breakers", {})
circuit_breakers_open = sum(
1 for status in circuit_breakers.values() if status.get("state") == "open"
)
health_score = max(0, 100 - (total_errors * 2) - (circuit_breakers_open * 20))
response = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"health_score": health_score,
"status": (
"healthy" if health_score > 80 else "degraded" if health_score > 50 else "critical"
),
"error_statistics": error_stats,
"recent_errors": recent_errors,
"recommendations": _get_health_recommendations(error_stats, health_score),
}
context.add_audit_metadata(
action="resilience_status",
health_score=health_score,
error_total=error_stats.get("total_errors") if isinstance(error_stats, dict) else None,
open_circuit_breakers=circuit_breakers_open,
)
return response
class AdminResetErrorStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
try:
from src.core.resilience import resilience_manager
except ImportError as exc:
raise HTTPException(status_code=503, detail="韧性管理系统未启用") from exc
old_stats = resilience_manager.get_error_stats()
resilience_manager.error_stats.clear()
resilience_manager.last_errors.clear()
logger.info(f"管理员 {context.user.email if context.user else 'unknown'} 重置了错误统计")
context.add_audit_metadata(
action="reset_error_stats",
previous_total_errors=(
old_stats.get("total_errors") if isinstance(old_stats, dict) else None
),
)
return {
"message": "错误统计已重置",
"previous_stats": old_stats,
"reset_by": context.user.email if context.user else None,
"reset_at": datetime.now(timezone.utc).isoformat(),
}
class AdminCircuitHistoryAdapter(AdminApiAdapter):
def __init__(self, limit: int = 50):
super().__init__()
self.limit = limit
async def handle(self, context): # type: ignore[override]
history = HealthMonitor.get_circuit_history(self.limit)
context.add_audit_metadata(
action="circuit_history",
limit=self.limit,
result_count=len(history),
)
return {"items": history, "count": len(history)}
def _get_health_recommendations(error_stats: dict, health_score: int) -> List[str]:
recommendations: List[str] = []
if health_score < 50:
recommendations.append("系统健康状况严重,请立即检查错误日志")
if error_stats.get("total_errors", 0) > 100:
recommendations.append("错误频率过高,建议检查系统配置和外部依赖")
circuit_breakers = error_stats.get("circuit_breakers", {})
open_breakers = [k for k, v in circuit_breakers.items() if v.get("state") == "open"]
if open_breakers:
recommendations.append(f"以下服务熔断器已打开:{', '.join(open_breakers)}")
if health_score > 90:
recommendations.append("系统运行良好")
return recommendations

View File

@@ -0,0 +1,871 @@
"""
缓存监控端点
提供缓存亲和性统计、管理和监控功能
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import PlainTextResponse
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence
from src.api.base.pipeline import ApiRequestPipeline
from src.clients.redis_client import get_redis_client_sync
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.database import get_db
from src.models.database import ApiKey, User
from src.services.cache.affinity_manager import get_affinity_manager
from src.services.cache.aware_scheduler import get_cache_aware_scheduler
router = APIRouter(prefix="/api/admin/monitoring/cache", tags=["Admin - Monitoring: Cache"])
pipeline = ApiRequestPipeline()
def mask_api_key(api_key: Optional[str], prefix_len: int = 8, suffix_len: int = 4) -> Optional[str]:
"""
脱敏 API Key显示前缀 + 星号 + 后缀
例如: sk-jhiId-xxxxxxxxxxxAABB -> sk-jhiId-********AABB
Args:
api_key: 原始 API Key
prefix_len: 显示的前缀长度,默认 8
suffix_len: 显示的后缀长度,默认 4
"""
if not api_key:
return None
total_visible = prefix_len + suffix_len
if len(api_key) <= total_visible:
# Key 太短,直接返回部分内容 + 星号
return api_key[:prefix_len] + "********"
return f"{api_key[:prefix_len]}********{api_key[-suffix_len:]}"
def decrypt_and_mask(encrypted_key: Optional[str], prefix_len: int = 8) -> Optional[str]:
"""
解密 API Key 后脱敏显示
Args:
encrypted_key: 加密后的 API Key
prefix_len: 显示的前缀长度
"""
if not encrypted_key:
return None
try:
decrypted = crypto_service.decrypt(encrypted_key)
return mask_api_key(decrypted, prefix_len)
except Exception:
# 解密失败时返回 None
return None
def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
"""
将用户标识符username/email/user_id/api_key_id解析为 user_id
支持的输入格式:
1. User UUID (36位带横杠)
2. Username (用户名)
3. Email (邮箱)
4. API Key ID (36位UUID)
返回:
- user_id (UUID字符串) 或 None
"""
identifier = identifier.strip()
# 1. 先尝试作为 User UUID 查询
user = db.query(User).filter(User.id == identifier).first()
if user:
logger.debug(f"通过User ID解析: {identifier[:8]}... -> {user.username}")
return user.id
# 2. 尝试作为 Username 查询
user = db.query(User).filter(User.username == identifier).first()
if user:
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...")
return user.id
# 3. 尝试作为 Email 查询
user = db.query(User).filter(User.email == identifier).first()
if user:
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...")
return user.id
# 4. 尝试作为 API Key ID 查询
api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first()
if api_key:
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...")
return api_key.user_id
# 无法识别
logger.debug(f"无法识别的用户标识符: {identifier}")
return None
@router.get("/stats")
async def get_cache_stats(
request: Request,
db: Session = Depends(get_db),
):
"""
获取缓存亲和性统计信息
返回:
- 缓存命中率
- 缓存用户数
- Provider切换次数
- Key切换次数
- 缓存预留配置
"""
adapter = AdminCacheStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/affinity/{user_identifier}")
async def get_user_affinity(
user_identifier: str,
request: Request,
db: Session = Depends(get_db),
):
"""
查询指定用户的所有缓存亲和性
参数:
- user_identifier: 用户标识符,支持以下格式:
* 用户名 (username),如: yuanhonghu
* 邮箱 (email),如: user@example.com
* 用户UUID (user_id),如: 550e8400-e29b-41d4-a716-446655440000
* API Key ID如: 660e8400-e29b-41d4-a716-446655440000
返回:
- 用户信息
- 所有端点的缓存亲和性列表(每个端点一条记录)
"""
adapter = AdminGetUserAffinityAdapter(user_identifier=user_identifier)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/affinities")
async def list_affinities(
request: Request,
keyword: Optional[str] = None,
limit: int = Query(100, ge=1, le=1000, description="返回数量限制"),
offset: int = Query(0, ge=0, description="偏移量"),
db: Session = Depends(get_db),
):
"""
获取所有缓存亲和性列表,可选按关键词过滤
参数:
- keyword: 可选,支持用户名/邮箱/User ID/API Key ID 或模糊匹配
"""
adapter = AdminListAffinitiesAdapter(keyword=keyword, limit=limit, offset=offset)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/users/{user_identifier}")
async def clear_user_cache(
user_identifier: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Clear cache affinity for a specific user
Parameters:
- user_identifier: User identifier (username, email, user_id, or API Key ID)
"""
adapter = AdminClearUserCacheAdapter(user_identifier=user_identifier)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("")
async def clear_all_cache(
request: Request,
db: Session = Depends(get_db),
):
"""
Clear all cache affinities
Warning: This affects all users, use with caution
"""
adapter = AdminClearAllCacheAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/providers/{provider_id}")
async def clear_provider_cache(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Clear cache affinities for a specific provider
Parameters:
- provider_id: Provider ID
"""
adapter = AdminClearProviderCacheAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/config")
async def get_cache_config(
request: Request,
db: Session = Depends(get_db),
):
"""
获取缓存相关配置
返回:
- 缓存TTL
- 缓存预留比例
"""
adapter = AdminCacheConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/metrics", response_class=PlainTextResponse)
async def get_cache_metrics(
request: Request,
db: Session = Depends(get_db),
):
"""
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
"""
adapter = AdminCacheMetricsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- 缓存监控适配器 --------
class AdminCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
stats = await scheduler.get_stats()
logger.info("缓存统计信息查询成功")
context.add_audit_metadata(
action="cache_stats",
scheduler=stats.get("scheduler"),
total_affinities=stats.get("total_affinities"),
cache_hit_rate=stats.get("cache_hit_rate"),
provider_switches=stats.get("provider_switches"),
)
return {"status": "ok", "data": stats}
except Exception as exc:
logger.exception(f"获取缓存统计信息失败: {exc}")
raise HTTPException(status_code=500, detail=f"获取缓存统计失败: {exc}")
class AdminCacheMetricsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
stats = await scheduler.get_stats()
payload = self._format_prometheus(stats)
context.add_audit_metadata(
action="cache_metrics_export",
scheduler=stats.get("scheduler"),
metrics_lines=payload.count("\n"),
)
return PlainTextResponse(payload)
except Exception as exc:
logger.exception(f"导出缓存指标失败: {exc}")
raise HTTPException(status_code=500, detail=f"导出缓存指标失败: {exc}")
def _format_prometheus(self, stats: Dict[str, Any]) -> str:
"""
将 scheduler/affinity 指标转换为 Prometheus 文本格式。
"""
scheduler_metrics = stats.get("scheduler_metrics", {})
affinity_stats = stats.get("affinity_stats", {})
metric_map: List[Tuple[str, str, float]] = [
(
"cache_scheduler_total_batches",
"Total batches pulled from provider list",
float(scheduler_metrics.get("total_batches", 0)),
),
(
"cache_scheduler_last_batch_size",
"Size of the latest candidate batch",
float(scheduler_metrics.get("last_batch_size", 0)),
),
(
"cache_scheduler_total_candidates",
"Total candidates enumerated by scheduler",
float(scheduler_metrics.get("total_candidates", 0)),
),
(
"cache_scheduler_last_candidate_count",
"Number of candidates in the most recent batch",
float(scheduler_metrics.get("last_candidate_count", 0)),
),
(
"cache_scheduler_cache_hits",
"Cache hits counted during scheduling",
float(scheduler_metrics.get("cache_hits", 0)),
),
(
"cache_scheduler_cache_misses",
"Cache misses counted during scheduling",
float(scheduler_metrics.get("cache_misses", 0)),
),
(
"cache_scheduler_cache_hit_rate",
"Cache hit rate during scheduling",
float(scheduler_metrics.get("cache_hit_rate", 0.0)),
),
(
"cache_scheduler_concurrency_denied",
"Times candidate rejected due to concurrency limits",
float(scheduler_metrics.get("concurrency_denied", 0)),
),
(
"cache_scheduler_avg_candidates_per_batch",
"Average candidates per batch",
float(scheduler_metrics.get("avg_candidates_per_batch", 0.0)),
),
]
affinity_map: List[Tuple[str, str, float]] = [
(
"cache_affinity_total",
"Total cache affinities stored",
float(affinity_stats.get("total_affinities", 0)),
),
(
"cache_affinity_hits",
"Affinity cache hits",
float(affinity_stats.get("cache_hits", 0)),
),
(
"cache_affinity_misses",
"Affinity cache misses",
float(affinity_stats.get("cache_misses", 0)),
),
(
"cache_affinity_hit_rate",
"Affinity cache hit rate",
float(affinity_stats.get("cache_hit_rate", 0.0)),
),
(
"cache_affinity_invalidations",
"Affinity invalidations",
float(affinity_stats.get("cache_invalidations", 0)),
),
(
"cache_affinity_provider_switches",
"Affinity provider switches",
float(affinity_stats.get("provider_switches", 0)),
),
(
"cache_affinity_key_switches",
"Affinity key switches",
float(affinity_stats.get("key_switches", 0)),
),
]
lines = []
for name, help_text, value in metric_map + affinity_map:
lines.append(f"# HELP {name} {help_text}")
lines.append(f"# TYPE {name} gauge")
lines.append(f"{name} {value}")
scheduler_name = stats.get("scheduler", "cache_aware")
lines.append(f'cache_scheduler_info{{scheduler="{scheduler_name}"}} 1')
return "\n".join(lines) + "\n"
@dataclass
class AdminGetUserAffinityAdapter(AdminApiAdapter):
user_identifier: str
async def handle(self, context): # type: ignore[override]
db = context.db
try:
user_id = resolve_user_identifier(db, self.user_identifier)
if not user_id:
raise HTTPException(
status_code=404,
detail=f"无法识别的用户标识符: {self.user_identifier}。支持用户名、邮箱、User ID或API Key ID",
)
user = db.query(User).filter(User.id == user_id).first()
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
# 获取该用户的所有缓存亲和性
all_affinities = await affinity_mgr.list_affinities()
user_affinities = [aff for aff in all_affinities if aff.get("user_id") == user_id]
if not user_affinities:
response = {
"status": "not_found",
"message": f"用户 {user.username} ({user.email}) 没有缓存亲和性",
"user_info": {
"user_id": user_id,
"username": user.username,
"email": user.email,
},
"affinities": [],
}
context.add_audit_metadata(
action="cache_user_affinity",
user_identifier=self.user_identifier,
resolved_user_id=user_id,
affinity_count=0,
status="not_found",
)
return response
response = {
"status": "ok",
"user_info": {
"user_id": user_id,
"username": user.username,
"email": user.email,
},
"affinities": [
{
"provider_id": aff["provider_id"],
"endpoint_id": aff["endpoint_id"],
"key_id": aff["key_id"],
"api_format": aff.get("api_format"),
"model_name": aff.get("model_name"),
"created_at": aff["created_at"],
"expire_at": aff["expire_at"],
"request_count": aff["request_count"],
}
for aff in user_affinities
],
"total_endpoints": len(user_affinities),
}
context.add_audit_metadata(
action="cache_user_affinity",
user_identifier=self.user_identifier,
resolved_user_id=user_id,
affinity_count=len(user_affinities),
status="ok",
)
return response
except HTTPException:
raise
except Exception as exc:
logger.exception(f"查询用户缓存亲和性失败: {exc}")
raise HTTPException(status_code=500, detail=f"查询失败: {exc}")
@dataclass
class AdminListAffinitiesAdapter(AdminApiAdapter):
keyword: Optional[str]
limit: int
offset: int
async def handle(self, context): # type: ignore[override]
db = context.db
redis_client = get_redis_client_sync()
if not redis_client:
raise HTTPException(status_code=503, detail="Redis未初始化无法获取缓存亲和性")
affinity_mgr = await get_affinity_manager(redis_client)
matched_user_id = None
matched_api_key_id = None
raw_affinities: List[Dict[str, Any]] = []
if self.keyword:
# 首先检查是否是 API Key IDaffinity_key
api_key = db.query(ApiKey).filter(ApiKey.id == self.keyword).first()
if api_key:
# 直接通过 affinity_key 过滤
matched_api_key_id = str(api_key.id)
matched_user_id = str(api_key.user_id)
all_affinities = await affinity_mgr.list_affinities()
raw_affinities = [
aff for aff in all_affinities if aff.get("affinity_key") == matched_api_key_id
]
else:
# 尝试解析为用户标识
user_id = resolve_user_identifier(db, self.keyword)
if user_id:
matched_user_id = user_id
# 获取该用户所有的 API Key ID
user_api_keys = db.query(ApiKey).filter(ApiKey.user_id == user_id).all()
user_api_key_ids = {str(k.id) for k in user_api_keys}
# 过滤出该用户所有 API Key 的亲和性
all_affinities = await affinity_mgr.list_affinities()
raw_affinities = [
aff for aff in all_affinities if aff.get("affinity_key") in user_api_key_ids
]
else:
# 关键词不是有效标识,返回所有亲和性(后续会进行模糊匹配)
raw_affinities = await affinity_mgr.list_affinities()
else:
raw_affinities = await affinity_mgr.list_affinities()
# 收集所有 affinity_key (API Key ID)
affinity_keys = {
item.get("affinity_key") for item in raw_affinities if item.get("affinity_key")
}
# 批量查询用户 API Key 信息
user_api_key_map: Dict[str, ApiKey] = {}
if affinity_keys:
user_api_keys = db.query(ApiKey).filter(ApiKey.id.in_(list(affinity_keys))).all()
user_api_key_map = {str(k.id): k for k in user_api_keys}
# 收集所有 user_id
user_ids = {str(k.user_id) for k in user_api_key_map.values()}
user_map: Dict[str, User] = {}
if user_ids:
users = db.query(User).filter(User.id.in_(list(user_ids))).all()
user_map = {str(user.id): user for user in users}
# 收集所有provider_id、endpoint_id、key_id
provider_ids = {
item.get("provider_id") for item in raw_affinities if item.get("provider_id")
}
endpoint_ids = {
item.get("endpoint_id") for item in raw_affinities if item.get("endpoint_id")
}
key_ids = {item.get("key_id") for item in raw_affinities if item.get("key_id")}
# 批量查询Provider、Endpoint、Key信息
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
provider_map = {}
if provider_ids:
providers = db.query(Provider).filter(Provider.id.in_(list(provider_ids))).all()
provider_map = {p.id: p for p in providers}
endpoint_map = {}
if endpoint_ids:
endpoints = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id.in_(list(endpoint_ids))).all()
)
endpoint_map = {e.id: e for e in endpoints}
key_map = {}
if key_ids:
keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.id.in_(list(key_ids))).all()
key_map = {k.id: k for k in keys}
# 收集所有 model_name实际存储的是 global_model_id并批量查询 GlobalModel
from src.models.database import GlobalModel
global_model_ids = {
item.get("model_name") for item in raw_affinities if item.get("model_name")
}
global_model_map: Dict[str, GlobalModel] = {}
if global_model_ids:
# model_name 可能是 UUID 格式的 global_model_id也可能是原始模型名称
global_models = db.query(GlobalModel).filter(
GlobalModel.id.in_(list(global_model_ids))
).all()
global_model_map = {str(gm.id): gm for gm in global_models}
keyword_lower = self.keyword.lower() if self.keyword else None
items = []
for affinity in raw_affinities:
affinity_key = affinity.get("affinity_key")
if not affinity_key:
continue
# 通过 affinity_keyAPI Key ID找到用户 API Key 和用户
user_api_key = user_api_key_map.get(affinity_key)
user = user_map.get(str(user_api_key.user_id)) if user_api_key else None
user_id = str(user_api_key.user_id) if user_api_key else None
provider_id = affinity.get("provider_id")
endpoint_id = affinity.get("endpoint_id")
key_id = affinity.get("key_id")
provider = provider_map.get(provider_id)
endpoint = endpoint_map.get(endpoint_id)
key = key_map.get(key_id)
# 用户 API Key 脱敏显示(解密 key_encrypted 后脱敏)
user_api_key_masked = None
if user_api_key and user_api_key.key_encrypted:
user_api_key_masked = decrypt_and_mask(user_api_key.key_encrypted)
# Provider Key 脱敏显示(解密 api_key 后脱敏)
provider_key_masked = None
if key and key.api_key:
provider_key_masked = decrypt_and_mask(key.api_key)
item = {
"affinity_key": affinity_key,
"user_api_key_name": user_api_key.name if user_api_key else None,
"user_api_key_prefix": user_api_key_masked,
"is_standalone": user_api_key.is_standalone if user_api_key else False,
"user_id": user_id,
"username": user.username if user else None,
"email": user.email if user else None,
"provider_id": provider_id,
"provider_name": provider.display_name if provider else None,
"endpoint_id": endpoint_id,
"endpoint_api_format": (
endpoint.api_format if endpoint and endpoint.api_format else None
),
"endpoint_url": endpoint.base_url if endpoint else None,
"key_id": key_id,
"key_name": key.name if key else None,
"key_prefix": provider_key_masked,
"rate_multiplier": key.rate_multiplier if key else 1.0,
"model_name": (
global_model_map.get(affinity.get("model_name")).name
if affinity.get("model_name") and global_model_map.get(affinity.get("model_name"))
else affinity.get("model_name") # 如果找不到 GlobalModel显示原始值
),
"model_display_name": (
global_model_map.get(affinity.get("model_name")).display_name
if affinity.get("model_name") and global_model_map.get(affinity.get("model_name"))
else None
),
"api_format": affinity.get("api_format"),
"created_at": affinity.get("created_at"),
"expire_at": affinity.get("expire_at"),
"request_count": affinity.get("request_count", 0),
}
if keyword_lower and not matched_user_id and not matched_api_key_id:
searchable = [
item["affinity_key"],
item["user_api_key_name"] or "",
item["user_id"] or "",
item["username"] or "",
item["email"] or "",
item["provider_id"] or "",
item["key_id"] or "",
]
if not any(keyword_lower in str(value).lower() for value in searchable if value):
continue
items.append(item)
items.sort(key=lambda x: x.get("expire_at") or 0, reverse=True)
paged_items, meta = paginate_sequence(items, self.limit, self.offset)
payload = build_pagination_payload(
paged_items,
meta,
matched_user_id=matched_user_id,
)
response = {
"status": "ok",
"data": payload,
}
result_count = meta.count if hasattr(meta, "count") else len(paged_items)
context.add_audit_metadata(
action="cache_affinity_list",
keyword=self.keyword,
matched_user_id=matched_user_id,
matched_api_key_id=matched_api_key_id,
limit=self.limit,
offset=self.offset,
result_count=result_count,
)
return response
@dataclass
class AdminClearUserCacheAdapter(AdminApiAdapter):
user_identifier: str
async def handle(self, context): # type: ignore[override]
db = context.db
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
# 首先检查是否直接是 API Key ID (affinity_key)
api_key = db.query(ApiKey).filter(ApiKey.id == self.user_identifier).first()
if api_key:
# 直接按 affinity_key 清除
affinity_key = str(api_key.id)
user = db.query(User).filter(User.id == api_key.user_id).first()
all_affinities = await affinity_mgr.list_affinities()
target_affinities = [
aff for aff in all_affinities if aff.get("affinity_key") == affinity_key
]
count = 0
for aff in target_affinities:
api_format = aff.get("api_format")
model_name = aff.get("model_name")
endpoint_id = aff.get("endpoint_id")
if api_format and model_name:
await affinity_mgr.invalidate_affinity(
affinity_key, api_format, model_name, endpoint_id=endpoint_id
)
count += 1
logger.info(f"已清除API Key缓存亲和性: api_key_name={api_key.name}, affinity_key={affinity_key[:8]}..., 清除数量={count}")
response = {
"status": "ok",
"message": f"已清除 API Key {api_key.name} 的缓存亲和性",
"user_info": {
"user_id": str(api_key.user_id),
"username": user.username if user else None,
"email": user.email if user else None,
"api_key_id": affinity_key,
"api_key_name": api_key.name,
},
}
context.add_audit_metadata(
action="cache_clear_api_key",
user_identifier=self.user_identifier,
resolved_api_key_id=affinity_key,
cleared_count=count,
)
return response
# 如果不是 API Key ID尝试解析为用户标识
user_id = resolve_user_identifier(db, self.user_identifier)
if not user_id:
raise HTTPException(
status_code=404,
detail=f"无法识别的标识符: {self.user_identifier}。支持用户名、邮箱、User ID或API Key ID",
)
user = db.query(User).filter(User.id == user_id).first()
# 获取该用户所有的 API Key
user_api_keys = db.query(ApiKey).filter(ApiKey.user_id == user_id).all()
user_api_key_ids = {str(k.id) for k in user_api_keys}
# 获取该用户所有 API Key 的缓存亲和性并逐个失效
all_affinities = await affinity_mgr.list_affinities()
user_affinities = [
aff for aff in all_affinities if aff.get("affinity_key") in user_api_key_ids
]
count = 0
for aff in user_affinities:
affinity_key = aff.get("affinity_key")
api_format = aff.get("api_format")
model_name = aff.get("model_name")
endpoint_id = aff.get("endpoint_id")
if affinity_key and api_format and model_name:
await affinity_mgr.invalidate_affinity(
affinity_key, api_format, model_name, endpoint_id=endpoint_id
)
count += 1
logger.info(f"已清除用户缓存亲和性: username={user.username}, user_id={user_id[:8]}..., 清除数量={count}")
response = {
"status": "ok",
"message": f"已清除用户 {user.username} 的所有缓存亲和性",
"user_info": {"user_id": user_id, "username": user.username, "email": user.email},
}
context.add_audit_metadata(
action="cache_clear_user",
user_identifier=self.user_identifier,
resolved_user_id=user_id,
cleared_count=count,
)
return response
except HTTPException:
raise
except Exception as exc:
logger.exception(f"清除用户缓存亲和性失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
class AdminClearAllCacheAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
count = await affinity_mgr.clear_all()
logger.warning(f"已清除所有缓存亲和性(管理员操作): {count}")
context.add_audit_metadata(
action="cache_clear_all",
cleared_count=count,
)
return {"status": "ok", "message": "已清除所有缓存亲和性", "count": count}
except Exception as exc:
logger.exception(f"清除所有缓存亲和性失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
@dataclass
class AdminClearProviderCacheAdapter(AdminApiAdapter):
provider_id: str
async def handle(self, context): # type: ignore[override]
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
count = await affinity_mgr.invalidate_all_for_provider(self.provider_id)
logger.info(f"已清除Provider缓存亲和性: provider_id={self.provider_id[:8]}..., count={count}")
context.add_audit_metadata(
action="cache_clear_provider",
provider_id=self.provider_id,
cleared_count=count,
)
return {
"status": "ok",
"message": "已清除Provider的缓存亲和性",
"provider_id": self.provider_id,
"count": count,
}
except Exception as exc:
logger.exception(f"清除Provider缓存亲和性失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
class AdminCacheConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
from src.services.cache.affinity_manager import CacheAffinityManager
from src.services.cache.aware_scheduler import CacheAwareScheduler
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
# 获取动态预留管理器的配置
reservation_manager = get_adaptive_reservation_manager()
reservation_stats = reservation_manager.get_stats()
response = {
"status": "ok",
"data": {
"cache_ttl_seconds": CacheAffinityManager.DEFAULT_CACHE_TTL,
"cache_reservation_ratio": CacheAwareScheduler.CACHE_RESERVATION_RATIO,
"dynamic_reservation": {
"enabled": True,
"config": reservation_stats["config"],
"description": {
"probe_phase_requests": "探测阶段请求数阈值",
"probe_reservation": "探测阶段预留比例",
"stable_min_reservation": "稳定阶段最小预留比例",
"stable_max_reservation": "稳定阶段最大预留比例",
"low_load_threshold": "低负载阈值(低于此值使用最小预留)",
"high_load_threshold": "高负载阈值(高于此值根据置信度使用较高预留)",
},
},
"description": {
"cache_ttl": "缓存亲和性有效期(秒)",
"cache_reservation_ratio": "静态预留比例(已被动态预留替代)",
"dynamic_reservation": "动态预留机制配置",
},
},
}
context.add_audit_metadata(
action="cache_config",
cache_ttl_seconds=CacheAffinityManager.DEFAULT_CACHE_TTL,
cache_reservation_ratio=CacheAwareScheduler.CACHE_RESERVATION_RATIO,
dynamic_reservation_enabled=True,
)
return response

View File

@@ -0,0 +1,280 @@
"""
请求链路追踪 API 端点
"""
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.database import get_db
from src.models.database import Provider, ProviderEndpoint, ProviderAPIKey
from src.services.request.candidate import RequestCandidateService
router = APIRouter(prefix="/api/admin/monitoring/trace", tags=["Admin - Monitoring: Trace"])
pipeline = ApiRequestPipeline()
class CandidateResponse(BaseModel):
"""候选记录响应"""
id: str
request_id: str
candidate_index: int
retry_index: int = 0 # 重试序号从0开始
provider_id: Optional[str] = None
provider_name: Optional[str] = None
provider_website: Optional[str] = None # Provider 官网
endpoint_id: Optional[str] = None
endpoint_name: Optional[str] = None # 端点显示名称api_format
key_id: Optional[str] = None
key_name: Optional[str] = None # 密钥名称
key_preview: Optional[str] = None # 密钥脱敏预览(如 sk-***abc
key_capabilities: Optional[dict] = None # Key 支持的能力
required_capabilities: Optional[dict] = None # 请求实际需要的能力标签
status: str # 'pending', 'success', 'failed', 'skipped'
skip_reason: Optional[str] = None
is_cached: bool = False
# 执行结果字段
status_code: Optional[int] = None
error_type: Optional[str] = None
error_message: Optional[str] = None
latency_ms: Optional[int] = None
concurrent_requests: Optional[int] = None
extra_data: Optional[dict] = None
created_at: datetime
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
class Config:
from_attributes = True
class RequestTraceResponse(BaseModel):
"""请求追踪完整响应"""
request_id: str
total_candidates: int
final_status: str # 'success', 'failed', 'streaming', 'pending'
total_latency_ms: int
candidates: List[CandidateResponse]
@router.get("/{request_id}", response_model=RequestTraceResponse)
async def get_request_trace(
request_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""获取特定请求的完整追踪信息"""
adapter = AdminGetRequestTraceAdapter(request_id=request_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/stats/provider/{provider_id}")
async def get_provider_failure_rate(
provider_id: str,
request: Request,
limit: int = Query(100, ge=1, le=1000, description="统计最近的尝试数量"),
db: Session = Depends(get_db),
):
"""
获取某个 Provider 的失败率统计
需要管理员权限
"""
adapter = AdminProviderFailureRateAdapter(provider_id=provider_id, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- 请求追踪适配器 --------
@dataclass
class AdminGetRequestTraceAdapter(AdminApiAdapter):
request_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
# 只查询 candidates
candidates = RequestCandidateService.get_candidates_by_request_id(db, self.request_id)
# 如果没有数据,返回 404
if not candidates:
raise HTTPException(status_code=404, detail="Request not found")
# 计算总延迟只统计已完成的候选success 或 failed
# 使用显式的 is not None 检查,避免过滤掉 0ms 的快速响应
total_latency = sum(
c.latency_ms
for c in candidates
if c.status in ("success", "failed") and c.latency_ms is not None
)
# 判断最终状态:
# 1. status="success" 即视为成功(无论 status_code 是什么)
# - 流式请求即使客户端断开499只要 Provider 成功返回数据,也算成功
# 2. 同时检查 status_code 在 200-299 范围,作为额外的成功判断条件
# - 用于兼容非流式请求或未正确设置 status 的旧数据
# 3. status="streaming" 表示流式请求正在进行中
# 4. status="pending" 表示请求尚未开始执行
has_success = any(
c.status == "success"
or (c.status_code is not None and 200 <= c.status_code < 300)
for c in candidates
)
has_streaming = any(c.status == "streaming" for c in candidates)
has_pending = any(c.status == "pending" for c in candidates)
if has_success:
final_status = "success"
elif has_streaming:
# 有候选正在流式传输中
final_status = "streaming"
elif has_pending:
# 有候选正在等待执行
final_status = "pending"
else:
final_status = "failed"
# 批量加载 provider 信息,避免 N+1 查询
provider_ids = {c.provider_id for c in candidates if c.provider_id}
provider_map = {}
provider_website_map = {}
if provider_ids:
providers = db.query(Provider).filter(Provider.id.in_(provider_ids)).all()
for p in providers:
provider_map[p.id] = p.name
provider_website_map[p.id] = p.website
# 批量加载 endpoint 信息
endpoint_ids = {c.endpoint_id for c in candidates if c.endpoint_id}
endpoint_map = {}
if endpoint_ids:
endpoints = db.query(ProviderEndpoint).filter(ProviderEndpoint.id.in_(endpoint_ids)).all()
endpoint_map = {e.id: e.api_format for e in endpoints}
# 批量加载 key 信息
key_ids = {c.key_id for c in candidates if c.key_id}
key_map = {}
key_preview_map = {}
key_capabilities_map = {}
if key_ids:
keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.id.in_(key_ids)).all()
for k in keys:
key_map[k.id] = k.name
key_capabilities_map[k.id] = k.capabilities
# 生成脱敏预览保留前缀和最后4位
api_key = k.api_key or ""
if len(api_key) > 8:
# 检测常见前缀模式
prefix_end = 0
for prefix in ["sk-", "key-", "api-", "ak-"]:
if api_key.lower().startswith(prefix):
prefix_end = len(prefix)
break
if prefix_end > 0:
key_preview_map[k.id] = f"{api_key[:prefix_end]}***{api_key[-4:]}"
else:
key_preview_map[k.id] = f"{api_key[:3]}***{api_key[-4:]}"
elif len(api_key) > 4:
key_preview_map[k.id] = f"***{api_key[-4:]}"
else:
key_preview_map[k.id] = "***"
# 构建 candidate 响应列表
candidate_responses: List[CandidateResponse] = []
for candidate in candidates:
provider_name = (
provider_map.get(candidate.provider_id) if candidate.provider_id else None
)
provider_website = (
provider_website_map.get(candidate.provider_id) if candidate.provider_id else None
)
endpoint_name = (
endpoint_map.get(candidate.endpoint_id) if candidate.endpoint_id else None
)
key_name = (
key_map.get(candidate.key_id) if candidate.key_id else None
)
key_preview = (
key_preview_map.get(candidate.key_id) if candidate.key_id else None
)
key_capabilities = (
key_capabilities_map.get(candidate.key_id) if candidate.key_id else None
)
candidate_responses.append(
CandidateResponse(
id=candidate.id,
request_id=candidate.request_id,
candidate_index=candidate.candidate_index,
retry_index=candidate.retry_index,
provider_id=candidate.provider_id,
provider_name=provider_name,
provider_website=provider_website,
endpoint_id=candidate.endpoint_id,
endpoint_name=endpoint_name,
key_id=candidate.key_id,
key_name=key_name,
key_preview=key_preview,
key_capabilities=key_capabilities,
required_capabilities=candidate.required_capabilities,
status=candidate.status,
skip_reason=candidate.skip_reason,
is_cached=candidate.is_cached,
status_code=candidate.status_code,
error_type=candidate.error_type,
error_message=candidate.error_message,
latency_ms=candidate.latency_ms,
concurrent_requests=candidate.concurrent_requests,
extra_data=candidate.extra_data,
created_at=candidate.created_at,
started_at=candidate.started_at,
finished_at=candidate.finished_at,
)
)
response = RequestTraceResponse(
request_id=self.request_id,
total_candidates=len(candidates),
final_status=final_status,
total_latency_ms=total_latency,
candidates=candidate_responses,
)
context.add_audit_metadata(
action="trace_request_detail",
request_id=self.request_id,
total_candidates=len(candidates),
final_status=final_status,
total_latency_ms=total_latency,
)
return response
@dataclass
class AdminProviderFailureRateAdapter(AdminApiAdapter):
provider_id: str
limit: int
async def handle(self, context): # type: ignore[override]
result = RequestCandidateService.get_candidate_stats_by_provider(
db=context.db,
provider_id=self.provider_id,
limit=self.limit,
)
context.add_audit_metadata(
action="trace_provider_failure_rate",
provider_id=self.provider_id,
limit=self.limit,
total_attempts=result.get("total_attempts"),
failure_rate=result.get("failure_rate"),
)
return result

View File

@@ -0,0 +1,410 @@
"""
Provider Query API 端点
用于查询提供商的余额、使用记录等信息
"""
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from src.core.logger import logger
from src.database.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
# 初始化适配器注册
from src.plugins.provider_query import init # noqa
from src.plugins.provider_query import get_query_registry
from src.plugins.provider_query.base import QueryCapability
from src.utils.auth_utils import get_current_user
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
# ============ Request/Response Models ============
class BalanceQueryRequest(BaseModel):
"""余额查询请求"""
provider_id: str
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
class UsageSummaryQueryRequest(BaseModel):
"""使用汇总查询请求"""
provider_id: str
api_key_id: Optional[str] = None
period: str = "month" # day, week, month, year
class ModelsQueryRequest(BaseModel):
"""模型列表查询请求"""
provider_id: str
api_key_id: Optional[str] = None
# ============ API Endpoints ============
@router.get("/adapters")
async def list_adapters(
current_user: User = Depends(get_current_user),
):
"""
获取所有可用的查询适配器
Returns:
适配器列表
"""
registry = get_query_registry()
adapters = registry.list_adapters()
return {"success": True, "data": adapters}
@router.get("/capabilities/{provider_id}")
async def get_provider_capabilities(
provider_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取提供商支持的查询能力
Args:
provider_id: 提供商 ID
Returns:
支持的查询能力列表
"""
# 获取提供商
from sqlalchemy import select
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
capabilities = registry.get_capabilities_for_provider(provider.name)
if capabilities is None:
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [],
"has_adapter": False,
"message": "No query adapter available for this provider",
},
}
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [c.name for c in capabilities],
"has_adapter": True,
},
}
@router.post("/balance")
async def query_balance(
request: BalanceQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商余额
Args:
request: 查询请求
Returns:
余额信息
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
if request.api_key_id:
# 查找指定的 API Key
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 使用第一个可用的 API Key
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询余额
registry = get_query_registry()
query_result = await registry.query_provider_balance(
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
)
if not query_result.success:
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.post("/usage-summary")
async def query_usage_summary(
request: UsageSummaryQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商使用汇总
Args:
request: 查询请求
Returns:
使用汇总信息
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key逻辑同上
api_key_value = None
endpoint_config = None
if request.api_key_id:
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询使用汇总
registry = get_query_registry()
query_result = await registry.query_provider_usage(
provider_type=provider.name,
api_key=api_key_value,
period=request.period,
endpoint_config=endpoint_config,
)
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.post("/models")
async def query_available_models(
request: ModelsQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商可用模型
Args:
request: 查询请求
Returns:
模型列表
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
if request.api_key_id:
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询模型
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
if not adapter:
raise HTTPException(
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
)
query_result = await adapter.query_available_models(
api_key=api_key_value, endpoint_config=endpoint_config
)
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.delete("/cache/{provider_id}")
async def clear_query_cache(
provider_id: str,
api_key_id: Optional[str] = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
清除查询缓存
Args:
provider_id: 提供商 ID
api_key_id: 可选,指定清除某个 API Key 的缓存
Returns:
清除结果
"""
from sqlalchemy import select
# 获取提供商
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
if adapter:
if api_key_id:
# 获取 API Key 值来清除缓存
from sqlalchemy.orm import selectinload
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
api_key = result.scalar_one_or_none()
if api_key:
adapter.clear_cache(api_key.api_key)
else:
adapter.clear_cache()
return {"success": True, "message": "Cache cleared successfully"}

View File

@@ -0,0 +1,272 @@
"""
提供商策略管理 API 端点
"""
from datetime import datetime, timedelta
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider
from src.models.database_extensions import ProviderUsageTracking
router = APIRouter(prefix="/api/admin/provider-strategy", tags=["Provider Strategy"])
pipeline = ApiRequestPipeline()
class ProviderBillingUpdate(BaseModel):
billing_type: ProviderBillingType
monthly_quota_usd: Optional[float] = None
quota_reset_day: int = Field(default=30, ge=1, le=365) # 重置周期(天数)
quota_last_reset_at: Optional[str] = None # 当前周期开始时间
quota_expires_at: Optional[str] = None
rpm_limit: Optional[int] = Field(default=None, ge=0)
provider_priority: int = Field(default=100, ge=0, le=200)
@router.put("/providers/{provider_id}/billing")
async def update_provider_billing(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
adapter = AdminProviderBillingAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/providers/{provider_id}/stats")
async def get_provider_stats(
provider_id: str,
request: Request,
hours: int = 24,
db: Session = Depends(get_db),
):
adapter = AdminProviderStatsAdapter(provider_id=provider_id, hours=hours)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/providers/{provider_id}/quota")
async def reset_provider_quota(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""Reset provider quota usage to zero"""
adapter = AdminProviderResetQuotaAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/strategies")
async def list_available_strategies(request: Request, db: Session = Depends(get_db)):
adapter = AdminListStrategiesAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminProviderBillingAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context):
db = context.db
payload = context.ensure_json_body()
try:
config = ProviderBillingUpdate.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
provider.billing_type = config.billing_type
provider.monthly_quota_usd = config.monthly_quota_usd
provider.quota_reset_day = config.quota_reset_day
provider.rpm_limit = config.rpm_limit
provider.provider_priority = config.provider_priority
from dateutil import parser
from sqlalchemy import func
from src.models.database import Usage
if config.quota_last_reset_at:
new_reset_at = parser.parse(config.quota_last_reset_at)
provider.quota_last_reset_at = new_reset_at
# 自动同步该周期内的历史使用量
period_usage = (
db.query(func.coalesce(func.sum(Usage.total_cost_usd), 0))
.filter(
Usage.provider_id == self.provider_id,
Usage.created_at >= new_reset_at,
)
.scalar()
)
provider.monthly_used_usd = float(period_usage or 0)
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
if config.quota_expires_at:
provider.quota_expires_at = parser.parse(config.quota_expires_at)
db.commit()
db.refresh(provider)
logger.info(f"Updated billing config for provider {provider.name}")
return JSONResponse(
{
"message": "Provider billing config updated successfully",
"provider": {
"id": provider.id,
"name": provider.name,
"billing_type": provider.billing_type.value,
"provider_priority": provider.provider_priority,
},
}
)
class AdminProviderStatsAdapter(AdminApiAdapter):
def __init__(self, provider_id: str, hours: int):
self.provider_id = provider_id
self.hours = hours
async def handle(self, context):
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
since = datetime.now() - timedelta(hours=self.hours)
stats = (
db.query(ProviderUsageTracking)
.filter(
ProviderUsageTracking.provider_id == self.provider_id,
ProviderUsageTracking.window_start >= since,
)
.all()
)
total_requests = sum(s.total_requests for s in stats)
total_success = sum(s.successful_requests for s in stats)
total_failures = sum(s.failed_requests for s in stats)
avg_response_time = sum(s.avg_response_time_ms for s in stats) / len(stats) if stats else 0
total_cost = sum(s.total_cost_usd for s in stats)
return JSONResponse(
{
"provider_id": self.provider_id,
"provider_name": provider.name,
"period_hours": self.hours,
"billing_info": {
"billing_type": provider.billing_type.value,
"monthly_quota_usd": provider.monthly_quota_usd,
"monthly_used_usd": provider.monthly_used_usd,
"quota_remaining_usd": (
provider.monthly_quota_usd - provider.monthly_used_usd
if provider.monthly_quota_usd is not None
else None
),
"quota_expires_at": (
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
),
},
"rpm_info": {
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": (
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
),
},
"usage_stats": {
"total_requests": total_requests,
"successful_requests": total_success,
"failed_requests": total_failures,
"success_rate": total_success / total_requests if total_requests > 0 else 0,
"avg_response_time_ms": round(avg_response_time, 2),
"total_cost_usd": round(total_cost, 4),
},
}
)
class AdminProviderResetQuotaAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context):
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
if provider.billing_type != ProviderBillingType.MONTHLY_QUOTA:
raise HTTPException(status_code=400, detail="Only monthly quota providers can be reset")
old_used = provider.monthly_used_usd
provider.monthly_used_usd = 0.0
provider.rpm_used = 0
provider.rpm_reset_at = None
db.commit()
logger.info(f"Manually reset quota for provider {provider.name}")
return JSONResponse(
{
"message": "Provider quota reset successfully",
"provider_name": provider.name,
"previous_used": old_used,
"current_used": 0.0,
}
)
class AdminListStrategiesAdapter(AdminApiAdapter):
async def handle(self, context):
from src.plugins.manager import get_plugin_manager
plugin_manager = get_plugin_manager()
lb_plugins = plugin_manager.plugins.get("load_balancer", {})
strategies = []
for name, plugin in lb_plugins.items():
try:
strategies.append(
{
"name": getattr(plugin, "name", name),
"priority": getattr(plugin, "priority", 0),
"version": (
getattr(plugin.metadata, "version", "1.0.0")
if hasattr(plugin, "metadata")
else "1.0.0"
),
"description": (
getattr(plugin.metadata, "description", "")
if hasattr(plugin, "metadata")
else ""
),
"author": (
getattr(plugin.metadata, "author", "Unknown")
if hasattr(plugin, "metadata")
else "Unknown"
),
}
)
except Exception as exc: # pragma: no cover
logger.error(f"Error accessing plugin {name}: {exc}")
continue
strategies.sort(key=lambda x: x["priority"], reverse=True)
return JSONResponse({"strategies": strategies, "total": len(strategies)})

View File

@@ -0,0 +1,20 @@
"""Provider admin routes export."""
from fastapi import APIRouter
from .models import router as models_router
from .routes import router as routes_router
from .summary import router as summary_router
router = APIRouter(prefix="/api/admin/providers", tags=["Admin - Providers"])
# Provider CRUD
router.include_router(routes_router)
# Provider summary & health monitor
router.include_router(summary_router)
# Provider models management
router.include_router(models_router)
__all__ = ["router"]

View File

@@ -0,0 +1,443 @@
"""
Provider 模型管理 API
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, Request
from sqlalchemy import or_
from sqlalchemy.orm import Session, joinedload
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.api import (
ModelCreate,
ModelResponse,
ModelUpdate,
)
from src.models.pydantic_models import (
BatchAssignModelsToProviderRequest,
BatchAssignModelsToProviderResponse,
)
from src.models.database import (
GlobalModel,
Model,
ModelMapping,
Provider,
)
from src.models.pydantic_models import (
ProviderAvailableSourceModel,
ProviderAvailableSourceModelsResponse,
)
from src.services.model.service import ModelService
router = APIRouter(tags=["Model Management"])
pipeline = ApiRequestPipeline()
@router.get("/{provider_id}/models", response_model=List[ModelResponse])
async def list_provider_models(
provider_id: str,
request: Request,
is_active: Optional[bool] = None,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
) -> List[ModelResponse]:
"""获取提供商的所有模型(管理员)"""
adapter = AdminListProviderModelsAdapter(
provider_id=provider_id,
is_active=is_active,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{provider_id}/models", response_model=ModelResponse)
async def create_provider_model(
provider_id: str,
model_data: ModelCreate,
request: Request,
db: Session = Depends(get_db),
) -> ModelResponse:
"""创建模型(管理员)"""
adapter = AdminCreateProviderModelAdapter(provider_id=provider_id, model_data=model_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{provider_id}/models/{model_id}", response_model=ModelResponse)
async def get_provider_model(
provider_id: str,
model_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ModelResponse:
"""获取模型详情(管理员)"""
adapter = AdminGetProviderModelAdapter(provider_id=provider_id, model_id=model_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{provider_id}/models/{model_id}", response_model=ModelResponse)
async def update_provider_model(
provider_id: str,
model_id: str,
model_data: ModelUpdate,
request: Request,
db: Session = Depends(get_db),
) -> ModelResponse:
"""更新模型(管理员)"""
adapter = AdminUpdateProviderModelAdapter(
provider_id=provider_id,
model_id=model_id,
model_data=model_data,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{provider_id}/models/{model_id}")
async def delete_provider_model(
provider_id: str,
model_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""删除模型(管理员)"""
adapter = AdminDeleteProviderModelAdapter(provider_id=provider_id, model_id=model_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{provider_id}/models/batch", response_model=List[ModelResponse])
async def batch_create_provider_models(
provider_id: str,
models_data: List[ModelCreate],
request: Request,
db: Session = Depends(get_db),
) -> List[ModelResponse]:
"""批量创建模型(管理员)"""
adapter = AdminBatchCreateModelsAdapter(provider_id=provider_id, models_data=models_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get(
"/{provider_id}/available-source-models",
response_model=ProviderAvailableSourceModelsResponse,
)
async def get_provider_available_source_models(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
获取该 Provider 支持的所有统一模型名source_model
包括:
1. 通过 ModelMapping 映射的模型
2. 直连模型Model.provider_model_name 直接作为统一模型名)
"""
adapter = AdminGetProviderAvailableSourceModelsAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post(
"/{provider_id}/assign-global-models",
response_model=BatchAssignModelsToProviderResponse,
)
async def batch_assign_global_models_to_provider(
provider_id: str,
payload: BatchAssignModelsToProviderRequest,
request: Request,
db: Session = Depends(get_db),
) -> BatchAssignModelsToProviderResponse:
"""批量为 Provider 关联 GlobalModels自动继承价格和能力配置"""
adapter = AdminBatchAssignModelsToProviderAdapter(
provider_id=provider_id, payload=payload
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminListProviderModelsAdapter(AdminApiAdapter):
provider_id: str
is_active: Optional[bool]
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
models = ModelService.get_models_by_provider(
db, self.provider_id, self.skip, self.limit, self.is_active
)
return [ModelService.convert_to_response(model) for model in models]
@dataclass
class AdminCreateProviderModelAdapter(AdminApiAdapter):
provider_id: str
model_data: ModelCreate
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
try:
model = ModelService.create_model(db, self.provider_id, self.model_data)
logger.info(f"Model created: {model.provider_model_name} for provider {provider.name} by {context.user.username}")
return ModelService.convert_to_response(model)
except Exception as exc:
raise InvalidRequestException(str(exc))
@dataclass
class AdminGetProviderModelAdapter(AdminApiAdapter):
provider_id: str
model_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
model = (
db.query(Model)
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
.first()
)
if not model:
raise NotFoundException("Model not found", "model")
return ModelService.convert_to_response(model)
@dataclass
class AdminUpdateProviderModelAdapter(AdminApiAdapter):
provider_id: str
model_id: str
model_data: ModelUpdate
async def handle(self, context): # type: ignore[override]
db = context.db
model = (
db.query(Model)
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
.first()
)
if not model:
raise NotFoundException("Model not found", "model")
try:
updated_model = ModelService.update_model(db, self.model_id, self.model_data)
logger.info(f"Model updated: {updated_model.provider_model_name} by {context.user.username}")
return ModelService.convert_to_response(updated_model)
except Exception as exc:
raise InvalidRequestException(str(exc))
@dataclass
class AdminDeleteProviderModelAdapter(AdminApiAdapter):
provider_id: str
model_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
model = (
db.query(Model)
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
.first()
)
if not model:
raise NotFoundException("Model not found", "model")
model_name = model.provider_model_name
try:
ModelService.delete_model(db, self.model_id)
logger.info(f"Model deleted: {model_name} by {context.user.username}")
return {"message": f"Model '{model_name}' deleted successfully"}
except Exception as exc:
raise InvalidRequestException(str(exc))
@dataclass
class AdminBatchCreateModelsAdapter(AdminApiAdapter):
provider_id: str
models_data: List[ModelCreate]
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
try:
models = ModelService.batch_create_models(db, self.provider_id, self.models_data)
logger.info(f"Batch created {len(models)} models for provider {provider.name} by {context.user.username}")
return [ModelService.convert_to_response(model) for model in models]
except Exception as exc:
raise InvalidRequestException(str(exc))
@dataclass
class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
provider_id: str
async def handle(self, context): # type: ignore[override]
"""
返回 Provider 支持的所有 GlobalModel
方案 A 逻辑:
1. 查询该 Provider 的所有 Model
2. 通过 Model.global_model_id 获取 GlobalModel
3. 查询所有指向该 GlobalModel 的别名ModelMapping.alias
"""
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
# 1. 查询该 Provider 的所有活跃 Model预加载 GlobalModel
models = (
db.query(Model)
.options(joinedload(Model.global_model))
.filter(Model.provider_id == self.provider_id, Model.is_active == True)
.all()
)
# 2. 构建以 GlobalModel 为主键的字典
global_models_dict: Dict[str, Dict[str, Any]] = {}
for model in models:
global_model = model.global_model
if not global_model or not global_model.is_active:
continue
global_model_name = global_model.name
# 如果该 GlobalModel 还未处理,初始化
if global_model_name not in global_models_dict:
# 查询指向该 GlobalModel 的所有别名/映射
alias_rows = (
db.query(ModelMapping.source_model)
.filter(
ModelMapping.target_global_model_id == global_model.id,
ModelMapping.is_active == True,
or_(
ModelMapping.provider_id == self.provider_id,
ModelMapping.provider_id.is_(None),
),
)
.all()
)
alias_list = [alias[0] for alias in alias_rows]
global_models_dict[global_model_name] = {
"global_model_name": global_model_name,
"display_name": global_model.display_name,
"provider_model_name": model.provider_model_name,
"has_alias": len(alias_list) > 0,
"aliases": alias_list,
"model_id": model.id,
"price": {
"input_price_per_1m": model.get_effective_input_price(),
"output_price_per_1m": model.get_effective_output_price(),
"cache_creation_price_per_1m": model.get_effective_cache_creation_price(),
"cache_read_price_per_1m": model.get_effective_cache_read_price(),
"price_per_request": model.get_effective_price_per_request(),
},
"capabilities": {
"supports_vision": bool(model.supports_vision),
"supports_function_calling": bool(model.supports_function_calling),
"supports_streaming": bool(model.supports_streaming),
},
"is_active": bool(model.is_active),
}
models_list = [
ProviderAvailableSourceModel(**global_models_dict[name])
for name in sorted(global_models_dict.keys())
]
return ProviderAvailableSourceModelsResponse(models=models_list, total=len(models_list))
@dataclass
class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
"""批量为 Provider 关联 GlobalModels"""
provider_id: str
payload: BatchAssignModelsToProviderRequest
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
success = []
errors = []
for global_model_id in self.payload.global_model_ids:
try:
global_model = (
db.query(GlobalModel).filter(GlobalModel.id == global_model_id).first()
)
if not global_model:
errors.append(
{"global_model_id": global_model_id, "error": "GlobalModel not found"}
)
continue
# 检查是否已存在关联
existing = (
db.query(Model)
.filter(
Model.provider_id == self.provider_id,
Model.global_model_id == global_model_id,
)
.first()
)
if existing:
errors.append(
{
"global_model_id": global_model_id,
"global_model_name": global_model.name,
"error": "Already associated",
}
)
continue
# 创建新的 Model 记录,继承 GlobalModel 的配置
new_model = Model(
provider_id=self.provider_id,
global_model_id=global_model_id,
provider_model_name=global_model.name,
is_active=True,
)
db.add(new_model)
db.flush()
success.append(
{
"global_model_id": global_model_id,
"global_model_name": global_model.name,
"model_id": new_model.id,
}
)
except Exception as e:
errors.append({"global_model_id": global_model_id, "error": str(e)})
db.commit()
logger.info(
f"Batch assigned {len(success)} GlobalModels to provider {provider.name} by {context.user.username}"
)
return BatchAssignModelsToProviderResponse(success=success, errors=errors)

View File

@@ -0,0 +1,249 @@
"""管理员 Provider 管理路由。"""
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, Query, Request
from pydantic import ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.database import get_db
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
from src.models.database import Provider
router = APIRouter(tags=["Provider CRUD"])
pipeline = ApiRequestPipeline()
@router.get("/")
async def list_providers(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=500),
is_active: Optional[bool] = None,
db: Session = Depends(get_db),
):
adapter = AdminListProvidersAdapter(skip=skip, limit=limit, is_active=is_active)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/")
async def create_provider(request: Request, db: Session = Depends(get_db)):
adapter = AdminCreateProviderAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{provider_id}")
async def update_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
adapter = AdminUpdateProviderAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{provider_id}")
async def delete_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
adapter = AdminDeleteProviderAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminListProvidersAdapter(AdminApiAdapter):
def __init__(self, skip: int, limit: int, is_active: Optional[bool]):
self.skip = skip
self.limit = limit
self.is_active = is_active
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(Provider)
if self.is_active is not None:
query = query.filter(Provider.is_active == self.is_active)
providers = query.offset(self.skip).limit(self.limit).all()
data = []
for provider in providers:
api_format = getattr(provider, "api_format", None)
base_url = getattr(provider, "base_url", None)
api_key = getattr(provider, "api_key", None)
priority = getattr(provider, "priority", provider.provider_priority)
data.append(
{
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"api_format": api_format.value if api_format else None,
"base_url": base_url,
"api_key": "***" if api_key else None,
"priority": priority,
"is_active": provider.is_active,
"created_at": provider.created_at.isoformat(),
"updated_at": provider.updated_at.isoformat() if provider.updated_at else None,
}
)
context.add_audit_metadata(
action="list_providers",
filter_is_active=self.is_active,
limit=self.limit,
skip=self.skip,
result_count=len(data),
)
return data
class AdminCreateProviderAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
validated_data = CreateProviderRequest.model_validate(payload)
except ValidationError as exc:
# 将 Pydantic 验证错误转换为友好的错误信息
errors = []
for error in exc.errors():
field = " -> ".join(str(x) for x in error["loc"])
errors.append(f"{field}: {error['msg']}")
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
try:
# 检查名称是否已存在
existing = db.query(Provider).filter(Provider.name == validated_data.name).first()
if existing:
raise InvalidRequestException(f"提供商名称 '{validated_data.name}' 已存在")
# 将验证后的数据转换为枚举类型
billing_type = (
ProviderBillingType(validated_data.billing_type)
if validated_data.billing_type
else ProviderBillingType.PAY_AS_YOU_GO
)
# 创建 Provider 对象
provider = Provider(
name=validated_data.name,
display_name=validated_data.display_name,
description=validated_data.description,
website=validated_data.website,
billing_type=billing_type,
monthly_quota_usd=validated_data.monthly_quota_usd,
quota_reset_day=validated_data.quota_reset_day,
quota_last_reset_at=validated_data.quota_last_reset_at,
quota_expires_at=validated_data.quota_expires_at,
rpm_limit=validated_data.rpm_limit,
provider_priority=validated_data.provider_priority,
is_active=validated_data.is_active,
rate_limit=validated_data.rate_limit,
concurrent_limit=validated_data.concurrent_limit,
config=validated_data.config,
)
db.add(provider)
db.commit()
db.refresh(provider)
context.add_audit_metadata(
action="create_provider",
provider_id=provider.id,
provider_name=provider.name,
billing_type=provider.billing_type.value if provider.billing_type else None,
is_active=provider.is_active,
provider_priority=provider.provider_priority,
)
return {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"message": "提供商创建成功",
}
except InvalidRequestException:
db.rollback()
raise
except Exception:
db.rollback()
raise
class AdminUpdateProviderAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
# 查找 Provider
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("提供商不存在", "provider")
try:
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
validated_data = UpdateProviderRequest.model_validate(payload)
except ValidationError as exc:
# 将 Pydantic 验证错误转换为友好的错误信息
errors = []
for error in exc.errors():
field = " -> ".join(str(x) for x in error["loc"])
errors.append(f"{field}: {error['msg']}")
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
try:
# 更新字段(只更新非 None 的字段)
update_data = validated_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "billing_type" and value is not None:
# billing_type 需要转换为枚举
setattr(provider, field, ProviderBillingType(value))
else:
setattr(provider, field, value)
provider.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(provider)
context.add_audit_metadata(
action="update_provider",
provider_id=provider.id,
changed_fields=list(update_data.keys()),
is_active=provider.is_active,
provider_priority=provider.provider_priority,
)
return {
"id": provider.id,
"name": provider.name,
"is_active": provider.is_active,
"message": "提供商更新成功",
}
except InvalidRequestException:
db.rollback()
raise
except Exception:
db.rollback()
raise
class AdminDeleteProviderAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("提供商不存在", "provider")
context.add_audit_metadata(
action="delete_provider",
provider_id=provider.id,
provider_name=provider.name,
)
db.delete(provider)
db.commit()
return {"message": "提供商已删除"}

View File

@@ -0,0 +1,348 @@
"""
Provider 摘要与健康监控 API
"""
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Dict, List
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.database import (
Model,
Provider,
ProviderAPIKey,
ProviderEndpoint,
RequestCandidate,
)
from src.models.endpoint_models import (
EndpointHealthEvent,
EndpointHealthMonitor,
ProviderEndpointHealthMonitorResponse,
ProviderUpdateRequest,
ProviderWithEndpointsSummary,
)
router = APIRouter(tags=["Provider Summary"])
pipeline = ApiRequestPipeline()
@router.get("/summary", response_model=List[ProviderWithEndpointsSummary])
async def get_providers_summary(
request: Request,
db: Session = Depends(get_db),
) -> List[ProviderWithEndpointsSummary]:
"""获取所有 Providers 的摘要信息(包含 Endpoints 和 Keys 统计)"""
adapter = AdminProviderSummaryAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{provider_id}/summary", response_model=ProviderWithEndpointsSummary)
async def get_provider_summary(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ProviderWithEndpointsSummary:
"""获取单个 Provider 的摘要信息(包含 Endpoints 和 Keys 统计)"""
provider = db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
raise NotFoundException(f"Provider {provider_id} not found")
return _build_provider_summary(db, provider)
@router.get("/{provider_id}/health-monitor", response_model=ProviderEndpointHealthMonitorResponse)
async def get_provider_health_monitor(
provider_id: str,
request: Request,
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
per_endpoint_limit: int = Query(48, ge=10, le=200, description="每个端点的事件数量"),
db: Session = Depends(get_db),
) -> ProviderEndpointHealthMonitorResponse:
"""获取 Provider 下所有端点的健康监控时间线"""
adapter = AdminProviderHealthMonitorAdapter(
provider_id=provider_id,
lookback_hours=lookback_hours,
per_endpoint_limit=per_endpoint_limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{provider_id}", response_model=ProviderWithEndpointsSummary)
async def update_provider_settings(
provider_id: str,
update_data: ProviderUpdateRequest,
request: Request,
db: Session = Depends(get_db),
) -> ProviderWithEndpointsSummary:
"""更新 Provider 基础配置display_name, description, priority, weight 等)"""
adapter = AdminUpdateProviderSettingsAdapter(provider_id=provider_id, update_data=update_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndpointsSummary:
endpoints = db.query(ProviderEndpoint).filter(ProviderEndpoint.provider_id == provider.id).all()
total_endpoints = len(endpoints)
active_endpoints = sum(1 for e in endpoints if e.is_active)
endpoint_ids = [e.id for e in endpoints]
# Key 统计(合并为单个查询)
total_keys = 0
active_keys = 0
if endpoint_ids:
key_stats = db.query(
func.count(ProviderAPIKey.id).label("total"),
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).first()
total_keys = key_stats.total or 0
active_keys = int(key_stats.active or 0)
# Model 统计(合并为单个查询)
model_stats = db.query(
func.count(Model.id).label("total"),
func.sum(case((Model.is_active == True, 1), else_=0)).label("active"),
).filter(Model.provider_id == provider.id).first()
total_models = model_stats.total or 0
active_models = int(model_stats.active or 0)
api_formats = [e.api_format for e in endpoints]
# 优化: 一次性加载所有 endpoint 的 keys避免 N+1 查询
all_keys = []
if endpoint_ids:
all_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).all()
)
# 按 endpoint_id 分组 keys
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {}
for key in all_keys:
if key.endpoint_id not in keys_by_endpoint:
keys_by_endpoint[key.endpoint_id] = []
keys_by_endpoint[key.endpoint_id].append(key)
endpoint_health_map: dict[str, float] = {}
for endpoint in endpoints:
keys = keys_by_endpoint.get(endpoint.id, [])
if keys:
health_scores = [k.health_score for k in keys if k.health_score is not None]
avg_health = sum(health_scores) / len(health_scores) if health_scores else 1.0
endpoint_health_map[endpoint.id] = avg_health
else:
endpoint_health_map[endpoint.id] = 1.0
all_health_scores = list(endpoint_health_map.values())
avg_health_score = sum(all_health_scores) / len(all_health_scores) if all_health_scores else 1.0
unhealthy_endpoints = sum(1 for score in all_health_scores if score < 0.5)
# 计算每个端点的活跃密钥数量
active_keys_by_endpoint: dict[str, int] = {}
for endpoint_id, keys in keys_by_endpoint.items():
active_keys_by_endpoint[endpoint_id] = sum(1 for k in keys if k.is_active)
endpoint_health_details = [
{
"api_format": e.api_format,
"health_score": endpoint_health_map.get(e.id, 1.0),
"is_active": e.is_active,
"active_keys": active_keys_by_endpoint.get(e.id, 0),
}
for e in endpoints
]
return ProviderWithEndpointsSummary(
id=provider.id,
name=provider.name,
display_name=provider.display_name,
description=provider.description,
website=provider.website,
provider_priority=provider.provider_priority,
is_active=provider.is_active,
billing_type=provider.billing_type.value if provider.billing_type else None,
monthly_quota_usd=provider.monthly_quota_usd,
monthly_used_usd=provider.monthly_used_usd,
quota_reset_day=provider.quota_reset_day,
quota_last_reset_at=provider.quota_last_reset_at,
quota_expires_at=provider.quota_expires_at,
rpm_limit=provider.rpm_limit,
rpm_used=provider.rpm_used,
rpm_reset_at=provider.rpm_reset_at,
total_endpoints=total_endpoints,
active_endpoints=active_endpoints,
total_keys=total_keys,
active_keys=active_keys,
total_models=total_models,
active_models=active_models,
avg_health_score=avg_health_score,
unhealthy_endpoints=unhealthy_endpoints,
api_formats=api_formats,
endpoint_health_details=endpoint_health_details,
created_at=provider.created_at,
updated_at=provider.updated_at,
)
# -------- Adapters --------
@dataclass
class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
provider_id: str
lookback_hours: int
per_endpoint_limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
endpoints = (
db.query(ProviderEndpoint)
.filter(ProviderEndpoint.provider_id == self.provider_id)
.all()
)
now = datetime.now(timezone.utc)
since = now - timedelta(hours=self.lookback_hours)
endpoint_ids = [endpoint.id for endpoint in endpoints]
if not endpoint_ids:
response = ProviderEndpointHealthMonitorResponse(
provider_id=provider.id,
provider_name=provider.display_name or provider.name,
generated_at=now,
endpoints=[],
)
context.add_audit_metadata(
action="provider_health_monitor",
provider_id=self.provider_id,
endpoint_count=0,
lookback_hours=self.lookback_hours,
)
return response
limit_rows = max(200, self.per_endpoint_limit * max(1, len(endpoint_ids)) * 2)
attempts_query = (
db.query(RequestCandidate)
.filter(
RequestCandidate.endpoint_id.in_(endpoint_ids),
RequestCandidate.created_at >= since,
)
.order_by(RequestCandidate.created_at.desc())
)
attempts = attempts_query.limit(limit_rows).all()
buffered_attempts: Dict[str, List[RequestCandidate]] = {eid: [] for eid in endpoint_ids}
counters: Dict[str, int] = {eid: 0 for eid in endpoint_ids}
for attempt in attempts:
if not attempt.endpoint_id or attempt.endpoint_id not in buffered_attempts:
continue
if counters[attempt.endpoint_id] >= self.per_endpoint_limit:
continue
buffered_attempts[attempt.endpoint_id].append(attempt)
counters[attempt.endpoint_id] += 1
endpoint_monitors: List[EndpointHealthMonitor] = []
for endpoint in endpoints:
attempt_list = list(reversed(buffered_attempts.get(endpoint.id, [])))
events: List[EndpointHealthEvent] = []
for attempt in attempt_list:
event_timestamp = attempt.finished_at or attempt.started_at or attempt.created_at
events.append(
EndpointHealthEvent(
timestamp=event_timestamp,
status=attempt.status,
status_code=attempt.status_code,
latency_ms=attempt.latency_ms,
error_type=attempt.error_type,
error_message=attempt.error_message,
)
)
success_count = sum(1 for event in events if event.status == "success")
failed_count = sum(1 for event in events if event.status == "failed")
skipped_count = sum(1 for event in events if event.status == "skipped")
total_attempts = len(events)
success_rate = success_count / total_attempts if total_attempts else 1.0
last_event_at = events[-1].timestamp if events else None
endpoint_monitors.append(
EndpointHealthMonitor(
endpoint_id=endpoint.id,
api_format=endpoint.api_format,
is_active=endpoint.is_active,
total_attempts=total_attempts,
success_count=success_count,
failed_count=failed_count,
skipped_count=skipped_count,
success_rate=success_rate,
last_event_at=last_event_at,
events=events,
)
)
response = ProviderEndpointHealthMonitorResponse(
provider_id=provider.id,
provider_name=provider.display_name or provider.name,
generated_at=now,
endpoints=endpoint_monitors,
)
context.add_audit_metadata(
action="provider_health_monitor",
provider_id=self.provider_id,
endpoint_count=len(endpoint_monitors),
lookback_hours=self.lookback_hours,
per_endpoint_limit=self.per_endpoint_limit,
)
return response
class AdminProviderSummaryAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
providers = (
db.query(Provider)
.order_by(Provider.provider_priority.asc(), Provider.created_at.asc())
.all()
)
return [_build_provider_summary(db, provider) for provider in providers]
@dataclass
class AdminUpdateProviderSettingsAdapter(AdminApiAdapter):
provider_id: str
update_data: ProviderUpdateRequest
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
update_dict = self.update_data.model_dump(exclude_unset=True)
if "billing_type" in update_dict and update_dict["billing_type"] is not None:
update_dict["billing_type"] = ProviderBillingType(update_dict["billing_type"])
for key, value in update_dict.items():
setattr(provider, key, value)
db.commit()
db.refresh(provider)
admin_name = context.user.username if context.user else "admin"
logger.info(f"Provider {provider.name} updated by {admin_name}: {update_dict}")
return _build_provider_summary(db, provider)

View File

@@ -0,0 +1,14 @@
"""
安全管理 API
提供 IP 黑白名单管理等安全功能
"""
from fastapi import APIRouter
from .ip_management import router as ip_router
router = APIRouter()
router.include_router(ip_router)
__all__ = ["router"]

View File

@@ -0,0 +1,202 @@
"""
IP 安全管理接口
提供 IP 黑白名单管理和速率限制统计
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from src.api.base.adapter import ApiMode
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.core.logger import logger
from src.database import get_db
from src.services.rate_limit.ip_limiter import IPRateLimiter
router = APIRouter(prefix="/api/admin/security/ip", tags=["IP Security"])
pipeline = ApiRequestPipeline()
# ========== Pydantic 模型 ==========
class AddIPToBlacklistRequest(BaseModel):
"""添加 IP 到黑名单请求"""
ip_address: str = Field(..., description="IP 地址")
reason: str = Field(..., min_length=1, max_length=200, description="加入黑名单的原因")
ttl: Optional[int] = Field(None, gt=0, description="过期时间None 表示永久")
class RemoveIPFromBlacklistRequest(BaseModel):
"""从黑名单移除 IP 请求"""
ip_address: str = Field(..., description="IP 地址")
class AddIPToWhitelistRequest(BaseModel):
"""添加 IP 到白名单请求"""
ip_address: str = Field(..., description="IP 地址或 CIDR 格式(如 192.168.1.0/24")
class RemoveIPFromWhitelistRequest(BaseModel):
"""从白名单移除 IP 请求"""
ip_address: str = Field(..., description="IP 地址")
# ========== API 端点 ==========
@router.post("/blacklist")
async def add_to_blacklist(request: Request, db: Session = Depends(get_db)):
"""Add IP to blacklist"""
adapter = AddToBlacklistAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
@router.delete("/blacklist/{ip_address}")
async def remove_from_blacklist(ip_address: str, request: Request, db: Session = Depends(get_db)):
"""Remove IP from blacklist"""
adapter = RemoveFromBlacklistAdapter(ip_address=ip_address)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
@router.get("/blacklist/stats")
async def get_blacklist_stats(request: Request, db: Session = Depends(get_db)):
"""Get blacklist statistics"""
adapter = GetBlacklistStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
@router.post("/whitelist")
async def add_to_whitelist(request: Request, db: Session = Depends(get_db)):
"""Add IP to whitelist"""
adapter = AddToWhitelistAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
@router.delete("/whitelist/{ip_address}")
async def remove_from_whitelist(ip_address: str, request: Request, db: Session = Depends(get_db)):
"""Remove IP from whitelist"""
adapter = RemoveFromWhitelistAdapter(ip_address=ip_address)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
@router.get("/whitelist")
async def get_whitelist(request: Request, db: Session = Depends(get_db)):
"""Get whitelist"""
adapter = GetWhitelistAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
# ========== 适配器实现 ==========
class AddToBlacklistAdapter(AuthenticatedApiAdapter):
"""添加 IP 到黑名单适配器"""
async def handle(self, context): # type: ignore[override]
payload = context.ensure_json_body()
try:
req = AddIPToBlacklistRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
success = await IPRateLimiter.add_to_blacklist(req.ip_address, req.reason, req.ttl)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="添加 IP 到黑名单失败Redis 不可用)",
)
return {
"success": True,
"message": f"IP {req.ip_address} 已加入黑名单",
"reason": req.reason,
"ttl": req.ttl or "永久",
}
class RemoveFromBlacklistAdapter(AuthenticatedApiAdapter):
"""从黑名单移除 IP 适配器"""
def __init__(self, ip_address: str):
self.ip_address = ip_address
async def handle(self, context): # type: ignore[override]
success = await IPRateLimiter.remove_from_blacklist(self.ip_address)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"IP {self.ip_address} 不在黑名单中"
)
return {"success": True, "message": f"IP {self.ip_address} 已从黑名单移除"}
class GetBlacklistStatsAdapter(AuthenticatedApiAdapter):
"""获取黑名单统计适配器"""
async def handle(self, context): # type: ignore[override]
stats = await IPRateLimiter.get_blacklist_stats()
return stats
class AddToWhitelistAdapter(AuthenticatedApiAdapter):
"""添加 IP 到白名单适配器"""
async def handle(self, context): # type: ignore[override]
payload = context.ensure_json_body()
try:
req = AddIPToWhitelistRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
success = await IPRateLimiter.add_to_whitelist(req.ip_address)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"添加 IP 到白名单失败(无效的 IP 格式或 Redis 不可用)",
)
return {"success": True, "message": f"IP {req.ip_address} 已加入白名单"}
class RemoveFromWhitelistAdapter(AuthenticatedApiAdapter):
"""从白名单移除 IP 适配器"""
def __init__(self, ip_address: str):
self.ip_address = ip_address
async def handle(self, context): # type: ignore[override]
success = await IPRateLimiter.remove_from_whitelist(self.ip_address)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"IP {self.ip_address} 不在白名单中"
)
return {"success": True, "message": f"IP {self.ip_address} 已从白名单移除"}
class GetWhitelistAdapter(AuthenticatedApiAdapter):
"""获取白名单适配器"""
async def handle(self, context): # type: ignore[override]
whitelist = await IPRateLimiter.get_whitelist()
return {"whitelist": list(whitelist), "total": len(whitelist)}

312
src/api/admin/system.py Normal file
View File

@@ -0,0 +1,312 @@
"""系统设置API端点。"""
from dataclasses import dataclass
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException, translate_pydantic_error
from src.database import get_db
from src.models.api import SystemSettingsRequest, SystemSettingsResponse
from src.models.database import ApiKey, Provider, Usage, User
from src.services.system.config import SystemConfigService
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
pipeline = ApiRequestPipeline()
@router.get("/settings")
async def get_system_settings(request: Request, db: Session = Depends(get_db)):
"""获取系统设置(管理员)"""
adapter = AdminGetSystemSettingsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/settings")
async def update_system_settings(http_request: Request, db: Session = Depends(get_db)):
"""更新系统设置(管理员)"""
adapter = AdminUpdateSystemSettingsAdapter()
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
@router.get("/configs")
async def get_all_system_configs(request: Request, db: Session = Depends(get_db)):
"""获取所有系统配置(管理员)"""
adapter = AdminGetAllConfigsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/configs/{key}")
async def get_system_config(key: str, request: Request, db: Session = Depends(get_db)):
"""获取特定系统配置(管理员)"""
adapter = AdminGetSystemConfigAdapter(key=key)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/configs/{key}")
async def set_system_config(
key: str,
request: Request,
db: Session = Depends(get_db),
):
"""设置系统配置(管理员)"""
adapter = AdminSetSystemConfigAdapter(key=key)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/configs/{key}")
async def delete_system_config(key: str, request: Request, db: Session = Depends(get_db)):
"""删除系统配置(管理员)"""
adapter = AdminDeleteSystemConfigAdapter(key=key)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/stats")
async def get_system_stats(request: Request, db: Session = Depends(get_db)):
adapter = AdminSystemStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/cleanup")
async def trigger_cleanup(request: Request, db: Session = Depends(get_db)):
"""Manually trigger usage record cleanup task"""
adapter = AdminTriggerCleanupAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/api-formats")
async def get_api_formats(request: Request, db: Session = Depends(get_db)):
"""获取所有可用的API格式列表"""
adapter = AdminGetApiFormatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- 系统设置适配器 --------
class AdminGetSystemSettingsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
default_provider = SystemConfigService.get_default_provider(db)
default_model = SystemConfigService.get_config(db, "default_model")
enable_usage_tracking = (
SystemConfigService.get_config(db, "enable_usage_tracking", "true") == "true"
)
return SystemSettingsResponse(
default_provider=default_provider,
default_model=default_model,
enable_usage_tracking=enable_usage_tracking,
)
class AdminUpdateSystemSettingsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
settings_request = SystemSettingsRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
if settings_request.default_provider is not None:
provider = (
db.query(Provider)
.filter(
Provider.name == settings_request.default_provider,
Provider.is_active.is_(True),
)
.first()
)
if not provider and settings_request.default_provider != "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"提供商 '{settings_request.default_provider}' 不存在或未启用",
)
if settings_request.default_provider:
SystemConfigService.set_default_provider(db, settings_request.default_provider)
else:
SystemConfigService.delete_config(db, "default_provider")
if settings_request.default_model is not None:
if settings_request.default_model:
SystemConfigService.set_config(db, "default_model", settings_request.default_model)
else:
SystemConfigService.delete_config(db, "default_model")
if settings_request.enable_usage_tracking is not None:
SystemConfigService.set_config(
db,
"enable_usage_tracking",
str(settings_request.enable_usage_tracking).lower(),
)
return {"message": "系统设置更新成功"}
class AdminGetAllConfigsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
return SystemConfigService.get_all_configs(context.db)
@dataclass
class AdminGetSystemConfigAdapter(AdminApiAdapter):
key: str
async def handle(self, context): # type: ignore[override]
value = SystemConfigService.get_config(context.db, self.key)
if value is None:
raise NotFoundException(f"配置项 '{self.key}' 不存在")
return {"key": self.key, "value": value}
@dataclass
class AdminSetSystemConfigAdapter(AdminApiAdapter):
key: str
async def handle(self, context): # type: ignore[override]
payload = context.ensure_json_body()
config = SystemConfigService.set_config(
context.db,
self.key,
payload.get("value"),
payload.get("description"),
)
return {
"key": config.key,
"value": config.value,
"description": config.description,
"updated_at": config.updated_at.isoformat(),
}
@dataclass
class AdminDeleteSystemConfigAdapter(AdminApiAdapter):
key: str
async def handle(self, context): # type: ignore[override]
deleted = SystemConfigService.delete_config(context.db, self.key)
if not deleted:
raise NotFoundException(f"配置项 '{self.key}' 不存在")
return {"message": f"配置项 '{self.key}' 已删除"}
class AdminSystemStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
total_users = db.query(User).count()
active_users = db.query(User).filter(User.is_active.is_(True)).count()
total_providers = db.query(Provider).count()
active_providers = db.query(Provider).filter(Provider.is_active.is_(True)).count()
total_api_keys = db.query(ApiKey).count()
total_requests = db.query(Usage).count()
return {
"users": {"total": total_users, "active": active_users},
"providers": {"total": total_providers, "active": active_providers},
"api_keys": total_api_keys,
"requests": total_requests,
}
class AdminTriggerCleanupAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""手动触发清理任务"""
from datetime import datetime, timedelta, timezone
from sqlalchemy import func
from src.services.system.cleanup_scheduler import get_cleanup_scheduler
db = context.db
# 获取清理前的统计信息
total_before = db.query(Usage).count()
with_body_before = (
db.query(Usage)
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
.count()
)
with_headers_before = (
db.query(Usage)
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
.count()
)
# 触发清理
cleanup_scheduler = get_cleanup_scheduler()
await cleanup_scheduler._perform_cleanup()
# 获取清理后的统计信息
total_after = db.query(Usage).count()
with_body_after = (
db.query(Usage)
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
.count()
)
with_headers_after = (
db.query(Usage)
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
.count()
)
return {
"message": "清理任务执行完成",
"stats": {
"total_records": {
"before": total_before,
"after": total_after,
"deleted": total_before - total_after,
},
"body_fields": {
"before": with_body_before,
"after": with_body_after,
"cleaned": with_body_before - with_body_after,
},
"header_fields": {
"before": with_headers_before,
"after": with_headers_after,
"cleaned": with_headers_before - with_headers_after,
},
},
"timestamp": datetime.now(timezone.utc).isoformat(),
}
class AdminGetApiFormatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""获取所有可用的API格式"""
from src.core.api_format_metadata import API_FORMAT_DEFINITIONS
from src.core.enums import APIFormat
_ = context # 参数保留以符合接口规范
formats = []
for api_format in APIFormat:
definition = API_FORMAT_DEFINITIONS.get(api_format)
formats.append(
{
"value": api_format.value,
"label": api_format.value,
"default_path": definition.default_path if definition else "/",
"aliases": list(definition.aliases) if definition else [],
}
)
return {"formats": formats}

View File

@@ -0,0 +1,5 @@
"""Usage admin routes export."""
from .routes import router
__all__ = ["router"]

View File

@@ -0,0 +1,818 @@
"""管理员使用情况统计路由。"""
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.database import get_db
from src.models.database import (
ApiKey,
Provider,
ProviderAPIKey,
ProviderEndpoint,
RequestCandidate,
Usage,
User,
)
from src.services.usage.service import UsageService
router = APIRouter(prefix="/api/admin/usage", tags=["Admin - Usage"])
pipeline = ApiRequestPipeline()
# ==================== RESTful Routes ====================
@router.get("/aggregation/stats")
async def get_usage_aggregation(
request: Request,
group_by: str = Query(..., description="Aggregation dimension: model, user, provider, or api_format"),
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
):
"""
Get usage aggregation by specified dimension.
- group_by=model: Aggregate by model
- group_by=user: Aggregate by user
- group_by=provider: Aggregate by provider
- group_by=api_format: Aggregate by API format
"""
if group_by == "model":
adapter = AdminUsageByModelAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "user":
adapter = AdminUsageByUserAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "provider":
adapter = AdminUsageByProviderAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "api_format":
adapter = AdminUsageByApiFormatAdapter(start_date=start_date, end_date=end_date, limit=limit)
else:
raise HTTPException(
status_code=400,
detail=f"Invalid group_by value: {group_by}. Must be one of: model, user, provider, api_format"
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/stats")
async def get_usage_stats(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
db: Session = Depends(get_db),
):
adapter = AdminUsageStatsAdapter(start_date=start_date, end_date=end_date)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/records")
async def get_usage_records(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
user_id: Optional[str] = None,
username: Optional[str] = None,
model: Optional[str] = None,
provider: Optional[str] = None,
status: Optional[str] = None, # stream, standard, error
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
):
adapter = AdminUsageRecordsAdapter(
start_date=start_date,
end_date=end_date,
user_id=user_id,
username=username,
model=model,
provider=provider,
status=status,
limit=limit,
offset=offset,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/active")
async def get_active_requests(
request: Request,
ids: Optional[str] = Query(None, description="逗号分隔的请求 ID 列表,用于查询特定请求的状态"),
db: Session = Depends(get_db),
):
"""
获取活跃请求的状态(轻量级接口,用于前端轮询)
- 如果提供 ids 参数,只返回这些 ID 对应请求的最新状态
- 如果不提供 ids返回所有 pending/streaming 状态的请求
"""
adapter = AdminActiveRequestsAdapter(ids=ids)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# NOTE: This route must be defined AFTER all other routes to avoid matching
# routes like /stats, /records, /active, etc.
@router.get("/{usage_id}")
async def get_usage_detail(
usage_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Get detailed information of a specific usage record.
Includes request/response headers and body.
"""
adapter = AdminUsageDetailAdapter(usage_id=usage_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminUsageStatsAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime]):
self.start_date = start_date
self.end_date = end_date
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(Usage)
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
total_stats = query.with_entities(
func.count(Usage.id).label("total_requests"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("total_actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
).first()
# 缓存统计
cache_stats = query.with_entities(
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
).first()
# 错误统计
error_count = query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
activity_heatmap = UsageService.get_daily_activity(
db=db,
window_days=365,
include_actual_cost=True,
)
context.add_audit_metadata(
action="usage_stats",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
)
total_requests = total_stats.total_requests if total_stats else 0
avg_response_time_ms = float(total_stats.avg_response_time_ms or 0) if total_stats else 0
avg_response_time = avg_response_time_ms / 1000.0
return {
"total_requests": total_requests,
"total_tokens": int(total_stats.total_tokens or 0),
"total_cost": float(total_stats.total_cost or 0),
"total_actual_cost": float(total_stats.total_actual_cost or 0),
"avg_response_time": round(avg_response_time, 2),
"error_count": error_count,
"error_rate": (
round((error_count / total_requests) * 100, 2) if total_requests > 0 else 0
),
"cache_stats": {
"cache_creation_tokens": (
int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
),
"cache_read_tokens": int(cache_stats.cache_read_tokens or 0) if cache_stats else 0,
"cache_creation_cost": (
float(cache_stats.cache_creation_cost or 0) if cache_stats else 0
),
"cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0,
},
"activity_heatmap": activity_heatmap,
}
class AdminUsageByModelAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(
Usage.model,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
)
# 过滤掉 pending/streaming 状态的请求(尚未完成的请求不应计入统计)
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
# 过滤掉 unknown/pending provider请求未到达任何提供商
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = query.group_by(Usage.model).order_by(func.count(Usage.id).desc()).limit(self.limit)
stats = query.all()
context.add_audit_metadata(
action="usage_by_model",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"model": model,
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
"actual_cost": float(actual_cost or 0),
}
for model, count, tokens, cost, actual_cost in stats
]
class AdminUsageByUserAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = (
db.query(
User.id,
User.email,
User.username,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
)
.join(Usage, Usage.user_id == User.id)
.group_by(User.id, User.email, User.username)
)
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = query.order_by(func.count(Usage.id).desc()).limit(self.limit)
stats = query.all()
context.add_audit_metadata(
action="usage_by_user",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"user_id": user_id,
"email": email,
"username": username,
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
}
for user_id, email, username, count, tokens, cost in stats
]
class AdminUsageByProviderAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
# 从 request_candidates 表统计每个 Provider 的尝试次数和成功率
# 这样可以正确统计 Fallback 场景(一个请求可能尝试多个 Provider
from sqlalchemy import case, Integer
attempt_query = db.query(
RequestCandidate.provider_id,
func.count(RequestCandidate.id).label("attempt_count"),
func.sum(
case((RequestCandidate.status == "success", 1), else_=0)
).label("success_count"),
func.sum(
case((RequestCandidate.status == "failed", 1), else_=0)
).label("failed_count"),
func.avg(RequestCandidate.latency_ms).label("avg_latency_ms"),
).filter(
RequestCandidate.provider_id.isnot(None),
# 只统计实际执行的尝试(排除 available/skipped 状态)
RequestCandidate.status.in_(["success", "failed"]),
)
if self.start_date:
attempt_query = attempt_query.filter(RequestCandidate.created_at >= self.start_date)
if self.end_date:
attempt_query = attempt_query.filter(RequestCandidate.created_at <= self.end_date)
attempt_stats = (
attempt_query.group_by(RequestCandidate.provider_id)
.order_by(func.count(RequestCandidate.id).desc())
.limit(self.limit)
.all()
)
# 从 Usage 表获取 token 和费用统计(基于成功的请求)
usage_query = db.query(
Usage.provider_id,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
).filter(
Usage.provider_id.isnot(None),
# 过滤掉 pending/streaming 状态的请求
Usage.status.notin_(["pending", "streaming"]),
)
if self.start_date:
usage_query = usage_query.filter(Usage.created_at >= self.start_date)
if self.end_date:
usage_query = usage_query.filter(Usage.created_at <= self.end_date)
usage_stats = usage_query.group_by(Usage.provider_id).all()
usage_map = {str(u.provider_id): u for u in usage_stats}
# 获取所有相关的 Provider ID
provider_ids = set()
for stat in attempt_stats:
if stat.provider_id:
provider_ids.add(stat.provider_id)
for stat in usage_stats:
if stat.provider_id:
provider_ids.add(stat.provider_id)
# 获取 Provider 名称映射
provider_map = {}
if provider_ids:
providers_data = (
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
)
provider_map = {str(p.id): p.name for p in providers_data}
context.add_audit_metadata(
action="usage_by_provider",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(attempt_stats),
)
result = []
for stat in attempt_stats:
provider_id_str = str(stat.provider_id) if stat.provider_id else None
attempt_count = stat.attempt_count or 0
success_count = int(stat.success_count or 0)
failed_count = int(stat.failed_count or 0)
success_rate = (success_count / attempt_count * 100) if attempt_count > 0 else 0
# 从 usage_map 获取 token 和费用信息
usage_stat = usage_map.get(provider_id_str)
result.append({
"provider_id": provider_id_str,
"provider": provider_map.get(provider_id_str, "Unknown"),
"request_count": attempt_count, # 尝试次数
"total_tokens": int(usage_stat.total_tokens or 0) if usage_stat else 0,
"total_cost": float(usage_stat.total_cost or 0) if usage_stat else 0,
"actual_cost": float(usage_stat.actual_cost or 0) if usage_stat else 0,
"avg_response_time_ms": float(stat.avg_latency_ms or 0),
"success_rate": round(success_rate, 2),
"error_count": failed_count,
})
return result
class AdminUsageByApiFormatAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(
Usage.api_format,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
)
# 过滤掉 pending/streaming 状态的请求
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
# 过滤掉 unknown/pending provider
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
# 只统计有 api_format 的记录
query = query.filter(Usage.api_format.isnot(None))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = (
query.group_by(Usage.api_format)
.order_by(func.count(Usage.id).desc())
.limit(self.limit)
)
stats = query.all()
context.add_audit_metadata(
action="usage_by_api_format",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"api_format": api_format or "unknown",
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
"actual_cost": float(actual_cost or 0),
"avg_response_time_ms": float(avg_response_time or 0),
}
for api_format, count, tokens, cost, actual_cost, avg_response_time in stats
]
class AdminUsageRecordsAdapter(AdminApiAdapter):
def __init__(
self,
start_date: Optional[datetime],
end_date: Optional[datetime],
user_id: Optional[str],
username: Optional[str],
model: Optional[str],
provider: Optional[str],
status: Optional[str],
limit: int,
offset: int,
):
self.start_date = start_date
self.end_date = end_date
self.user_id = user_id
self.username = username
self.model = model
self.provider = provider
self.status = status
self.limit = limit
self.offset = offset
async def handle(self, context): # type: ignore[override]
db = context.db
query = (
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey)
.outerjoin(User, Usage.user_id == User.id)
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
)
if self.user_id:
query = query.filter(Usage.user_id == self.user_id)
if self.username:
# 支持用户名模糊搜索
query = query.filter(User.username.ilike(f"%{self.username}%"))
if self.model:
# 支持模型名模糊搜索
query = query.filter(Usage.model.ilike(f"%{self.model}%"))
if self.provider:
# 支持提供商名称搜索(通过 Provider 表)
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
query = query.filter(Provider.name.ilike(f"%{self.provider}%"))
if self.status:
# 状态筛选
# 旧的筛选值(基于 is_stream 和 status_codestream, standard, error
# 新的筛选值(基于 status 字段pending, streaming, completed, failed, active
if self.status == "stream":
query = query.filter(Usage.is_stream == True) # noqa: E712
elif self.status == "standard":
query = query.filter(Usage.is_stream == False) # noqa: E712
elif self.status == "error":
query = query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
)
elif self.status in ("pending", "streaming", "completed", "failed"):
# 新的状态筛选:直接按 status 字段过滤
query = query.filter(Usage.status == self.status)
elif self.status == "active":
# 活跃请求pending 或 streaming 状态
query = query.filter(Usage.status.in_(["pending", "streaming"]))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
total = query.count()
records = (
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
)
request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id]
fallback_map = {}
if request_ids:
# 只统计实际执行的候选success 或 failed不包括 skipped/pending/available
executed_counts = (
db.query(RequestCandidate.request_id, func.count(RequestCandidate.id))
.filter(
RequestCandidate.request_id.in_(request_ids),
RequestCandidate.status.in_(["success", "failed"]),
)
.group_by(RequestCandidate.request_id)
.all()
)
# 如果实际执行的候选数 > 1说明发生了 Provider 切换
fallback_map = {req_id: count > 1 for req_id, count in executed_counts}
context.add_audit_metadata(
action="usage_records",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
user_id=self.user_id,
username=self.username,
model=self.model,
provider=self.provider,
status=self.status,
limit=self.limit,
offset=self.offset,
total=total,
)
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id]
provider_map = {}
if provider_ids:
providers_data = (
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
)
provider_map = {str(p.id): p.name for p in providers_data}
data = []
for usage, user, endpoint, api_key in records:
actual_cost = (
float(usage.actual_total_cost_usd)
if usage.actual_total_cost_usd is not None
else 0.0
)
rate_multiplier = (
float(usage.rate_multiplier) if usage.rate_multiplier is not None else 1.0
)
# 提供商名称优先级:关联的 Provider 表 > usage.provider 字段
provider_name = usage.provider
if usage.provider_id and str(usage.provider_id) in provider_map:
provider_name = provider_map[str(usage.provider_id)]
data.append(
{
"id": usage.id,
"user_id": user.id if user else None,
"user_email": user.email if user else "已删除用户",
"username": user.username if user else "已删除用户",
"provider": provider_name,
"model": usage.model,
"target_model": usage.target_model, # 映射后的目标模型名
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"total_tokens": usage.total_tokens,
"cost": float(usage.total_cost_usd),
"actual_cost": actual_cost,
"rate_multiplier": rate_multiplier,
"response_time_ms": usage.response_time_ms,
"created_at": usage.created_at.isoformat(),
"is_stream": usage.is_stream,
"input_price_per_1m": usage.input_price_per_1m,
"output_price_per_1m": usage.output_price_per_1m,
"cache_creation_price_per_1m": usage.cache_creation_price_per_1m,
"cache_read_price_per_1m": usage.cache_read_price_per_1m,
"status_code": usage.status_code,
"error_message": usage.error_message,
"status": usage.status, # 请求状态: pending, streaming, completed, failed
"has_fallback": fallback_map.get(usage.request_id, False),
"api_format": usage.api_format
or (endpoint.api_format if endpoint and endpoint.api_format else None),
"api_key_name": api_key.name if api_key else None,
"request_metadata": usage.request_metadata, # Provider 响应元数据
}
)
return {
"records": data,
"total": total,
"limit": self.limit,
"offset": self.offset,
}
class AdminActiveRequestsAdapter(AdminApiAdapter):
"""轻量级活跃请求状态查询适配器"""
def __init__(self, ids: Optional[str]):
self.ids = ids
async def handle(self, context): # type: ignore[override]
db = context.db
if self.ids:
# 查询指定 ID 的请求状态
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
if not id_list:
return {"requests": []}
records = (
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
.filter(Usage.id.in_(id_list))
.all()
)
else:
# 查询所有活跃请求pending 或 streaming
records = (
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
.filter(Usage.status.in_(["pending", "streaming"]))
.order_by(Usage.created_at.desc())
.limit(50)
.all()
)
return {
"requests": [
{
"id": r.id,
"status": r.status,
"input_tokens": r.input_tokens,
"output_tokens": r.output_tokens,
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
"response_time_ms": r.response_time_ms,
}
for r in records
]
}
@dataclass
class AdminUsageDetailAdapter(AdminApiAdapter):
"""Get detailed usage record with request/response body"""
usage_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
usage_record = db.query(Usage).filter(Usage.id == self.usage_id).first()
if not usage_record:
raise HTTPException(status_code=404, detail="Usage record not found")
user = db.query(User).filter(User.id == usage_record.user_id).first()
api_key = db.query(ApiKey).filter(ApiKey.id == usage_record.api_key_id).first()
# 获取阶梯计费信息
tiered_pricing_info = await self._get_tiered_pricing_info(db, usage_record)
context.add_audit_metadata(
action="usage_detail",
usage_id=self.usage_id,
)
return {
"id": usage_record.id,
"request_id": usage_record.request_id,
"user": {
"id": user.id if user else None,
"username": user.username if user else "Unknown",
"email": user.email if user else None,
},
"api_key": {
"id": api_key.id if api_key else None,
"name": api_key.name if api_key else None,
"display": api_key.get_display_key() if api_key else None,
},
"provider": usage_record.provider,
"api_format": usage_record.api_format,
"model": usage_record.model,
"target_model": usage_record.target_model,
"tokens": {
"input": usage_record.input_tokens,
"output": usage_record.output_tokens,
"total": usage_record.total_tokens,
},
"cost": {
"input": usage_record.input_cost_usd,
"output": usage_record.output_cost_usd,
"total": usage_record.total_cost_usd,
},
"cache_creation_input_tokens": usage_record.cache_creation_input_tokens,
"cache_read_input_tokens": usage_record.cache_read_input_tokens,
"cache_creation_cost": getattr(usage_record, "cache_creation_cost_usd", 0.0),
"cache_read_cost": getattr(usage_record, "cache_read_cost_usd", 0.0),
"request_cost": getattr(usage_record, "request_cost_usd", 0.0),
"input_price_per_1m": usage_record.input_price_per_1m,
"output_price_per_1m": usage_record.output_price_per_1m,
"cache_creation_price_per_1m": usage_record.cache_creation_price_per_1m,
"cache_read_price_per_1m": usage_record.cache_read_price_per_1m,
"price_per_request": usage_record.price_per_request,
"request_type": usage_record.request_type,
"is_stream": usage_record.is_stream,
"status_code": usage_record.status_code,
"error_message": usage_record.error_message,
"response_time_ms": usage_record.response_time_ms,
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
"request_headers": usage_record.request_headers,
"request_body": usage_record.get_request_body(),
"provider_request_headers": usage_record.provider_request_headers,
"response_headers": usage_record.response_headers,
"response_body": usage_record.get_response_body(),
"metadata": usage_record.request_metadata,
"tiered_pricing": tiered_pricing_info,
}
async def _get_tiered_pricing_info(self, db, usage_record) -> dict | None:
"""获取阶梯计费信息"""
from src.services.model.cost import ModelCostService
# 计算总输入上下文(用于阶梯判定):输入 + 缓存创建 + 缓存读取
input_tokens = usage_record.input_tokens or 0
cache_creation_tokens = usage_record.cache_creation_input_tokens or 0
cache_read_tokens = usage_record.cache_read_input_tokens or 0
total_input_context = input_tokens + cache_creation_tokens + cache_read_tokens
# 尝试获取模型的阶梯配置(带来源信息)
cost_service = ModelCostService(db)
pricing_result = await cost_service.get_tiered_pricing_with_source_async(
usage_record.provider, usage_record.model
)
if not pricing_result:
return None
tiered_pricing = pricing_result.get("pricing")
pricing_source = pricing_result.get("source") # 'provider' 或 'global'
if not tiered_pricing or not tiered_pricing.get("tiers"):
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
# 找到命中的阶梯
tier_index = None
matched_tier = None
for i, tier in enumerate(tiers):
up_to = tier.get("up_to")
if up_to is None or total_input_context <= up_to:
tier_index = i
matched_tier = tier
break
# 如果都没匹配,使用最后一个阶梯
if tier_index is None and tiers:
tier_index = len(tiers) - 1
matched_tier = tiers[-1]
return {
"total_input_context": total_input_context,
"tier_index": tier_index,
"tier_count": len(tiers),
"current_tier": matched_tier,
"tiers": tiers,
"source": pricing_source, # 定价来源: 'provider' 或 'global'
}

View File

@@ -0,0 +1,5 @@
"""User admin routes export."""
from .routes import router
__all__ = ["router"]

View File

@@ -0,0 +1,488 @@
"""用户管理 API 端点。"""
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException, translate_pydantic_error
from src.core.logger import logger
from src.database import get_db
from src.models.admin_requests import UpdateUserRequest
from src.models.api import CreateApiKeyRequest, CreateUserRequest
from src.models.database import ApiKey, User, UserRole
from src.services.user.apikey import ApiKeyService
from src.services.user.service import UserService
router = APIRouter(prefix="/api/admin/users", tags=["Admin - Users"])
pipeline = ApiRequestPipeline()
# 管理员端点
@router.post("")
async def create_user_endpoint(request: Request, db: Session = Depends(get_db)):
adapter = AdminCreateUserAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("")
async def list_users(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
role: Optional[str] = None,
is_active: Optional[bool] = None,
db: Session = Depends(get_db),
):
adapter = AdminListUsersAdapter(skip=skip, limit=limit, role=role, is_active=is_active)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{user_id}")
async def get_user(user_id: str, request: Request, db: Session = Depends(get_db)): # UUID
adapter = AdminGetUserAdapter(user_id=user_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{user_id}")
async def update_user(
user_id: str,
request: Request,
db: Session = Depends(get_db),
):
adapter = AdminUpdateUserAdapter(user_id=user_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{user_id}")
async def delete_user(user_id: str, request: Request, db: Session = Depends(get_db)): # UUID
adapter = AdminDeleteUserAdapter(user_id=user_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{user_id}/quota")
async def reset_user_quota(user_id: str, request: Request, db: Session = Depends(get_db)):
"""Reset user quota (set used_usd to 0)"""
adapter = AdminResetUserQuotaAdapter(user_id=user_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{user_id}/api-keys")
async def get_user_api_keys(
user_id: str,
request: Request,
is_active: Optional[bool] = None,
db: Session = Depends(get_db),
):
"""获取用户的所有API Keys不包括独立Keys"""
adapter = AdminGetUserKeysAdapter(user_id=user_id, is_active=is_active)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{user_id}/api-keys")
async def create_user_api_key(
user_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""为用户创建API Key"""
adapter = AdminCreateUserKeyAdapter(user_id=user_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{user_id}/api-keys/{key_id}")
async def delete_user_api_key(
user_id: str,
key_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""删除用户的API Key"""
adapter = AdminDeleteUserKeyAdapter(user_id=user_id, key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ============== 管理员适配器实现 ==============
class AdminCreateUserAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
request = CreateUserRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
try:
role = (
request.role if hasattr(request.role, "value") else UserRole[request.role.upper()]
)
except (KeyError, AttributeError):
raise InvalidRequestException("角色参数不合法")
try:
user = UserService.create_user(
db=db,
email=request.email,
username=request.username,
password=request.password,
role=role,
quota_usd=request.quota_usd,
)
except ValueError as exc:
raise InvalidRequestException(str(exc))
context.add_audit_metadata(
action="create_user",
target_user_id=user.id,
target_email=user.email,
target_username=user.username,
target_role=user.role.value,
quota_usd=user.quota_usd,
is_active=user.is_active,
)
return {
"id": user.id,
"email": user.email,
"username": user.username,
"role": user.role.value,
"allowed_providers": user.allowed_providers,
"allowed_endpoints": user.allowed_endpoints,
"allowed_models": user.allowed_models,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": getattr(user, "total_usd", 0),
"is_active": user.is_active,
"created_at": user.created_at.isoformat(),
}
class AdminListUsersAdapter(AdminApiAdapter):
def __init__(self, skip: int, limit: int, role: Optional[str], is_active: Optional[bool]):
self.skip = skip
self.limit = limit
self.role = role
self.is_active = is_active
async def handle(self, context): # type: ignore[override]
db = context.db
role_enum = UserRole[self.role.upper()] if self.role else None
users = UserService.list_users(db, self.skip, self.limit, role_enum, self.is_active)
return [
{
"id": u.id,
"email": u.email,
"username": u.username,
"role": u.role.value,
"quota_usd": u.quota_usd,
"used_usd": u.used_usd,
"total_usd": getattr(u, "total_usd", 0),
"is_active": u.is_active,
"created_at": u.created_at.isoformat(),
}
for u in users
]
class AdminGetUserAdapter(AdminApiAdapter):
def __init__(self, user_id: str):
self.user_id = user_id
async def handle(self, context): # type: ignore[override]
db = context.db
user = UserService.get_user(db, self.user_id)
if not user:
raise NotFoundException("用户不存在", "user")
context.add_audit_metadata(
action="get_user_detail",
target_user_id=user.id,
target_role=user.role.value,
include_history=bool(user.last_login_at),
)
return {
"id": user.id,
"email": user.email,
"username": user.username,
"role": user.role.value,
"allowed_providers": user.allowed_providers,
"allowed_endpoints": user.allowed_endpoints,
"allowed_models": user.allowed_models,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": getattr(user, "total_usd", 0),
"is_active": user.is_active,
"created_at": user.created_at.isoformat(),
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
"last_login_at": user.last_login_at.isoformat() if user.last_login_at else None,
}
class AdminUpdateUserAdapter(AdminApiAdapter):
def __init__(self, user_id: str):
self.user_id = user_id
async def handle(self, context): # type: ignore[override]
db = context.db
existing_user = UserService.get_user(db, self.user_id)
if not existing_user:
raise NotFoundException("用户不存在", "user")
payload = context.ensure_json_body()
try:
request = UpdateUserRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
update_data = request.model_dump(exclude_unset=True)
if "role" in update_data and update_data["role"]:
if hasattr(update_data["role"], "value"):
update_data["role"] = update_data["role"]
else:
update_data["role"] = UserRole[update_data["role"].upper()]
user = UserService.update_user(db, self.user_id, **update_data)
if not user:
raise NotFoundException("用户不存在", "user")
changed_fields = list(update_data.keys())
context.add_audit_metadata(
action="update_user",
target_user_id=user.id,
updated_fields=changed_fields,
role_before=existing_user.role.value if existing_user.role else None,
role_after=user.role.value,
quota_usd=user.quota_usd,
is_active=user.is_active,
)
return {
"id": user.id,
"email": user.email,
"username": user.username,
"role": user.role.value,
"allowed_providers": user.allowed_providers,
"allowed_endpoints": user.allowed_endpoints,
"allowed_models": user.allowed_models,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": getattr(user, "total_usd", 0),
"is_active": user.is_active,
"created_at": user.created_at.isoformat(),
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
}
class AdminDeleteUserAdapter(AdminApiAdapter):
def __init__(self, user_id: str):
self.user_id = user_id
async def handle(self, context): # type: ignore[override]
db = context.db
user = UserService.get_user(db, self.user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
if user.role == UserRole.ADMIN:
admin_count = db.query(User).filter(User.role == UserRole.ADMIN).count()
if admin_count <= 1:
raise InvalidRequestException("不能删除最后一个管理员账户")
success = UserService.delete_user(db, self.user_id)
if not success:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
context.add_audit_metadata(
action="delete_user",
target_user_id=user.id,
target_email=user.email,
target_role=user.role.value,
)
return {"message": "用户删除成功"}
class AdminResetUserQuotaAdapter(AdminApiAdapter):
def __init__(self, user_id: str):
self.user_id = user_id
async def handle(self, context): # type: ignore[override]
db = context.db
user = UserService.get_user(db, self.user_id)
if not user:
raise NotFoundException("用户不存在", "user")
user.used_usd = 0.0
user.total_usd = getattr(user, "total_usd", 0)
user.updated_at = datetime.now(timezone.utc)
db.commit()
context.add_audit_metadata(
action="reset_user_quota",
target_user_id=user.id,
quota_usd=user.quota_usd,
used_usd=user.used_usd,
total_usd=user.total_usd,
)
return {
"message": "配额已重置",
"user_id": user.id,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": user.total_usd,
}
class AdminGetUserKeysAdapter(AdminApiAdapter):
"""获取用户的API Keys"""
def __init__(self, user_id: str, is_active: Optional[bool]):
self.user_id = user_id
self.is_active = is_active
async def handle(self, context): # type: ignore[override]
db = context.db
# 验证用户存在
user = db.query(User).filter(User.id == self.user_id).first()
if not user:
raise NotFoundException("用户不存在", "user")
# 获取用户的Keys不包括独立Keys
api_keys = ApiKeyService.list_user_api_keys(
db=db, user_id=self.user_id, is_active=self.is_active
)
context.add_audit_metadata(
action="list_user_api_keys",
target_user_id=self.user_id,
total=len(api_keys),
)
return {
"api_keys": [
{
"id": key.id,
"name": key.name,
"key_display": key.get_display_key(),
"is_active": key.is_active,
"total_requests": key.total_requests,
"total_cost_usd": float(key.total_cost_usd or 0),
"rate_limit": key.rate_limit,
"expires_at": key.expires_at.isoformat() if key.expires_at else None,
"last_used_at": key.last_used_at.isoformat() if key.last_used_at else None,
"created_at": key.created_at.isoformat(),
}
for key in api_keys
],
"total": len(api_keys),
"user_email": user.email,
"username": user.username,
}
class AdminCreateUserKeyAdapter(AdminApiAdapter):
"""为用户创建API Key"""
def __init__(self, user_id: str):
self.user_id = user_id
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
key_data = CreateApiKeyRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
# 验证用户存在
user = db.query(User).filter(User.id == self.user_id).first()
if not user:
raise NotFoundException("用户不存在", "user")
# 为用户创建Key不是独立Key
api_key, plain_key = ApiKeyService.create_api_key(
db=db,
user_id=self.user_id,
name=key_data.name,
allowed_providers=key_data.allowed_providers,
allowed_models=key_data.allowed_models,
rate_limit=key_data.rate_limit or 100,
expire_days=key_data.expire_days,
initial_balance_usd=None, # 普通Key不设置余额限制
is_standalone=False, # 不是独立Key
)
logger.info(f"管理员为用户创建API Key: 用户 {user.email}, Key ID {api_key.id}")
context.add_audit_metadata(
action="create_user_api_key",
target_user_id=self.user_id,
key_id=api_key.id,
)
return {
"id": api_key.id,
"key": plain_key, # 只在创建时返回
"name": api_key.name,
"key_display": api_key.get_display_key(),
"rate_limit": api_key.rate_limit,
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"created_at": api_key.created_at.isoformat(),
"message": "API Key创建成功请妥善保存完整密钥",
}
class AdminDeleteUserKeyAdapter(AdminApiAdapter):
"""删除用户的API Key"""
def __init__(self, user_id: str, key_id: str):
self.user_id = user_id
self.key_id = key_id
async def handle(self, context): # type: ignore[override]
db = context.db
# 验证Key存在且属于该用户
api_key = (
db.query(ApiKey)
.filter(
ApiKey.id == self.key_id,
ApiKey.user_id == self.user_id,
ApiKey.is_standalone == False, # 只能删除普通Key
)
.first()
)
if not api_key:
raise NotFoundException("API Key不存在或不属于该用户", "api_key")
db.delete(api_key)
db.commit()
logger.info(f"管理员删除用户API Key: 用户ID {self.user_id}, Key ID {self.key_id}")
context.add_audit_metadata(
action="delete_user_api_key",
target_user_id=self.user_id,
key_id=self.key_id,
)
return {"message": "API Key已删除"}

View File

@@ -0,0 +1,10 @@
'"""Announcement system routers."""'
from fastapi import APIRouter
from .routes import router as announcement_router
router = APIRouter()
router.include_router(announcement_router)
__all__ = ["router"]

View File

@@ -0,0 +1,297 @@
"""公告系统 API 端点。"""
from dataclasses import dataclass
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import ValidationError
from sqlalchemy.orm import Session
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.core.logger import logger
from src.database import get_db
from src.models.api import CreateAnnouncementRequest, UpdateAnnouncementRequest
from src.models.database import User
from src.services.auth.service import AuthService
from src.services.system.announcement import AnnouncementService
router = APIRouter(prefix="/api/announcements", tags=["Announcements"])
pipeline = ApiRequestPipeline()
# ============== 公共端点(所有用户可访问) ==============
@router.get("")
async def list_announcements(
request: Request,
active_only: bool = Query(True, description="只返回有效公告"),
limit: int = Query(50, description="返回数量限制"),
offset: int = Query(0, description="偏移量"),
db: Session = Depends(get_db),
):
"""获取公告列表(包含已读状态)"""
adapter = ListAnnouncementsAdapter(active_only=active_only, limit=limit, offset=offset)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/active")
async def get_active_announcements(
request: Request,
db: Session = Depends(get_db),
):
"""获取当前有效的公告(首页展示)"""
adapter = GetActiveAnnouncementsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{announcement_id}")
async def get_announcement(
announcement_id: str, # UUID
request: Request,
db: Session = Depends(get_db),
):
"""获取单个公告详情"""
adapter = GetAnnouncementAdapter(announcement_id=announcement_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{announcement_id}/read-status")
async def mark_announcement_as_read(
announcement_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""Mark announcement as read"""
adapter = MarkAnnouncementReadAdapter(announcement_id=announcement_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ============== 管理员端点 ==============
@router.post("")
async def create_announcement(
request: Request,
db: Session = Depends(get_db),
):
"""创建公告(管理员权限)"""
adapter = CreateAnnouncementAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{announcement_id}")
async def update_announcement(
announcement_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""更新公告(管理员权限)"""
adapter = UpdateAnnouncementAdapter(announcement_id=announcement_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{announcement_id}")
async def delete_announcement(
announcement_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""删除公告(管理员权限)"""
adapter = DeleteAnnouncementAdapter(announcement_id=announcement_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ============== 用户公告端点 ==============
@router.get("/users/me/unread-count")
async def get_my_unread_announcement_count(
request: Request,
db: Session = Depends(get_db),
):
"""获取我的未读公告数量"""
adapter = UnreadAnnouncementCountAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ============== Pipeline 适配器 ==============
class AnnouncementOptionalAuthAdapter(ApiAdapter):
"""允许匿名访问但可选解析Bearer以获取用户上下文。"""
mode = ApiMode.PUBLIC
async def authorize(self, context): # type: ignore[override]
context.extra["optional_user"] = await self._resolve_optional_user(context)
return None
async def _resolve_optional_user(self, context) -> Optional[User]:
if context.user:
return context.user
authorization = context.request.headers.get("authorization")
if not authorization or not authorization.lower().startswith("bearer "):
return None
token = authorization.replace("Bearer ", "").strip()
try:
payload = await AuthService.verify_token(token)
user_id = payload.get("user_id")
if not user_id:
return None
user = (
context.db.query(User).filter(User.id == user_id, User.is_active.is_(True)).first()
)
return user
except Exception:
return None
def get_optional_user(self, context) -> Optional[User]:
return context.extra.get("optional_user")
@dataclass
class ListAnnouncementsAdapter(AnnouncementOptionalAuthAdapter):
active_only: bool
limit: int
offset: int
async def handle(self, context): # type: ignore[override]
optional_user = self.get_optional_user(context)
return AnnouncementService.get_announcements(
db=context.db,
user_id=optional_user.id if optional_user else None,
active_only=self.active_only,
include_read_status=True if optional_user else False,
limit=self.limit,
offset=self.offset,
)
class GetActiveAnnouncementsAdapter(AnnouncementOptionalAuthAdapter):
async def handle(self, context): # type: ignore[override]
optional_user = self.get_optional_user(context)
return AnnouncementService.get_active_announcements(
db=context.db,
user_id=optional_user.id if optional_user else None,
)
@dataclass
class GetAnnouncementAdapter(AnnouncementOptionalAuthAdapter):
announcement_id: str
async def handle(self, context): # type: ignore[override]
announcement = AnnouncementService.get_announcement(context.db, self.announcement_id)
return {
"id": announcement.id,
"title": announcement.title,
"content": announcement.content,
"type": announcement.type,
"priority": announcement.priority,
"is_pinned": announcement.is_pinned,
"author": {"id": announcement.author.id, "username": announcement.author.username},
"start_time": announcement.start_time,
"end_time": announcement.end_time,
"created_at": announcement.created_at,
"updated_at": announcement.updated_at,
}
class AnnouncementUserAdapter(AuthenticatedApiAdapter):
"""需要登录但不要求管理员的公告适配器基类。"""
pass
class MarkAnnouncementReadAdapter(AnnouncementUserAdapter):
def __init__(self, announcement_id: str):
self.announcement_id = announcement_id
async def handle(self, context): # type: ignore[override]
AnnouncementService.mark_as_read(context.db, self.announcement_id, context.user.id)
return {"message": "公告已标记为已读"}
class UnreadAnnouncementCountAdapter(AnnouncementUserAdapter):
async def handle(self, context): # type: ignore[override]
result = AnnouncementService.get_announcements(
db=context.db,
user_id=context.user.id,
active_only=True,
include_read_status=True,
limit=1,
offset=0,
)
return {"unread_count": result.get("unread_count", 0)}
class CreateAnnouncementAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
payload = context.ensure_json_body()
try:
req = CreateAnnouncementRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
announcement = AnnouncementService.create_announcement(
db=context.db,
author_id=context.user.id,
title=req.title,
content=req.content,
type=req.type,
priority=req.priority,
is_pinned=req.is_pinned,
start_time=req.start_time,
end_time=req.end_time,
)
return {"id": announcement.id, "title": announcement.title, "message": "公告创建成功"}
@dataclass
class UpdateAnnouncementAdapter(AdminApiAdapter):
announcement_id: str
async def handle(self, context): # type: ignore[override]
payload = context.ensure_json_body()
try:
req = UpdateAnnouncementRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
AnnouncementService.update_announcement(
db=context.db,
announcement_id=self.announcement_id,
user_id=context.user.id,
title=req.title,
content=req.content,
type=req.type,
priority=req.priority,
is_active=req.is_active,
is_pinned=req.is_pinned,
start_time=req.start_time,
end_time=req.end_time,
)
return {"message": "公告更新成功"}
@dataclass
class DeleteAnnouncementAdapter(AdminApiAdapter):
announcement_id: str
async def handle(self, context): # type: ignore[override]
AnnouncementService.delete_announcement(context.db, self.announcement_id, context.user.id)
return {"message": "公告已删除"}

10
src/api/auth/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
"""Authentication route group."""
from fastapi import APIRouter
from .routes import router as auth_router
router = APIRouter()
router.include_router(auth_router)
__all__ = ["router"]

353
src/api/auth/routes.py Normal file
View File

@@ -0,0 +1,353 @@
"""
认证相关API端点
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import ValidationError
from sqlalchemy.orm import Session
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException
from src.core.logger import logger
from src.database import get_db
from src.models.api import (
LoginRequest,
LoginResponse,
LogoutResponse,
RefreshTokenRequest,
RefreshTokenResponse,
RegisterRequest,
RegisterResponse,
)
from src.models.database import AuditEventType, User, UserRole
from src.services.auth.service import AuthService
from src.services.rate_limit.ip_limiter import IPRateLimiter
from src.services.system.audit import AuditService
from src.services.user.service import UserService
from src.utils.request_utils import get_client_ip, get_user_agent
router = APIRouter(prefix="/api/auth", tags=["Authentication"])
security = HTTPBearer()
pipeline = ApiRequestPipeline()
# API端点
@router.post("/login", response_model=LoginResponse)
async def login(request: Request, db: Session = Depends(get_db)):
adapter = AuthLoginAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/refresh", response_model=RefreshTokenResponse)
async def refresh_token(request: Request, db: Session = Depends(get_db)):
adapter = AuthRefreshAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/register", response_model=RegisterResponse)
async def register(request: Request, db: Session = Depends(get_db)):
adapter = AuthRegisterAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/me")
async def get_current_user_info(request: Request, db: Session = Depends(get_db)):
adapter = AuthCurrentUserAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/password")
async def change_password(request: Request, db: Session = Depends(get_db)):
"""Change current user's password"""
adapter = AuthChangePasswordAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/logout", response_model=LogoutResponse)
async def logout(request: Request, db: Session = Depends(get_db)):
adapter = AuthLogoutAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ============== 适配器实现 ==============
class AuthPublicAdapter(ApiAdapter):
mode = ApiMode.PUBLIC
def authorize(self, context): # type: ignore[override]
return None
class AuthLoginAdapter(AuthPublicAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
login_request = LoginRequest.model_validate(payload)
except ValidationError as exc:
errors = []
for error in exc.errors():
field = " -> ".join(str(x) for x in error["loc"])
errors.append(f"{field}: {error['msg']}")
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
client_ip = get_client_ip(context.request)
user_agent = get_user_agent(context.request)
# IP 速率限制检查登录接口5次/分钟)
allowed, remaining, reset_after = await IPRateLimiter.check_limit(client_ip, "login")
if not allowed:
logger.warning(f"登录请求超过速率限制: IP={client_ip}, 剩余={remaining}")
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"登录请求过于频繁,请在 {reset_after} 秒后重试",
)
user = await AuthService.authenticate_user(db, login_request.email, login_request.password)
if not user:
AuditService.log_login_attempt(
db=db,
email=login_request.email,
success=False,
ip_address=client_ip,
user_agent=user_agent,
error_reason="邮箱或密码错误",
)
db.commit()
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="邮箱或密码错误")
AuditService.log_login_attempt(
db=db,
email=login_request.email,
success=True,
ip_address=client_ip,
user_agent=user_agent,
user_id=user.id,
)
db.commit()
access_token = AuthService.create_access_token(
data={
"user_id": user.id,
"email": user.email,
"role": user.role.value,
"created_at": user.created_at.isoformat() if user.created_at else None,
}
)
refresh_token = AuthService.create_refresh_token(
data={"user_id": user.id, "email": user.email}
)
response = LoginResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=86400,
user_id=user.id,
email=user.email,
username=user.username,
role=user.role.value,
)
return response.model_dump()
class AuthRefreshAdapter(AuthPublicAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
refresh_request = RefreshTokenRequest.model_validate(payload)
client_ip = get_client_ip(context.request)
user_agent = get_user_agent(context.request)
try:
token_payload = await AuthService.verify_token(
refresh_request.refresh_token, token_type="refresh"
)
user_id = token_payload.get("user_id")
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的刷新令牌"
)
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的刷新令牌"
)
if not user.is_active:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已禁用")
new_access_token = AuthService.create_access_token(
data={
"user_id": user.id,
"email": user.email,
"role": user.role.value,
"created_at": user.created_at.isoformat() if user.created_at else None,
}
)
new_refresh_token = AuthService.create_refresh_token(
data={"user_id": user.id, "email": user.email}
)
logger.info(f"令牌刷新成功: {user.email}")
return RefreshTokenResponse(
access_token=new_access_token,
refresh_token=new_refresh_token,
token_type="bearer",
expires_in=86400,
).model_dump()
except HTTPException:
raise
except Exception as exc:
logger.error(f"刷新令牌失败: {exc}")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="刷新令牌失败")
class AuthRegisterAdapter(AuthPublicAdapter):
async def handle(self, context): # type: ignore[override]
from ..models.database import SystemConfig
db = context.db
payload = context.ensure_json_body()
register_request = RegisterRequest.model_validate(payload)
client_ip = get_client_ip(context.request)
user_agent = get_user_agent(context.request)
# IP 速率限制检查注册接口3次/分钟)
allowed, remaining, reset_after = await IPRateLimiter.check_limit(client_ip, "register")
if not allowed:
logger.warning(f"注册请求超过速率限制: IP={client_ip}, 剩余={remaining}")
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"注册请求过于频繁,请在 {reset_after} 秒后重试",
)
allow_registration = db.query(SystemConfig).filter_by(key="enable_registration").first()
if allow_registration and not allow_registration.value:
AuditService.log_event(
db=db,
event_type=AuditEventType.UNAUTHORIZED_ACCESS,
description=f"Registration attempt rejected - registration disabled: {register_request.email}",
ip_address=client_ip,
user_agent=user_agent,
metadata={"email": register_request.email, "reason": "registration_disabled"},
)
db.commit()
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="系统暂不开放注册")
try:
user = UserService.create_user(
db=db,
email=register_request.email,
username=register_request.username,
password=register_request.password,
role=UserRole.USER,
)
AuditService.log_event(
db=db,
event_type=AuditEventType.USER_CREATED,
description=f"User registered: {user.email}",
user_id=user.id,
ip_address=client_ip,
user_agent=user_agent,
metadata={"email": user.email, "username": user.username, "role": user.role.value},
)
db.commit()
return RegisterResponse(
user_id=user.id,
email=user.email,
username=user.username,
message="注册成功",
).model_dump()
except ValueError as exc:
AuditService.log_event(
db=db,
event_type=AuditEventType.UNAUTHORIZED_ACCESS,
description=f"Registration failed: {register_request.email} - {exc}",
ip_address=client_ip,
user_agent=user_agent,
metadata={"email": register_request.email, "error": str(exc)},
)
db.commit()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc))
class AuthCurrentUserAdapter(AuthenticatedApiAdapter):
async def handle(self, context): # type: ignore[override]
user = context.user
return {
"id": user.id,
"email": user.email,
"username": user.username,
"role": user.role.value,
"is_active": user.is_active,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": user.total_usd,
"allowed_providers": user.allowed_providers,
"allowed_endpoints": user.allowed_endpoints,
"allowed_models": user.allowed_models,
"created_at": user.created_at.isoformat(),
"last_login_at": user.last_login_at.isoformat() if user.last_login_at else None,
}
class AuthChangePasswordAdapter(AuthenticatedApiAdapter):
async def handle(self, context): # type: ignore[override]
payload = context.ensure_json_body()
old_password = payload.get("old_password")
new_password = payload.get("new_password")
if not old_password or not new_password:
raise HTTPException(status_code=400, detail="必须提供旧密码和新密码")
user = context.user
if not user.verify_password(old_password):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="旧密码错误")
if len(new_password) < 8:
raise InvalidRequestException("密码长度至少8位")
user.set_password(new_password)
context.db.commit()
logger.info(f"用户修改密码: {user.email}")
return {"message": "密码修改成功"}
class AuthLogoutAdapter(AuthenticatedApiAdapter):
async def handle(self, context): # type: ignore[override]
"""用户登出,将 Token 加入黑名单"""
user = context.user
client_ip = get_client_ip(context.request)
# 从 Authorization header 获取 Token
auth_header = context.request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="缺少认证令牌")
token = auth_header.replace("Bearer ", "")
# 将 Token 加入黑名单
success = await AuthService.logout(token)
if success:
# 记录审计日志
AuditService.log_event(
db=context.db,
event_type=AuditEventType.LOGOUT,
description=f"User logged out: {user.email}",
user_id=user.id,
ip_address=client_ip,
user_agent=get_user_agent(context.request),
metadata={"user_id": user.id, "email": user.email},
)
context.db.commit()
logger.info(f"用户登出成功: {user.email}")
return LogoutResponse(message="登出成功", success=True).model_dump()
else:
logger.warning(f"用户登出失败Redis不可用: {user.email}")
return LogoutResponse(message="登出成功(降级模式)", success=False).model_dump()

82
src/api/base/adapter.py Normal file
View File

@@ -0,0 +1,82 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional
from fastapi import Request, Response
from .context import ApiRequestContext
class ApiMode(str, Enum):
STANDARD = "standard"
PROXY = "proxy"
ADMIN = "admin"
USER = "user" # JWT 认证的普通用户(不要求管理员权限)
PUBLIC = "public"
class ApiAdapter(ABC):
"""所有API格式适配器的抽象基类。"""
name: str = "base"
mode: ApiMode = ApiMode.STANDARD
api_format: Optional[str] = None # 对应 Provider API 格式提示
audit_log_enabled: bool = True
audit_success_event = None
audit_failure_event = None
@abstractmethod
async def handle(self, context: ApiRequestContext) -> Response:
"""处理请求并返回 FastAPI Response。"""
def authorize(self, context: ApiRequestContext) -> None:
"""可选的授权钩子,默认允许通过。"""
return None
def extract_api_key(self, request: Request) -> Optional[str]:
"""
从请求中提取客户端 API 密钥。
子类应覆盖此方法以支持各自的认证头格式。
Args:
request: FastAPI Request 对象
Returns:
提取的 API 密钥,如果未找到则返回 None
"""
return None
def get_audit_metadata(
self,
context: ApiRequestContext,
*,
success: bool,
status_code: Optional[int],
error: Optional[str] = None,
) -> Dict[str, Any]:
"""允许适配器在审计日志中追加自定义字段。"""
return {}
def detect_capability_requirements(
self,
headers: Dict[str, str],
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""
检测请求中隐含的能力需求(子类可覆盖)
不同 API 格式有不同的能力声明方式,例如:
- Claude: anthropic-beta: context-1m-xxx 表示需要 1M 上下文
- 其他格式可能有不同的请求头或请求体字段
Args:
headers: 请求头字典
request_body: 请求体字典(可选)
Returns:
检测到的能力需求,如 {"context_1m": True}
"""
return {}

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from fastapi import HTTPException
from src.models.database import UserRole
from .adapter import ApiAdapter, ApiMode
from .context import ApiRequestContext
class AdminApiAdapter(ApiAdapter):
"""管理员端点适配器基类,提供统一的权限校验。"""
mode = ApiMode.ADMIN
required_roles: tuple[UserRole, ...] = (UserRole.ADMIN,)
def authorize(self, context: ApiRequestContext) -> None:
user = context.user
if not user:
raise HTTPException(status_code=401, detail="未登录")
# 检查是否使用独立余额Key访问管理接口
if context.api_key and context.api_key.is_standalone:
raise HTTPException(
status_code=403, detail="独立余额Key不允许访问管理接口仅可用于代理请求"
)
if not any(user.role == role for role in self.required_roles):
raise HTTPException(status_code=403, detail="需要管理员权限")

View File

@@ -0,0 +1,13 @@
from fastapi import HTTPException
from .adapter import ApiAdapter, ApiMode
class AuthenticatedApiAdapter(ApiAdapter):
"""通用需要登录的适配器基类。"""
mode = ApiMode.USER
def authorize(self, context): # type: ignore[override]
if not context.user:
raise HTTPException(status_code=401, detail="未登录")

116
src/api/base/context.py Normal file
View File

@@ -0,0 +1,116 @@
from __future__ import annotations
import json
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import ApiKey, User
@dataclass
class ApiRequestContext:
"""统一的API请求上下文贯穿Pipeline与格式适配器。"""
request: Request
db: Session
user: Optional[User]
api_key: Optional[ApiKey]
request_id: str
start_time: float
client_ip: str
user_agent: str
original_headers: Dict[str, str]
query_params: Dict[str, str]
raw_body: bytes | None = None
json_body: Optional[Dict[str, Any]] = None
quota_remaining: Optional[float] = None
mode: str = "standard" # standard / proxy
api_format_hint: Optional[str] = None
# URL 路径参数(如 Gemini API 的 /v1beta/models/{model}:generateContent
path_params: Dict[str, Any] = field(default_factory=dict)
# 供适配器扩展的状态存储
extra: Dict[str, Any] = field(default_factory=dict)
audit_metadata: Dict[str, Any] = field(default_factory=dict)
def ensure_json_body(self) -> Dict[str, Any]:
"""确保请求体已解析为JSON并返回。"""
if self.json_body is not None:
return self.json_body
if not self.raw_body:
raise HTTPException(status_code=400, detail="请求体不能为空")
try:
self.json_body = json.loads(self.raw_body.decode("utf-8"))
except json.JSONDecodeError as exc:
logger.warning(f"解析JSON失败: {exc}")
raise HTTPException(status_code=400, detail="请求体必须是合法的JSON") from exc
return self.json_body
def add_audit_metadata(self, **values: Any) -> None:
"""向审计日志附加字段(会自动过滤 None"""
for key, value in values.items():
if value is not None:
self.audit_metadata[key] = value
def extend_audit_metadata(self, data: Dict[str, Any]) -> None:
"""批量附加审计字段。"""
for key, value in data.items():
if value is not None:
self.audit_metadata[key] = value
@classmethod
def build(
cls,
request: Request,
db: Session,
user: Optional[User],
api_key: Optional[ApiKey],
raw_body: Optional[bytes] = None,
mode: str = "standard",
api_format_hint: Optional[str] = None,
path_params: Optional[Dict[str, Any]] = None,
) -> "ApiRequestContext":
"""创建上下文实例并提前读取必要的元数据。"""
request_id = getattr(request.state, "request_id", None) or str(uuid.uuid4())[:8]
setattr(request.state, "request_id", request_id)
start_time = time.time()
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
context = cls(
request=request,
db=db,
user=user,
api_key=api_key,
request_id=request_id,
start_time=start_time,
client_ip=client_ip,
user_agent=user_agent,
original_headers=dict(request.headers),
query_params=dict(request.query_params),
raw_body=raw_body,
mode=mode,
api_format_hint=api_format_hint,
path_params=path_params or {},
)
# 便于插件/日志引用
request.state.request_id = request_id
if user:
request.state.user_id = user.id
if api_key:
request.state.api_key_id = api_key.id
return context

View File

@@ -0,0 +1,49 @@
from __future__ import annotations
from dataclasses import asdict, dataclass
from typing import List, Sequence, Tuple, TypeVar
from sqlalchemy.orm import Query
T = TypeVar("T")
@dataclass
class PaginationMeta:
total: int
limit: int
offset: int
count: int
def to_dict(self) -> dict:
return asdict(self)
def paginate_query(query: Query, limit: int, offset: int) -> Tuple[int, List[T]]:
"""
对 SQLAlchemy 查询应用 limit/offset并返回总数与结果列表。
"""
total = query.order_by(None).count()
records = query.offset(offset).limit(limit).all()
return total, records
def paginate_sequence(
items: Sequence[T], limit: int, offset: int
) -> Tuple[List[T], PaginationMeta]:
"""
对内存序列应用分页,返回切片和元数据。
"""
total = len(items)
sliced = list(items[offset : offset + limit])
meta = PaginationMeta(total=total, limit=limit, offset=offset, count=len(sliced))
return sliced, meta
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra) -> dict:
"""
构建标准分页响应 payload。
"""
payload = {"items": items, "meta": meta.to_dict()}
payload.update(extra)
return payload

387
src/api/base/pipeline.py Normal file
View File

@@ -0,0 +1,387 @@
from __future__ import annotations
import time
from enum import Enum
from typing import Any, Optional, Tuple
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session, sessionmaker
from src.core.exceptions import QuotaExceededException
from src.core.logger import logger
from src.models.database import ApiKey, AuditEventType, User, UserRole
from src.services.auth.service import AuthService
from src.services.cache.user_cache import UserCacheService
from src.services.system.audit import AuditService
from src.services.usage.service import UsageService
from .adapter import ApiAdapter, ApiMode
from .context import ApiRequestContext
class ApiRequestPipeline:
"""负责统一执行认证、配额校验、上下文构建等通用逻辑的管道。"""
def __init__(
self,
auth_service: AuthService = AuthService,
usage_service: UsageService = UsageService,
audit_service: AuditService = AuditService,
):
self.auth_service = auth_service
self.usage_service = usage_service
self.audit_service = audit_service
async def run(
self,
adapter: ApiAdapter,
http_request: Request,
db: Session,
*,
mode: ApiMode = ApiMode.STANDARD,
api_format_hint: Optional[str] = None,
path_params: Optional[dict[str, Any]] = None,
):
logger.debug(f"[Pipeline] START | path={http_request.url.path}")
logger.debug(f"[Pipeline] Running with mode={mode}, adapter={adapter.__class__.__name__}, "
f"adapter.mode={adapter.mode}, path={http_request.url.path}")
if mode == ApiMode.ADMIN:
user = await self._authenticate_admin(http_request, db)
api_key = None
elif mode == ApiMode.USER:
user = await self._authenticate_user(http_request, db)
api_key = None
elif mode == ApiMode.PUBLIC:
user = None
api_key = None
else:
logger.debug("[Pipeline] 调用 _authenticate_client")
user, api_key = self._authenticate_client(http_request, db, adapter)
logger.debug(f"[Pipeline] 认证完成 | user={user.username if user else None}")
raw_body = None
if http_request.method in {"POST", "PUT", "PATCH"}:
try:
import asyncio
# 添加30秒超时防止卡死
raw_body = await asyncio.wait_for(http_request.body(), timeout=30.0)
logger.debug(f"[Pipeline] Raw body读取完成 | size={len(raw_body) if raw_body is not None else 0} bytes")
except asyncio.TimeoutError:
logger.error("读取请求体超时(30s),可能客户端未发送完整请求体")
raise HTTPException(
status_code=408, detail="Request timeout: body not received within 30 seconds"
)
else:
logger.debug(f"[Pipeline] 非写请求跳过读取Body | method={http_request.method}")
context = ApiRequestContext.build(
request=http_request,
db=db,
user=user,
api_key=api_key,
raw_body=raw_body,
mode=mode.value,
api_format_hint=api_format_hint,
path_params=path_params,
)
logger.debug(f"[Pipeline] Context构建完成 | adapter={adapter.name} | request_id={context.request_id}")
if mode != ApiMode.ADMIN and user:
context.quota_remaining = self._calculate_quota_remaining(user)
logger.debug(f"[Pipeline] Adapter={adapter.name} | RequestID={context.request_id}")
logger.debug(f"[Pipeline] Calling authorize on {adapter.__class__.__name__}, user={context.user}")
# authorize 可能是异步的,需要检查并 await
authorize_result = adapter.authorize(context)
if hasattr(authorize_result, "__await__"):
await authorize_result
try:
response = await adapter.handle(context)
status_code = getattr(response, "status_code", None)
self._record_audit_event(context, adapter, success=True, status_code=status_code)
return response
except HTTPException as exc:
err_detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
self._record_audit_event(
context,
adapter,
success=False,
status_code=exc.status_code,
error=err_detail,
)
raise
except Exception as exc:
self._record_audit_event(
context,
adapter,
success=False,
status_code=500,
error=str(exc),
)
raise
# --------------------------------------------------------------------- #
# Internal helpers
# --------------------------------------------------------------------- #
def _authenticate_client(
self, request: Request, db: Session, adapter: ApiAdapter
) -> Tuple[User, ApiKey]:
logger.debug("[Pipeline._authenticate_client] 开始")
# 使用 adapter 的 extract_api_key 方法,支持不同 API 格式的认证头
client_api_key = adapter.extract_api_key(request)
logger.debug(f"[Pipeline._authenticate_client] 提取API密钥完成 | key_prefix={client_api_key[:8] if client_api_key else None}...")
if not client_api_key:
raise HTTPException(status_code=401, detail="请提供API密钥")
logger.debug("[Pipeline._authenticate_client] 调用 auth_service.authenticate_api_key")
auth_result = self.auth_service.authenticate_api_key(db, client_api_key)
logger.debug(f"[Pipeline._authenticate_client] 认证结果 | result={bool(auth_result)}")
if not auth_result:
raise HTTPException(status_code=401, detail="无效的API密钥")
user, api_key = auth_result
if not user or not api_key:
raise HTTPException(status_code=401, detail="无效的API密钥")
request.state.user_id = user.id
request.state.api_key_id = api_key.id
# 检查配额或余额支持独立Key
quota_ok, message = self.usage_service.check_user_quota(db, user, api_key=api_key)
if not quota_ok:
# 根据Key类型计算剩余额度
if api_key.is_standalone:
# 独立Key显示剩余余额
remaining = (
None
if api_key.current_balance_usd is None
else float(api_key.current_balance_usd - (api_key.balance_used_usd or 0))
)
else:
# 普通Key显示用户配额剩余
remaining = (
None
if user.quota_usd is None or user.quota_usd < 0
else float(user.quota_usd - user.used_usd)
)
raise QuotaExceededException(quota_type="USD", remaining=remaining)
return user, api_key
async def _authenticate_admin(self, request: Request, db: Session) -> User:
authorization = request.headers.get("authorization")
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="缺少管理员凭证")
token = authorization.replace("Bearer ", "").strip()
try:
payload = await self.auth_service.verify_token(token)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Admin token 验证失败: {exc}")
raise HTTPException(status_code=401, detail="无效的管理员令牌")
user_id = payload.get("user_id")
if not user_id:
raise HTTPException(status_code=401, detail="无效的管理员令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
request.state.user_id = user.id
return user
async def _authenticate_user(self, request: Request, db: Session) -> User:
"""JWT 认证普通用户(不要求管理员权限)"""
authorization = request.headers.get("authorization")
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="缺少用户凭证")
token = authorization.replace("Bearer ", "").strip()
try:
payload = await self.auth_service.verify_token(token)
except HTTPException:
raise
except Exception as exc:
logger.error(f"User token 验证失败: {exc}")
raise HTTPException(status_code=401, detail="无效的用户令牌")
user_id = payload.get("user_id")
if not user_id:
raise HTTPException(status_code=401, detail="无效的用户令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
request.state.user_id = user.id
return user
def _calculate_quota_remaining(self, user: Optional[User]) -> Optional[float]:
if not user:
return None
if user.quota_usd is None or user.quota_usd < 0:
return None
return max(float(user.quota_usd - user.used_usd), 0.0)
def _record_audit_event(
self,
context: ApiRequestContext,
adapter: ApiAdapter,
*,
success: bool,
status_code: Optional[int] = None,
error: Optional[str] = None,
) -> None:
if not getattr(adapter, "audit_log_enabled", True):
return
bind = context.db.get_bind()
if bind is None:
return
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
if not event_type:
if not success and status_code == 401:
event_type = AuditEventType.UNAUTHORIZED_ACCESS
else:
event_type = (
AuditEventType.REQUEST_SUCCESS if success else AuditEventType.REQUEST_FAILED
)
metadata = self._build_audit_metadata(
context=context,
adapter=adapter,
success=success,
status_code=status_code,
error=error,
)
SessionMaker = sessionmaker(bind=bind)
audit_session = SessionMaker()
try:
self.audit_service.log_event(
db=audit_session,
event_type=event_type,
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
user_id=context.user.id if context.user else None,
api_key_id=context.api_key.id if context.api_key else None,
ip_address=context.client_ip,
user_agent=context.user_agent,
request_id=context.request_id,
status_code=status_code,
error_message=error,
metadata=metadata,
)
audit_session.commit()
except Exception as exc:
audit_session.rollback()
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
finally:
audit_session.close()
def _build_audit_metadata(
self,
context: ApiRequestContext,
adapter: ApiAdapter,
*,
success: bool,
status_code: Optional[int],
error: Optional[str],
) -> dict:
duration_ms = max((time.time() - context.start_time) * 1000, 0.0)
request = context.request
path_params = {}
try:
path_params = dict(getattr(request, "path_params", {}) or {})
except Exception:
path_params = {}
metadata: dict[str, Any] = {
"path": request.url.path,
"path_params": path_params,
"method": request.method,
"adapter": adapter.name,
"adapter_class": adapter.__class__.__name__,
"adapter_mode": getattr(adapter.mode, "value", str(adapter.mode)),
"mode": context.mode,
"api_format_hint": context.api_format_hint,
"query": context.query_params,
"duration_ms": round(duration_ms, 2),
"request_body_bytes": len(context.raw_body or b""),
"has_body": bool(context.raw_body),
"request_content_type": request.headers.get("content-type"),
"quota_remaining": context.quota_remaining,
"success": success,
}
if status_code is not None:
metadata["status_code"] = status_code
if context.user and getattr(context.user, "role", None):
role = context.user.role
metadata["user_role"] = getattr(role, "value", role)
if context.api_key:
if getattr(context.api_key, "name", None):
metadata["api_key_name"] = context.api_key.name
# 使用脱敏后的密钥显示
if hasattr(context.api_key, "get_display_key"):
metadata["api_key_display"] = context.api_key.get_display_key()
extra_details: dict[str, Any] = {}
if context.audit_metadata:
extra_details.update(context.audit_metadata)
try:
adapter_details = adapter.get_audit_metadata(
context,
success=success,
status_code=status_code,
error=error,
)
if adapter_details:
extra_details.update(adapter_details)
except Exception as exc:
logger.warning(f"[Audit] Adapter metadata failed: {adapter.__class__.__name__}: {exc}")
if extra_details:
metadata["details"] = extra_details
if error:
metadata["error"] = error
return self._sanitize_metadata(metadata)
def _sanitize_metadata(self, value: Any, depth: int = 0):
if value is None:
return None
if depth > 5:
return str(value)
if isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, Enum):
return value.value
if isinstance(value, dict):
sanitized = {}
for key, val in value.items():
cleaned = self._sanitize_metadata(val, depth + 1)
if cleaned is not None:
sanitized[str(key)] = cleaned
return sanitized
if isinstance(value, (list, tuple, set)):
return [self._sanitize_metadata(item, depth + 1) for item in value]
if hasattr(value, "isoformat"):
try:
return value.isoformat()
except Exception:
return str(value)
return str(value)

View File

@@ -0,0 +1,10 @@
'"""Dashboard API routers."""'
from fastapi import APIRouter
from .routes import router as dashboard_router
router = APIRouter()
router.include_router(dashboard_router)
__all__ = ["router"]

905
src/api/dashboard/routes.py Normal file
View File

@@ -0,0 +1,905 @@
"""仪表盘统计 API 端点。"""
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy import and_, func
from sqlalchemy.orm import Session
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import UserRole
from src.database import get_db
from src.models.database import ApiKey, Provider, RequestCandidate, StatsDaily, Usage
from src.models.database import User as DBUser
from src.services.system.stats_aggregator import StatsAggregatorService
from src.utils.cache_decorator import cache_result
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
pipeline = ApiRequestPipeline()
def format_tokens(num: int) -> str:
"""格式化 Token 数量,自动转换 K/M 单位"""
if num < 1000:
return str(num)
if num < 1000000:
thousands = num / 1000
if thousands >= 100:
return f"{round(thousands)}K"
elif thousands >= 10:
return f"{thousands:.1f}K"
else:
return f"{thousands:.2f}K"
millions = num / 1000000
if millions >= 100:
return f"{round(millions)}M"
elif millions >= 10:
return f"{millions:.1f}M"
else:
return f"{millions:.2f}M"
@router.get("/stats")
async def get_dashboard_stats(request: Request, db: Session = Depends(get_db)):
adapter = DashboardStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/recent-requests")
async def get_recent_requests(
request: Request,
limit: int = Query(10, ge=1, le=100),
db: Session = Depends(get_db),
):
adapter = DashboardRecentRequestsAdapter(limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# NOTE: /request-detail/{request_id} has been moved to /api/admin/usage/{id}
# The old route is removed. Use dashboardApi.getRequestDetail() which now calls the new API.
@router.get("/provider-status")
async def get_provider_status(request: Request, db: Session = Depends(get_db)):
adapter = DashboardProviderStatusAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/daily-stats")
async def get_daily_stats(
request: Request,
days: int = Query(7, ge=1, le=30),
db: Session = Depends(get_db),
):
adapter = DashboardDailyStatsAdapter(days=days)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class DashboardAdapter(ApiAdapter):
"""需要登录的仪表盘适配器基类。"""
mode = ApiMode.ADMIN
def authorize(self, context): # type: ignore[override]
if not context.user:
raise HTTPException(status_code=401, detail="未登录")
class DashboardStatsAdapter(DashboardAdapter):
async def handle(self, context): # type: ignore[override]
user = context.user
if not user:
raise HTTPException(status_code=401, detail="未登录")
adapter = (
AdminDashboardStatsAdapter()
if user.role == UserRole.ADMIN
else UserDashboardStatsAdapter()
)
return await adapter.handle(context)
class AdminDashboardStatsAdapter(AdminApiAdapter):
@cache_result(key_prefix="dashboard:admin:stats", ttl=60, user_specific=False)
async def handle(self, context): # type: ignore[override]
"""管理员仪表盘统计 - 使用预聚合数据优化性能"""
db = context.db
now = datetime.now(timezone.utc)
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
yesterday = today - timedelta(days=1)
last_month = today - timedelta(days=30)
# ==================== 使用预聚合数据 ====================
# 从 stats_summary + 今日实时数据获取全局统计
combined_stats = StatsAggregatorService.get_combined_stats(db)
all_time_requests = combined_stats["total_requests"]
all_time_success_requests = combined_stats["success_requests"]
all_time_error_requests = combined_stats["error_requests"]
all_time_input_tokens = combined_stats["input_tokens"]
all_time_output_tokens = combined_stats["output_tokens"]
all_time_cache_creation = combined_stats["cache_creation_tokens"]
all_time_cache_read = combined_stats["cache_read_tokens"]
all_time_cost = combined_stats["total_cost"]
all_time_actual_cost = combined_stats["actual_total_cost"]
# 用户/API Key 统计
total_users = combined_stats.get("total_users") or db.query(func.count(DBUser.id)).scalar()
active_users = combined_stats.get("active_users") or (
db.query(func.count(DBUser.id)).filter(DBUser.is_active.is_(True)).scalar()
)
total_api_keys = combined_stats.get("total_api_keys") or db.query(func.count(ApiKey.id)).scalar()
active_api_keys = combined_stats.get("active_api_keys") or (
db.query(func.count(ApiKey.id)).filter(ApiKey.is_active.is_(True)).scalar()
)
# ==================== 今日实时统计 ====================
today_stats = StatsAggregatorService.get_today_realtime_stats(db)
requests_today = today_stats["total_requests"]
cost_today = today_stats["total_cost"]
actual_cost_today = today_stats["actual_total_cost"]
input_tokens_today = today_stats["input_tokens"]
output_tokens_today = today_stats["output_tokens"]
cache_creation_today = today_stats["cache_creation_tokens"]
cache_read_today = today_stats["cache_read_tokens"]
tokens_today = (
input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today
)
# ==================== 昨日统计(从预聚合表获取)====================
yesterday_stats = db.query(StatsDaily).filter(StatsDaily.date == yesterday).first()
if yesterday_stats:
requests_yesterday = yesterday_stats.total_requests
cost_yesterday = yesterday_stats.total_cost
input_tokens_yesterday = yesterday_stats.input_tokens
output_tokens_yesterday = yesterday_stats.output_tokens
cache_creation_yesterday = yesterday_stats.cache_creation_tokens
cache_read_yesterday = yesterday_stats.cache_read_tokens
else:
# 如果没有预聚合数据,回退到实时查询
requests_yesterday = (
db.query(func.count(Usage.id))
.filter(Usage.created_at >= yesterday, Usage.created_at < today)
.scalar() or 0
)
cost_yesterday = (
db.query(func.sum(Usage.total_cost_usd))
.filter(Usage.created_at >= yesterday, Usage.created_at < today)
.scalar() or 0
)
yesterday_token_stats = (
db.query(
func.sum(Usage.input_tokens).label("input_tokens"),
func.sum(Usage.output_tokens).label("output_tokens"),
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
)
.filter(Usage.created_at >= yesterday, Usage.created_at < today)
.first()
)
input_tokens_yesterday = int(yesterday_token_stats.input_tokens or 0) if yesterday_token_stats else 0
output_tokens_yesterday = int(yesterday_token_stats.output_tokens or 0) if yesterday_token_stats else 0
cache_creation_yesterday = int(yesterday_token_stats.cache_creation_tokens or 0) if yesterday_token_stats else 0
cache_read_yesterday = int(yesterday_token_stats.cache_read_tokens or 0) if yesterday_token_stats else 0
# ==================== 本月统计(从预聚合表聚合)====================
monthly_stats = (
db.query(
func.sum(StatsDaily.total_requests).label("total_requests"),
func.sum(StatsDaily.error_requests).label("error_requests"),
func.sum(StatsDaily.total_cost).label("total_cost"),
func.sum(StatsDaily.actual_total_cost).label("actual_total_cost"),
func.sum(StatsDaily.input_tokens + StatsDaily.output_tokens +
StatsDaily.cache_creation_tokens + StatsDaily.cache_read_tokens).label("total_tokens"),
func.sum(StatsDaily.cache_creation_tokens).label("cache_creation_tokens"),
func.sum(StatsDaily.cache_read_tokens).label("cache_read_tokens"),
func.sum(StatsDaily.cache_creation_cost).label("cache_creation_cost"),
func.sum(StatsDaily.cache_read_cost).label("cache_read_cost"),
func.sum(StatsDaily.fallback_count).label("fallback_count"),
)
.filter(StatsDaily.date >= last_month, StatsDaily.date < today)
.first()
)
# 本月数据 = 预聚合月数据 + 今日实时数据
if monthly_stats and monthly_stats.total_requests:
total_requests = int(monthly_stats.total_requests or 0) + requests_today
error_requests = int(monthly_stats.error_requests or 0) + today_stats["error_requests"]
total_cost = float(monthly_stats.total_cost or 0) + cost_today
total_actual_cost = float(monthly_stats.actual_total_cost or 0) + actual_cost_today
total_tokens = int(monthly_stats.total_tokens or 0) + tokens_today
cache_creation_tokens = int(monthly_stats.cache_creation_tokens or 0) + cache_creation_today
cache_read_tokens = int(monthly_stats.cache_read_tokens or 0) + cache_read_today
cache_creation_cost = float(monthly_stats.cache_creation_cost or 0)
cache_read_cost = float(monthly_stats.cache_read_cost or 0)
fallback_count = int(monthly_stats.fallback_count or 0)
else:
# 回退到实时查询(没有预聚合数据时)
total_requests = (
db.query(func.count(Usage.id)).filter(Usage.created_at >= last_month).scalar() or 0
)
total_cost = (
db.query(func.sum(Usage.total_cost_usd)).filter(Usage.created_at >= last_month).scalar() or 0
)
total_actual_cost = (
db.query(func.sum(Usage.actual_total_cost_usd))
.filter(Usage.created_at >= last_month).scalar() or 0
)
error_requests = (
db.query(func.count(Usage.id))
.filter(
Usage.created_at >= last_month,
(Usage.status_code >= 400) | (Usage.error_message.isnot(None)),
).scalar() or 0
)
total_tokens = (
db.query(func.sum(Usage.total_tokens)).filter(Usage.created_at >= last_month).scalar() or 0
)
cache_stats = (
db.query(
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
)
.filter(Usage.created_at >= last_month)
.first()
)
cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
cache_read_tokens = int(cache_stats.cache_read_tokens or 0) if cache_stats else 0
cache_creation_cost = float(cache_stats.cache_creation_cost or 0) if cache_stats else 0
cache_read_cost = float(cache_stats.cache_read_cost or 0) if cache_stats else 0
# Fallback 统计
fallback_subquery = (
db.query(
RequestCandidate.request_id, func.count(RequestCandidate.id).label("executed_count")
)
.filter(
RequestCandidate.created_at >= last_month,
RequestCandidate.status.in_(["success", "failed"]),
)
.group_by(RequestCandidate.request_id)
.subquery()
)
fallback_count = (
db.query(func.count())
.select_from(fallback_subquery)
.filter(fallback_subquery.c.executed_count > 1)
.scalar() or 0
)
# ==================== 系统健康指标 ====================
error_rate = round((error_requests / total_requests) * 100, 2) if total_requests > 0 else 0
# 平均响应时间(仅查询今日数据,降低查询成本)
avg_response_time = (
db.query(func.avg(Usage.response_time_ms))
.filter(
Usage.created_at >= today,
Usage.status_code == 200,
Usage.response_time_ms.isnot(None),
)
.scalar() or 0
)
avg_response_time_seconds = float(avg_response_time) / 1000.0
# 缓存命中率
total_input_with_cache = all_time_input_tokens + all_time_cache_read
cache_hit_rate = (
round((all_time_cache_read / total_input_with_cache) * 100, 1)
if total_input_with_cache > 0
else 0
)
return {
"stats": [
{
"name": "总请求",
"value": f"{all_time_requests:,}",
"subValue": f"有效 {all_time_success_requests:,} / 异常 {all_time_error_requests:,}",
"change": (
f"+{requests_today}"
if requests_today > requests_yesterday
else str(requests_today)
),
"changeType": (
"increase"
if requests_today > requests_yesterday
else ("decrease" if requests_today < requests_yesterday else "neutral")
),
"icon": "Activity",
},
{
"name": "总费用",
"value": f"${all_time_cost:.2f}",
"subValue": f"倍率后 ${all_time_actual_cost:.2f}",
"change": (
f"+${cost_today:.2f}"
if cost_today > cost_yesterday
else f"${cost_today:.2f}"
),
"changeType": (
"increase"
if cost_today > cost_yesterday
else ("decrease" if cost_today < cost_yesterday else "neutral")
),
"icon": "DollarSign",
},
{
"name": "总Token",
"value": format_tokens(
all_time_input_tokens
+ all_time_output_tokens
+ all_time_cache_creation
+ all_time_cache_read
),
"subValue": f"输入 {format_tokens(all_time_input_tokens)} / 输出 {format_tokens(all_time_output_tokens)}",
"change": (
f"+{format_tokens(input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today)}"
if (input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today)
> (input_tokens_yesterday + output_tokens_yesterday + cache_creation_yesterday + cache_read_yesterday)
else format_tokens(input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today)
),
"changeType": (
"increase"
if (input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today)
> (input_tokens_yesterday + output_tokens_yesterday + cache_creation_yesterday + cache_read_yesterday)
else (
"decrease"
if (input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today)
< (input_tokens_yesterday + output_tokens_yesterday + cache_creation_yesterday + cache_read_yesterday)
else "neutral"
)
),
"icon": "Hash",
},
{
"name": "总缓存",
"value": format_tokens(all_time_cache_creation + all_time_cache_read),
"subValue": f"创建 {format_tokens(all_time_cache_creation)} / 读取 {format_tokens(all_time_cache_read)}",
"change": (
f"+{format_tokens(cache_creation_today + cache_read_today)}"
if (cache_creation_today + cache_read_today)
> (cache_creation_yesterday + cache_read_yesterday)
else format_tokens(cache_creation_today + cache_read_today)
),
"changeType": (
"increase"
if (cache_creation_today + cache_read_today)
> (cache_creation_yesterday + cache_read_yesterday)
else (
"decrease"
if (cache_creation_today + cache_read_today)
< (cache_creation_yesterday + cache_read_yesterday)
else "neutral"
)
),
"extraBadge": f"命中率 {cache_hit_rate}%",
"icon": "Database",
},
],
"today": {
"requests": requests_today,
"cost": cost_today,
"actual_cost": actual_cost_today,
"tokens": tokens_today,
"cache_creation_tokens": cache_creation_today,
"cache_read_tokens": cache_read_today,
},
"api_keys": {"total": total_api_keys, "active": active_api_keys},
"tokens": {"month": total_tokens},
"token_breakdown": {
"input": all_time_input_tokens,
"output": all_time_output_tokens,
"cache_creation": all_time_cache_creation,
"cache_read": all_time_cache_read,
},
"system_health": {
"avg_response_time": round(avg_response_time_seconds, 2),
"error_rate": error_rate,
"error_requests": error_requests,
"fallback_count": fallback_count,
"total_requests": total_requests,
},
"cost_stats": {
"total_cost": round(total_cost, 4),
"total_actual_cost": round(total_actual_cost, 4),
"cost_savings": round(total_cost - total_actual_cost, 4),
},
"cache_stats": {
"cache_creation_tokens": cache_creation_tokens,
"cache_read_tokens": cache_read_tokens,
"cache_creation_cost": round(cache_creation_cost, 4),
"cache_read_cost": round(cache_read_cost, 4),
"total_cache_tokens": cache_creation_tokens + cache_read_tokens,
},
"users": {
"total": total_users,
"active": active_users,
},
}
class UserDashboardStatsAdapter(DashboardAdapter):
@cache_result(key_prefix="dashboard:user:stats", ttl=30, user_specific=True)
async def handle(self, context): # type: ignore[override]
db = context.db
user = context.user
now = datetime.now(timezone.utc)
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
last_month = today - timedelta(days=30)
yesterday = today - timedelta(days=1)
user_api_keys = db.query(func.count(ApiKey.id)).filter(ApiKey.user_id == user.id).scalar()
active_keys = (
db.query(func.count(ApiKey.id))
.filter(and_(ApiKey.user_id == user.id, ApiKey.is_active.is_(True)))
.scalar()
)
# 全局 Token 统计
all_time_token_stats = (
db.query(
func.sum(Usage.input_tokens).label("input_tokens"),
func.sum(Usage.output_tokens).label("output_tokens"),
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
)
.filter(Usage.user_id == user.id)
.first()
)
all_time_input_tokens = (
int(all_time_token_stats.input_tokens or 0) if all_time_token_stats else 0
)
all_time_output_tokens = (
int(all_time_token_stats.output_tokens or 0) if all_time_token_stats else 0
)
all_time_cache_creation = (
int(all_time_token_stats.cache_creation_tokens or 0) if all_time_token_stats else 0
)
all_time_cache_read = (
int(all_time_token_stats.cache_read_tokens or 0) if all_time_token_stats else 0
)
# 本月请求统计
user_requests = (
db.query(func.count(Usage.id))
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month))
.scalar()
)
user_cost = (
db.query(func.sum(Usage.total_cost_usd))
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month))
.scalar()
or 0
)
# 今日统计
requests_today = (
db.query(func.count(Usage.id))
.filter(and_(Usage.user_id == user.id, Usage.created_at >= today))
.scalar()
)
cost_today = (
db.query(func.sum(Usage.total_cost_usd))
.filter(and_(Usage.user_id == user.id, Usage.created_at >= today))
.scalar()
or 0
)
tokens_today = (
db.query(func.sum(Usage.total_tokens))
.filter(and_(Usage.user_id == user.id, Usage.created_at >= today))
.scalar()
or 0
)
# 昨日统计(用于计算变化)
requests_yesterday = (
db.query(func.count(Usage.id))
.filter(
and_(
Usage.user_id == user.id,
Usage.created_at >= yesterday,
Usage.created_at < today,
)
)
.scalar()
)
# 缓存统计(本月)
cache_stats = (
db.query(
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
func.sum(Usage.input_tokens).label("total_input_tokens"),
)
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month))
.first()
)
cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
cache_read_tokens = int(cache_stats.cache_read_tokens or 0) if cache_stats else 0
# 计算缓存命中率cache_read / (input_tokens + cache_read)
# input_tokens 是实际发送给模型的输入不含缓存读取cache_read 是从缓存读取的
# 总输入 = input_tokens + cache_read缓存命中率 = cache_read / 总输入
total_input_with_cache = all_time_input_tokens + all_time_cache_read
cache_hit_rate = (
round((all_time_cache_read / total_input_with_cache) * 100, 1)
if total_input_with_cache > 0
else 0
)
# 今日缓存统计
cache_stats_today = (
db.query(
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
)
.filter(and_(Usage.user_id == user.id, Usage.created_at >= today))
.first()
)
cache_creation_tokens_today = (
int(cache_stats_today.cache_creation_tokens or 0) if cache_stats_today else 0
)
cache_read_tokens_today = (
int(cache_stats_today.cache_read_tokens or 0) if cache_stats_today else 0
)
# 配额状态
if user.quota_usd is None:
quota_value = "无限制"
quota_change = f"已用 ${user.used_usd:.2f}"
quota_high = False
elif user.quota_usd and user.quota_usd > 0:
percent = min(100, int((user.used_usd / user.quota_usd) * 100))
quota_value = "无限制"
quota_change = f"已用 ${user.used_usd:.2f}"
quota_high = percent > 80
else:
quota_value = "0%"
quota_change = f"已用 ${user.used_usd:.2f}"
quota_high = False
return {
"stats": [
{
"name": "API 密钥",
"value": f"{active_keys}/{user_api_keys}",
"icon": "Key",
},
{
"name": "本月请求",
"value": f"{user_requests:,}",
"change": f"今日 {requests_today}",
"changeType": (
"increase"
if requests_today > requests_yesterday
else ("decrease" if requests_today < requests_yesterday else "neutral")
),
"icon": "Activity",
},
{
"name": "配额使用",
"value": quota_value,
"change": quota_change,
"changeType": "increase" if quota_high else "neutral",
"icon": "TrendingUp",
},
{
"name": "本月费用",
"value": f"${user_cost:.2f}",
"icon": "DollarSign",
},
],
"today": {
"requests": requests_today,
"cost": cost_today,
"tokens": tokens_today,
"cache_creation_tokens": cache_creation_tokens_today,
"cache_read_tokens": cache_read_tokens_today,
},
# 全局 Token 详细分类(与管理员端对齐)
"token_breakdown": {
"input": all_time_input_tokens,
"output": all_time_output_tokens,
"cache_creation": all_time_cache_creation,
"cache_read": all_time_cache_read,
},
# 用户视角:缓存使用情况
"cache_stats": {
"cache_creation_tokens": cache_creation_tokens,
"cache_read_tokens": cache_read_tokens,
"cache_hit_rate": cache_hit_rate,
"total_cache_tokens": cache_creation_tokens + cache_read_tokens,
},
}
@dataclass
class DashboardRecentRequestsAdapter(DashboardAdapter):
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
user = context.user
query = db.query(Usage)
if user.role != UserRole.ADMIN:
query = query.filter(Usage.user_id == user.id)
recent_requests = query.order_by(Usage.created_at.desc()).limit(self.limit).all()
results = []
for req in recent_requests:
owner = db.query(DBUser).filter(DBUser.id == req.user_id).first()
results.append(
{
"id": req.id,
"user": owner.username if owner else "Unknown",
"model": req.model or "N/A",
"tokens": req.total_tokens,
"time": req.created_at.strftime("%H:%M") if req.created_at else None,
"is_stream": req.is_stream,
}
)
return {"requests": results}
# NOTE: DashboardRequestDetailAdapter has been moved to AdminUsageDetailAdapter
# in src/api/admin/usage/routes.py
class DashboardProviderStatusAdapter(DashboardAdapter):
@cache_result(key_prefix="dashboard:provider:status", ttl=60, user_specific=False)
async def handle(self, context): # type: ignore[override]
db = context.db
user = context.user
providers = db.query(Provider).filter(Provider.is_active.is_(True)).all()
since = datetime.now(timezone.utc) - timedelta(days=1)
entries = []
for provider in providers:
count = (
db.query(func.count(Usage.id))
.filter(and_(Usage.provider == provider.name, Usage.created_at >= since))
.scalar()
)
entries.append(
{
"name": provider.name,
"status": "active" if provider.is_active else "inactive",
"requests": count,
}
)
entries.sort(key=lambda x: x["requests"], reverse=True)
limit = 10 if user.role == UserRole.ADMIN else 5
return {"providers": entries[:limit]}
@dataclass
class DashboardDailyStatsAdapter(DashboardAdapter):
days: int
@cache_result(key_prefix="dashboard:daily:stats", ttl=300, user_specific=True)
async def handle(self, context): # type: ignore[override]
db = context.db
user = context.user
is_admin = user.role == UserRole.ADMIN
now = datetime.now(timezone.utc)
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = now.replace(hour=23, minute=59, second=59, microsecond=999999)
start_date = (end_date - timedelta(days=self.days - 1)).replace(
hour=0, minute=0, second=0, microsecond=0
)
# ==================== 使用预聚合数据优化 ====================
if is_admin:
# 管理员:从 stats_daily 获取历史数据
daily_stats = (
db.query(StatsDaily)
.filter(and_(StatsDaily.date >= start_date, StatsDaily.date < today))
.order_by(StatsDaily.date.asc())
.all()
)
stats_map = {
stat.date.replace(tzinfo=timezone.utc).date().isoformat(): {
"requests": stat.total_requests,
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
"cost": stat.total_cost,
"avg_response_time": stat.avg_response_time_ms / 1000.0 if stat.avg_response_time_ms else 0,
"unique_models": getattr(stat, 'unique_models', 0) or 0,
"unique_providers": getattr(stat, 'unique_providers', 0) or 0,
"fallback_count": stat.fallback_count or 0,
}
for stat in daily_stats
}
# 今日实时数据
today_stats = StatsAggregatorService.get_today_realtime_stats(db)
today_str = today.date().isoformat()
if today_stats["total_requests"] > 0:
# 今日平均响应时间需要单独查询
today_avg_rt = (
db.query(func.avg(Usage.response_time_ms))
.filter(Usage.created_at >= today, Usage.response_time_ms.isnot(None))
.scalar() or 0
)
# 今日 unique_models 和 unique_providers
today_unique_models = (
db.query(func.count(func.distinct(Usage.model)))
.filter(Usage.created_at >= today)
.scalar() or 0
)
today_unique_providers = (
db.query(func.count(func.distinct(Usage.provider)))
.filter(Usage.created_at >= today)
.scalar() or 0
)
# 今日 fallback_count
today_fallback_count = (
db.query(func.count())
.select_from(
db.query(RequestCandidate.request_id)
.filter(
RequestCandidate.created_at >= today,
RequestCandidate.status.in_(["success", "failed"]),
)
.group_by(RequestCandidate.request_id)
.having(func.count(RequestCandidate.id) > 1)
.subquery()
)
.scalar() or 0
)
stats_map[today_str] = {
"requests": today_stats["total_requests"],
"tokens": (today_stats["input_tokens"] + today_stats["output_tokens"] +
today_stats["cache_creation_tokens"] + today_stats["cache_read_tokens"]),
"cost": today_stats["total_cost"],
"avg_response_time": float(today_avg_rt) / 1000.0 if today_avg_rt else 0,
"unique_models": today_unique_models,
"unique_providers": today_unique_providers,
"fallback_count": today_fallback_count,
}
else:
# 普通用户:仍需实时查询(用户级预聚合可选)
query = db.query(Usage).filter(
and_(
Usage.user_id == user.id,
Usage.created_at >= start_date,
Usage.created_at <= end_date,
)
)
user_daily_stats = (
query.with_entities(
func.date(Usage.created_at).label("date"),
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("tokens"),
func.sum(Usage.total_cost_usd).label("cost"),
func.avg(Usage.response_time_ms).label("avg_response_time"),
)
.group_by(func.date(Usage.created_at))
.order_by(func.date(Usage.created_at).asc())
.all()
)
stats_map = {
stat.date.isoformat(): {
"requests": stat.requests or 0,
"tokens": int(stat.tokens or 0),
"cost": float(stat.cost or 0),
"avg_response_time": float(stat.avg_response_time or 0) / 1000.0 if stat.avg_response_time else 0,
}
for stat in user_daily_stats
}
# 构建完整日期序列
current_date = start_date.date()
formatted: List[dict] = []
while current_date <= end_date.date():
date_str = current_date.isoformat()
stat = stats_map.get(date_str)
if stat:
formatted.append({
"date": date_str,
"requests": stat["requests"],
"tokens": stat["tokens"],
"cost": stat["cost"],
"avg_response_time": stat["avg_response_time"],
"unique_models": stat.get("unique_models", 0),
"unique_providers": stat.get("unique_providers", 0),
"fallback_count": stat.get("fallback_count", 0),
})
else:
formatted.append({
"date": date_str,
"requests": 0,
"tokens": 0,
"cost": 0.0,
"avg_response_time": 0.0,
"unique_models": 0,
"unique_providers": 0,
"fallback_count": 0,
})
current_date += timedelta(days=1)
# ==================== 模型统计(仍需实时查询)====================
model_query = db.query(Usage)
if not is_admin:
model_query = model_query.filter(Usage.user_id == user.id)
model_query = model_query.filter(
and_(Usage.created_at >= start_date, Usage.created_at <= end_date)
)
model_stats = (
model_query.with_entities(
Usage.model,
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("tokens"),
func.sum(Usage.total_cost_usd).label("cost"),
func.avg(Usage.response_time_ms).label("avg_response_time"),
)
.group_by(Usage.model)
.order_by(func.sum(Usage.total_cost_usd).desc())
.all()
)
model_summary = [
{
"model": stat.model,
"requests": stat.requests or 0,
"tokens": int(stat.tokens or 0),
"cost": float(stat.cost or 0),
"avg_response_time": (
float(stat.avg_response_time or 0) / 1000.0 if stat.avg_response_time else 0
),
"cost_per_request": float(stat.cost or 0) / max(stat.requests or 1, 1),
"tokens_per_request": int(stat.tokens or 0) / max(stat.requests or 1, 1),
}
for stat in model_stats
]
daily_model_stats = (
model_query.with_entities(
func.date(Usage.created_at).label("date"),
Usage.model,
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("tokens"),
func.sum(Usage.total_cost_usd).label("cost"),
)
.group_by(func.date(Usage.created_at), Usage.model)
.order_by(func.date(Usage.created_at).desc(), func.sum(Usage.total_cost_usd).desc())
.all()
)
breakdown = {}
for stat in daily_model_stats:
date_str = stat.date.isoformat()
breakdown.setdefault(date_str, []).append(
{
"model": stat.model,
"requests": stat.requests or 0,
"tokens": int(stat.tokens or 0),
"cost": float(stat.cost or 0),
}
)
for item in formatted:
item["model_breakdown"] = breakdown.get(item["date"], [])
return {
"daily_stats": formatted,
"model_summary": model_summary,
"period": {
"start_date": start_date.date().isoformat(),
"end_date": end_date.date().isoformat(),
"days": self.days,
},
}

View File

@@ -0,0 +1,99 @@
"""
API Handlers - 请求处理器
按 API 格式组织的 Adapter 和 Handler
- Adapter: 请求验证、格式转换、错误处理
- Handler: 业务逻辑、调用 Provider、记录用量
支持的格式:
- claude: Claude Chat API (/v1/messages)
- claude_cli: Claude CLI 透传模式
- openai: OpenAI Chat API (/v1/chat/completions)
- openai_cli: OpenAI CLI 透传模式
注意Handler 基类和具体 Handler 使用延迟导入以避免循环依赖。
"""
# Adapter 基类(不会引起循环导入,可以直接导入)
from src.api.handlers.base import (
ChatAdapterBase,
CliAdapterBase,
)
__all__ = [
# Adapter 基类
"ChatAdapterBase",
"CliAdapterBase",
# Handler 基类(延迟导入)
"ChatHandlerBase",
"CliMessageHandlerBase",
"BaseMessageHandler",
"MessageHandlerProtocol",
"MessageTelemetry",
"StreamContext",
# Claude
"ClaudeChatAdapter",
"ClaudeTokenCountAdapter",
"build_claude_adapter",
"ClaudeChatHandler",
# Claude CLI
"ClaudeCliAdapter",
"ClaudeCliMessageHandler",
# OpenAI
"OpenAIChatAdapter",
"OpenAIChatHandler",
# OpenAI CLI
"OpenAICliAdapter",
"OpenAICliMessageHandler",
]
# 延迟导入映射表
_LAZY_IMPORTS = {
# Handler 基类
"ChatHandlerBase": ("src.api.handlers.base.chat_handler_base", "ChatHandlerBase"),
"CliMessageHandlerBase": (
"src.api.handlers.base.cli_handler_base",
"CliMessageHandlerBase",
),
"StreamContext": ("src.api.handlers.base.cli_handler_base", "StreamContext"),
"BaseMessageHandler": ("src.api.handlers.base.base_handler", "BaseMessageHandler"),
"MessageHandlerProtocol": (
"src.api.handlers.base.base_handler",
"MessageHandlerProtocol",
),
"MessageTelemetry": ("src.api.handlers.base.base_handler", "MessageTelemetry"),
# Claude
"ClaudeChatAdapter": ("src.api.handlers.claude.adapter", "ClaudeChatAdapter"),
"ClaudeTokenCountAdapter": (
"src.api.handlers.claude.adapter",
"ClaudeTokenCountAdapter",
),
"build_claude_adapter": ("src.api.handlers.claude.adapter", "build_claude_adapter"),
"ClaudeChatHandler": ("src.api.handlers.claude.handler", "ClaudeChatHandler"),
# Claude CLI
"ClaudeCliAdapter": ("src.api.handlers.claude_cli.adapter", "ClaudeCliAdapter"),
"ClaudeCliMessageHandler": (
"src.api.handlers.claude_cli.handler",
"ClaudeCliMessageHandler",
),
# OpenAI
"OpenAIChatAdapter": ("src.api.handlers.openai.adapter", "OpenAIChatAdapter"),
"OpenAIChatHandler": ("src.api.handlers.openai.handler", "OpenAIChatHandler"),
# OpenAI CLI
"OpenAICliAdapter": ("src.api.handlers.openai_cli.adapter", "OpenAICliAdapter"),
"OpenAICliMessageHandler": (
"src.api.handlers.openai_cli.handler",
"OpenAICliMessageHandler",
),
}
def __getattr__(name: str):
"""延迟导入以避免循环依赖"""
if name in _LAZY_IMPORTS:
module_path, attr_name = _LAZY_IMPORTS[name]
import importlib
module = importlib.import_module(module_path)
return getattr(module, attr_name)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -0,0 +1,68 @@
"""
Handler 基类模块
提供 Adapter、Handler 的抽象基类,以及请求构建器和响应解析器。
注意Handler 基类ChatHandlerBase, CliMessageHandlerBase 等)不在这里导出,
因为它们依赖 services.usage.stream而后者又需要导入 response_parser
会形成循环导入。请直接从具体模块导入 Handler 基类。
"""
# Chat Adapter 基类(不会引起循环导入)
from src.api.handlers.base.chat_adapter_base import (
ChatAdapterBase,
get_adapter_class,
get_adapter_instance,
list_registered_formats,
register_adapter,
)
# CLI Adapter 基类
from src.api.handlers.base.cli_adapter_base import (
CliAdapterBase,
get_cli_adapter_class,
get_cli_adapter_instance,
list_registered_cli_formats,
register_cli_adapter,
)
# 请求构建器
from src.api.handlers.base.request_builder import (
SENSITIVE_HEADERS,
PassthroughRequestBuilder,
RequestBuilder,
build_passthrough_request,
)
# 响应解析器
from src.api.handlers.base.response_parser import (
ParsedChunk,
ParsedResponse,
ResponseParser,
StreamStats,
)
__all__ = [
# Chat Adapter
"ChatAdapterBase",
"register_adapter",
"get_adapter_class",
"get_adapter_instance",
"list_registered_formats",
# CLI Adapter
"CliAdapterBase",
"register_cli_adapter",
"get_cli_adapter_class",
"get_cli_adapter_instance",
"list_registered_cli_formats",
# 请求构建器
"RequestBuilder",
"PassthroughRequestBuilder",
"build_passthrough_request",
"SENSITIVE_HEADERS",
# 响应解析器
"ResponseParser",
"ParsedChunk",
"ParsedResponse",
"StreamStats",
]

View File

@@ -0,0 +1,363 @@
"""
基础消息处理器,封装通用的编排、转换、遥测逻辑。
接口约定:
- process_stream: 处理流式请求,返回 StreamingResponse
- process_sync: 处理非流式请求,返回 JSONResponse
签名规范(推荐):
async def process_stream(
self,
request: Any, # 解析后的请求模型
http_request: Request, # FastAPI Request 对象
original_headers: Dict[str, str], # 原始请求头
original_request_body: Dict[str, Any], # 原始请求体
query_params: Optional[Dict[str, str]] = None, # 查询参数
) -> StreamingResponse: ...
async def process_sync(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> JSONResponse: ...
"""
from __future__ import annotations
import time
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from sqlalchemy.orm import Session
from src.clients.redis_client import get_redis_client_sync
from src.core.api_format_metadata import resolve_api_format
from src.core.enums import APIFormat
from src.core.logger import logger
from src.services.orchestration.fallback_orchestrator import FallbackOrchestrator
from src.services.provider.format import normalize_api_format
from src.services.system.audit import audit_service
from src.services.usage.service import UsageService
class MessageTelemetry:
"""
负责记录 Usage/Audit避免处理器里重复代码。
"""
def __init__(self, db: Session, user, api_key, request_id: str, client_ip: str):
self.db = db
self.user = user
self.api_key = api_key
self.request_id = request_id
self.client_ip = client_ip
async def calculate_cost(
self,
provider: str,
model: str,
*,
input_tokens: int,
output_tokens: int,
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
) -> float:
input_price, output_price = await UsageService.get_model_price_async(
self.db, provider, model
)
_, _, _, _, _, _, total_cost = UsageService.calculate_cost(
input_tokens,
output_tokens,
input_price,
output_price,
cache_creation_tokens,
cache_read_tokens,
*await UsageService.get_cache_prices_async(self.db, provider, model, input_price),
)
return total_cost
async def record_success(
self,
*,
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
response_time_ms: int,
status_code: int,
request_body: Dict[str, Any],
request_headers: Dict[str, Any],
response_body: Any,
response_headers: Dict[str, Any],
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
is_stream: bool = False,
provider_request_headers: Optional[Dict[str, Any]] = None,
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
api_format: Optional[str] = None,
# 模型映射信息
target_model: Optional[str] = None,
# Provider 响应元数据(如 Gemini 的 modelVersion
response_metadata: Optional[Dict[str, Any]] = None,
) -> float:
total_cost = await self.calculate_cost(
provider,
model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_tokens=cache_creation_tokens,
cache_read_tokens=cache_read_tokens,
)
await UsageService.record_usage(
db=self.db,
user=self.user,
api_key=self.api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_tokens,
cache_read_input_tokens=cache_read_tokens,
request_type="chat",
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
status_code=status_code,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers or {},
response_headers=response_headers,
response_body=response_body,
request_id=self.request_id,
# Provider 侧追踪信息(用于记录真实成本)
provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id,
# 模型映射信息
target_model=target_model,
# Provider 响应元数据
metadata=response_metadata,
)
if self.user and self.api_key:
audit_service.log_api_request(
db=self.db,
user_id=self.user.id,
api_key_id=self.api_key.id,
request_id=self.request_id,
model=model,
provider=provider,
success=True,
ip_address=self.client_ip,
status_code=status_code,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost_usd=total_cost,
)
return total_cost
async def record_failure(
self,
*,
provider: str,
model: str,
response_time_ms: int,
status_code: int,
error_message: str,
request_body: Dict[str, Any],
request_headers: Dict[str, Any],
is_stream: bool,
api_format: Optional[str] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
# 预估 token 信息(来自 message_start 事件,用于中断请求的成本估算)
input_tokens: int = 0,
output_tokens: int = 0,
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
response_body: Optional[Dict[str, Any]] = None,
# 模型映射信息
target_model: Optional[str] = None,
):
"""
记录失败请求
注意Provider 链路信息provider_id, endpoint_id, key_id不在此处记录
因为 RequestCandidate 表已经记录了完整的请求链路追踪信息。
Args:
input_tokens: 预估输入 tokens来自 message_start用于中断请求的成本估算
output_tokens: 预估输出 tokens来自已收到的内容
cache_creation_tokens: 缓存创建 tokens
cache_read_tokens: 缓存读取 tokens
response_body: 响应体(如果有部分响应)
target_model: 映射后的目标模型名(如果发生了映射)
"""
provider_name = provider or "unknown"
if provider_name == "unknown":
logger.warning(f"[Telemetry] Recording failure with unknown provider (request_id={self.request_id})")
await UsageService.record_usage(
db=self.db,
user=self.user,
api_key=self.api_key,
provider=provider_name,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_tokens,
cache_read_input_tokens=cache_read_tokens,
request_type="chat",
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=error_message,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers or {},
response_headers={},
response_body=response_body or {"error": error_message},
request_id=self.request_id,
# 模型映射信息
target_model=target_model,
)
@runtime_checkable
class MessageHandlerProtocol(Protocol):
"""
消息处理器协议 - 定义标准接口
ChatHandlerBase 使用完整签名(含 request, http_request
CliMessageHandlerBase 使用简化签名(仅 original_request_body, original_headers
"""
async def process_stream(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> StreamingResponse:
"""处理流式请求"""
...
async def process_sync(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> JSONResponse:
"""处理非流式请求"""
...
class BaseMessageHandler:
"""
消息处理器基类,所有具体格式的 handler 可以继承它。
子类需要实现:
- process_stream: 处理流式请求
- process_sync: 处理非流式请求
推荐使用 MessageHandlerProtocol 中定义的签名。
"""
# Adapter 检测器类型
AdapterDetectorType = Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]
def __init__(
self,
*,
db: Session,
user,
api_key,
request_id: str,
client_ip: str,
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list[str]] = None,
adapter_detector: Optional[AdapterDetectorType] = None,
):
self.db = db
self.user = user
self.api_key = api_key
self.request_id = request_id
self.client_ip = client_ip
self.user_agent = user_agent
self.start_time = start_time
self.allowed_api_formats = allowed_api_formats or [APIFormat.CLAUDE.value]
self.primary_api_format = normalize_api_format(self.allowed_api_formats[0])
self.adapter_detector = adapter_detector
redis_client = get_redis_client_sync()
self.orchestrator = FallbackOrchestrator(db, redis_client)
self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip)
def elapsed_ms(self) -> int:
return int((time.time() - self.start_time) * 1000)
def _resolve_capability_requirements(
self,
model_name: str,
request_headers: Optional[Dict[str, str]] = None,
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""
解析请求的能力需求
来源:
1. 用户模型级配置 (User.model_capability_settings)
2. 用户 API Key 强制配置 (ApiKey.force_capabilities)
3. 请求头 X-Require-Capability
4. Adapter 的 detect_capability_requirements如 Claude 的 anthropic-beta
Args:
model_name: 模型名称
request_headers: 请求头
request_body: 请求体(可选)
Returns:
能力需求字典
"""
from src.services.capability.resolver import CapabilityResolver
return CapabilityResolver.resolve_requirements(
user=self.user,
user_api_key=self.api_key,
model_name=model_name,
request_headers=request_headers,
request_body=request_body,
adapter_detector=self.adapter_detector,
)
def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat:
"""根据 provider_type 解析 API 格式,未知类型默认 OPENAI"""
if provider_type:
return resolve_api_format(provider_type, default=APIFormat.OPENAI)
return self.primary_api_format
def build_provider_payload(
self,
original_body: Dict[str, Any],
*,
mapped_model: Optional[str] = None,
) -> Dict[str, Any]:
"""构建发送给 Provider 的请求体,替换 model 名称"""
payload = dict(original_body)
if mapped_model:
payload["model"] = mapped_model
return payload

View File

@@ -0,0 +1,724 @@
"""
Chat Adapter 通用基类
提供 Chat 格式(进行请求验证和标准化)的通用适配器逻辑:
- 请求解析和验证
- 审计日志记录
- 错误处理和响应格式化
- Handler 创建和调用
- 计费策略(支持不同 API 格式的差异化计费)
子类只需提供:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: 对应的 ChatHandlerBase 子类
- _validate_request_body(): 可选覆盖请求验证逻辑
- _build_audit_metadata(): 可选覆盖审计元数据构建
- compute_total_input_context(): 可选覆盖总输入上下文计算(用于阶梯计费判定)
"""
import time
import traceback
from abc import abstractmethod
from typing import Any, Dict, Optional, Type
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.context import ApiRequestContext
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.exceptions import (
InvalidRequestException,
ModelNotSupportedException,
ProviderAuthException,
ProviderNotAvailableException,
ProviderRateLimitException,
ProviderTimeoutException,
ProxyException,
QuotaExceededException,
UpstreamClientException,
)
from src.core.logger import logger
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
class ChatAdapterBase(ApiAdapter):
"""
Chat Adapter 通用基类
提供 Chat 格式的通用适配器逻辑,子类只需配置:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: ChatHandlerBase 子类
- name: 适配器名称
"""
# 子类必须覆盖
FORMAT_ID: str = "UNKNOWN"
HANDLER_CLASS: Type[ChatHandlerBase]
# 适配器配置
name: str = "chat.base"
mode = ApiMode.STANDARD
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
self.response_normalizer = None
# 可选启用响应规范化
self._init_response_normalizer()
def _init_response_normalizer(self):
"""初始化响应规范化器 - 子类可覆盖"""
try:
from src.services.provider.response_normalizer import ResponseNormalizer
self.response_normalizer = ResponseNormalizer()
except ImportError:
pass
async def handle(self, context: ApiRequestContext):
"""处理 Chat API 请求"""
http_request = context.request
user = context.user
api_key = context.api_key
db = context.db
request_id = context.request_id
quota_remaining_value = context.quota_remaining
start_time = context.start_time
client_ip = context.client_ip
user_agent = context.user_agent
original_headers = context.original_headers
query_params = context.query_params
original_request_body = context.ensure_json_body()
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
if context.path_params:
original_request_body = self._merge_path_params(
original_request_body, context.path_params
)
# 验证和解析请求
request_obj = self._validate_request_body(original_request_body, context.path_params)
if isinstance(request_obj, JSONResponse):
return request_obj
stream = getattr(request_obj, "stream", False)
model = getattr(request_obj, "model", "unknown")
# 添加审计元数据
audit_metadata = self._build_audit_metadata(original_request_body, request_obj)
context.add_audit_metadata(**audit_metadata)
# 格式化额度显示
quota_display = (
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
)
# 请求开始日志
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
try:
# 检查客户端连接
if await http_request.is_disconnected():
logger.warning("客户端连接断开")
raise HTTPException(status_code=499, detail="Client disconnected")
# 创建 Handler
handler = self._create_handler(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
)
# 处理请求
if stream:
return await handler.process_stream(
request=request_obj,
http_request=http_request,
original_headers=original_headers,
original_request_body=original_request_body,
query_params=query_params,
)
return await handler.process_sync(
request=request_obj,
http_request=http_request,
original_headers=original_headers,
original_request_body=original_request_body,
query_params=query_params,
)
except HTTPException:
raise
except (
ModelNotSupportedException,
QuotaExceededException,
InvalidRequestException,
) as e:
logger.info(f"客户端请求错误: {e.error_type}")
return self._error_response(
status_code=e.status_code,
error_type=(
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
),
message=e.message,
)
except (
ProviderAuthException,
ProviderRateLimitException,
ProviderNotAvailableException,
ProviderTimeoutException,
UpstreamClientException,
) as e:
return await self._handle_provider_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
except Exception as e:
return await self._handle_unexpected_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
def _create_handler(
self,
*,
db,
user,
api_key,
request_id: str,
client_ip: str,
user_agent: str,
start_time: float,
):
"""创建 Handler 实例 - 子类可覆盖"""
return self.HANDLER_CLASS(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=self.allowed_api_formats,
response_normalizer=self.response_normalizer,
enable_response_normalization=self.response_normalizer is not None,
adapter_detector=self.detect_capability_requirements,
)
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - 子类可覆盖
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典
Returns:
合并后的请求体字典
"""
merged = original_request_body.copy()
for key, value in path_params.items():
if key not in merged:
merged[key] = value
return merged
@abstractmethod
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
"""
验证请求体 - 子类必须实现
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数(如 Gemini 的 stream 通过 URL 端点传入)
Returns:
验证后的请求对象,或 JSONResponse 错误响应
"""
pass
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
"""
提取消息数量 - 子类可覆盖
默认实现:从 messages 字段提取
"""
messages = payload.get("messages", [])
if hasattr(request_obj, "messages"):
messages = request_obj.messages
return len(messages) if isinstance(messages, list) else 0
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
"""
构建审计日志元数据 - 子类可覆盖
"""
model = getattr(request_obj, "model", payload.get("model", "unknown"))
stream = getattr(request_obj, "stream", payload.get("stream", False))
messages_count = self._extract_message_count(payload, request_obj)
return {
"action": f"{self.FORMAT_ID.lower()}_request",
"model": model,
"stream": bool(stream),
"max_tokens": getattr(request_obj, "max_tokens", payload.get("max_tokens")),
"messages_count": messages_count,
"temperature": getattr(request_obj, "temperature", payload.get("temperature")),
"top_p": getattr(request_obj, "top_p", payload.get("top_p")),
}
async def _handle_provider_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理 Provider 相关异常"""
logger.debug(f"Caught provider exception: {type(e).__name__}")
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
result.request_headers = original_headers
result.request_body = original_request_body
# 确定错误消息
if isinstance(e, ProviderAuthException):
error_message = (
f"提供商认证失败: {str(e)}"
if result.metadata.provider != "unknown"
else "服务端错误: 无可用提供商"
)
result.error_message = error_message
# 处理上游客户端错误(如图片处理失败)
if isinstance(e, UpstreamClientException):
# 返回 400 状态码和清晰的错误消息
result.status_code = e.status_code
result.error_message = e.message
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
# 根据异常类型确定错误类型
if isinstance(e, UpstreamClientException):
error_type = "invalid_request_error"
elif result.status_code == 503:
error_type = "internal_server_error"
else:
error_type = "rate_limit_exceeded"
return self._error_response(
status_code=result.status_code,
error_type=error_type,
message=result.error_message or str(e),
)
async def _handle_unexpected_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理未预期的异常"""
if isinstance(e, ProxyException):
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
else:
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
exception=e,
extra_data={
"exception_class": e.__class__.__name__,
"processing_stage": "request_processing",
"model": model,
"stream": stream,
"traceback_preview": str(traceback.format_exc())[:500],
},
)
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
# 对于未预期的异常,强制设置状态码为 500
result.status_code = 500
result.error_type = "internal_error"
result.request_headers = original_headers
result.request_body = original_request_body
try:
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
except Exception as record_error:
logger.error(f"记录失败请求时出错: {record_error}")
return self._error_response(
status_code=500,
error_type="internal_server_error",
message="处理请求时发生内部错误")
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
"""生成错误响应 - 子类可覆盖以自定义格式"""
return JSONResponse(
status_code=status_code,
content={
"error": {
"type": error_type,
"message": message,
}
},
)
# =========================================================================
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算总输入上下文(用于阶梯计费判定)
默认实现input_tokens + cache_read_input_tokens
子类可覆盖此方法实现不同的计算逻辑
Args:
input_tokens: 输入 token 数
cache_read_input_tokens: 缓存读取 token 数
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
Returns:
总输入上下文 token 数
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_price_per_1m: Optional[float],
cache_read_price_per_1m: Optional[float],
price_per_request: Optional[float],
tiered_pricing: Optional[dict] = None,
cache_ttl_minutes: Optional[int] = None,
) -> Dict[str, Any]:
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
Args:
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
input_price_per_1m: 输入价格(每 1M tokens
output_price_per_1m: 输出价格(每 1M tokens
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens
cache_read_price_per_1m: 缓存读取价格(每 1M tokens
price_per_request: 按次计费价格
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
包含各项成本的字典:
{
"input_cost": float,
"output_cost": float,
"cache_creation_cost": float,
"cache_read_cost": float,
"cache_cost": float,
"request_cost": float,
"total_cost": float,
"tier_index": Optional[int], # 命中的阶梯索引
}
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""
根据总输入 token 数确定价格阶梯
Args:
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
total_input_tokens: 总输入 token 数
Returns:
匹配的阶梯配置
"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
return tiers[-1] if tiers else None
# =========================================================================
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
# =========================================================================
_ADAPTER_REGISTRY: Dict[str, Type["ChatAdapterBase"]] = {}
_ADAPTERS_LOADED = False
def register_adapter(adapter_class: Type["ChatAdapterBase"]) -> Type["ChatAdapterBase"]:
"""
注册 Adapter 类到注册表
用法:
@register_adapter
class ClaudeChatAdapter(ChatAdapterBase):
FORMAT_ID = "CLAUDE"
...
Args:
adapter_class: Adapter 类
Returns:
注册的 Adapter 类(支持作为装饰器使用)
"""
format_id = adapter_class.FORMAT_ID
if format_id and format_id != "UNKNOWN":
_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
return adapter_class
def _ensure_adapters_loaded():
"""确保所有 Adapter 已被加载(触发注册)"""
global _ADAPTERS_LOADED
if _ADAPTERS_LOADED:
return
# 导入各个 Adapter 模块以触发 @register_adapter 装饰器
try:
from src.api.handlers.claude import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.openai import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.gemini import adapter as _ # noqa: F401
except ImportError:
pass
_ADAPTERS_LOADED = True
def get_adapter_class(api_format: str) -> Optional[Type["ChatAdapterBase"]]:
"""
根据 API format 获取 Adapter 类
Args:
api_format: API 格式标识(如 "CLAUDE", "OPENAI", "GEMINI"
Returns:
对应的 Adapter 类,如果未找到返回 None
"""
_ensure_adapters_loaded()
return _ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
def get_adapter_instance(api_format: str) -> Optional["ChatAdapterBase"]:
"""
根据 API format 获取 Adapter 实例
Args:
api_format: API 格式标识
Returns:
Adapter 实例,如果未找到返回 None
"""
adapter_class = get_adapter_class(api_format)
if adapter_class:
return adapter_class()
return None
def list_registered_formats() -> list[str]:
"""返回所有已注册的 API 格式"""
_ensure_adapters_loaded()
return list(_ADAPTER_REGISTRY.keys())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,648 @@
"""
CLI Adapter 通用基类
提供 CLI 格式(直接透传请求)的通用适配器逻辑:
- 请求解析和验证
- 审计日志记录
- 错误处理和响应格式化
- Handler 创建和调用
- 计费策略(支持不同 API 格式的差异化计费)
子类只需提供:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: 对应的 MessageHandler 类
- 可选覆盖 _extract_message_count() 自定义消息计数逻辑
- 可选覆盖 compute_total_input_context() 自定义总输入上下文计算
"""
import time
import traceback
from typing import Any, Dict, Optional, Type
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.context import ApiRequestContext
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
from src.core.exceptions import (
InvalidRequestException,
ModelNotSupportedException,
ProviderAuthException,
ProviderNotAvailableException,
ProviderRateLimitException,
ProviderTimeoutException,
ProxyException,
QuotaExceededException,
UpstreamClientException,
)
from src.core.logger import logger
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
class CliAdapterBase(ApiAdapter):
"""
CLI Adapter 通用基类
提供 CLI 格式的通用适配器逻辑,子类只需配置:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: MessageHandler 类
- name: 适配器名称
"""
# 子类必须覆盖
FORMAT_ID: str = "UNKNOWN"
HANDLER_CLASS: Type[CliMessageHandlerBase]
# 适配器配置
name: str = "cli.base"
mode = ApiMode.PROXY
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
async def handle(self, context: ApiRequestContext):
"""处理 CLI API 请求"""
http_request = context.request
user = context.user
api_key = context.api_key
db = context.db
request_id = context.request_id
quota_remaining_value = context.quota_remaining
start_time = context.start_time
client_ip = context.client_ip
user_agent = context.user_agent
original_headers = context.original_headers
query_params = context.query_params # 获取查询参数
original_request_body = context.ensure_json_body()
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
if context.path_params:
original_request_body = self._merge_path_params(
original_request_body, context.path_params
)
# 获取 stream优先从请求体其次从 path_params如 Gemini 通过 URL 端点区分)
stream = original_request_body.get("stream")
if stream is None and context.path_params:
stream = context.path_params.get("stream", False)
stream = bool(stream)
# 获取 model优先从请求体其次从 path_params如 Gemini 的 model 在 URL 路径中)
model = original_request_body.get("model")
if model is None and context.path_params:
model = context.path_params.get("model", "unknown")
model = model or "unknown"
# 提取请求元数据
audit_metadata = self._build_audit_metadata(original_request_body, context.path_params)
context.add_audit_metadata(**audit_metadata)
# 格式化额度显示
quota_display = (
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
)
# 请求开始日志
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
try:
# 检查客户端连接
if await http_request.is_disconnected():
logger.warning("客户端连接断开")
raise HTTPException(status_code=499, detail="Client disconnected")
# 创建 Handler
handler = self.HANDLER_CLASS(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=self.allowed_api_formats,
adapter_detector=self.detect_capability_requirements,
)
# 处理请求
if stream:
return await handler.process_stream(
original_request_body=original_request_body,
original_headers=original_headers,
query_params=query_params,
path_params=context.path_params,
)
return await handler.process_sync(
original_request_body=original_request_body,
original_headers=original_headers,
query_params=query_params,
path_params=context.path_params,
)
except HTTPException:
raise
except (
ModelNotSupportedException,
QuotaExceededException,
InvalidRequestException,
) as e:
logger.debug(f"客户端请求错误: {e.error_type}")
return self._error_response(
status_code=e.status_code,
error_type=(
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
),
message=e.message,
)
except (
ProviderAuthException,
ProviderRateLimitException,
ProviderNotAvailableException,
ProviderTimeoutException,
UpstreamClientException,
) as e:
return await self._handle_provider_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
except Exception as e:
return await self._handle_unexpected_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - 子类可覆盖
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典
Returns:
合并后的请求体字典
"""
merged = original_request_body.copy()
for key, value in path_params.items():
if key not in merged:
merged[key] = value
return merged
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
"""
提取消息数量 - 子类可覆盖
默认实现:从 input 字段提取
"""
if "input" not in payload:
return 0
input_data = payload["input"]
if isinstance(input_data, list):
return len(input_data)
if isinstance(input_data, dict) and "messages" in input_data:
return len(input_data.get("messages", []))
return 0
def _build_audit_metadata(
self,
payload: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
构建审计日志元数据 - 子类可覆盖
Args:
payload: 请求体
path_params: URL 路径参数(用于获取 model 等)
"""
# 优先从请求体获取 model其次从 path_params
model = payload.get("model")
if model is None and path_params:
model = path_params.get("model", "unknown")
model = model or "unknown"
stream = payload.get("stream", False)
messages_count = self._extract_message_count(payload)
return {
"action": f"{self.FORMAT_ID.lower()}_request",
"model": model,
"stream": bool(stream),
"max_tokens": payload.get("max_tokens"),
"messages_count": messages_count,
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"tool_count": len(payload.get("tools") or []),
"instructions_present": bool(payload.get("instructions")),
}
async def _handle_provider_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理 Provider 相关异常"""
logger.debug(f"Caught provider exception: {type(e).__name__}")
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
result.request_headers = original_headers
result.request_body = original_request_body
# 确定错误消息
if isinstance(e, ProviderAuthException):
error_message = (
f"提供商认证失败: {str(e)}"
if result.metadata.provider != "unknown"
else "服务端错误: 无可用提供商"
)
result.error_message = error_message
# 处理上游客户端错误(如图片处理失败)
if isinstance(e, UpstreamClientException):
# 返回 400 状态码和清晰的错误消息
result.status_code = e.status_code
result.error_message = e.message
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
# 根据异常类型确定错误类型
if isinstance(e, UpstreamClientException):
error_type = "invalid_request_error"
elif result.status_code == 503:
error_type = "internal_server_error"
else:
error_type = "rate_limit_exceeded"
return self._error_response(
status_code=result.status_code,
error_type=error_type,
message=result.error_message or str(e),
)
async def _handle_unexpected_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理未预期的异常"""
if isinstance(e, ProxyException):
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
else:
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
exception=e,
extra_data={
"exception_class": e.__class__.__name__,
"processing_stage": "request_processing",
"model": model,
"stream": stream,
"traceback_preview": str(traceback.format_exc())[:500],
},
)
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
# 对于未预期的异常,强制设置状态码为 500
result.status_code = 500
result.error_type = "internal_error"
result.request_headers = original_headers
result.request_body = original_request_body
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
return self._error_response(
status_code=500,
error_type="internal_server_error",
message="处理请求时发生内部错误")
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
"""生成错误响应"""
return JSONResponse(
status_code=status_code,
content={
"error": {
"type": error_type,
"message": message,
}
},
)
# =========================================================================
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算总输入上下文(用于阶梯计费判定)
默认实现input_tokens + cache_read_input_tokens
子类可覆盖此方法实现不同的计算逻辑
Args:
input_tokens: 输入 token 数
cache_read_input_tokens: 缓存读取 token 数
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
Returns:
总输入上下文 token 数
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_price_per_1m: Optional[float],
cache_read_price_per_1m: Optional[float],
price_per_request: Optional[float],
tiered_pricing: Optional[dict] = None,
cache_ttl_minutes: Optional[int] = None,
) -> Dict[str, Any]:
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
Args:
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
input_price_per_1m: 输入价格(每 1M tokens
output_price_per_1m: 输出价格(每 1M tokens
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens
cache_read_price_per_1m: 缓存读取价格(每 1M tokens
price_per_request: 按次计费价格
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
包含各项成本的字典
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""根据总输入 token 数确定价格阶梯"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
return tiers[-1] if tiers else None
# =========================================================================
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
# =========================================================================
_CLI_ADAPTER_REGISTRY: Dict[str, Type["CliAdapterBase"]] = {}
_CLI_ADAPTERS_LOADED = False
def register_cli_adapter(adapter_class: Type["CliAdapterBase"]) -> Type["CliAdapterBase"]:
"""
注册 CLI Adapter 类到注册表
用法:
@register_cli_adapter
class ClaudeCliAdapter(CliAdapterBase):
FORMAT_ID = "CLAUDE_CLI"
...
"""
format_id = adapter_class.FORMAT_ID
if format_id and format_id != "UNKNOWN":
_CLI_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
return adapter_class
def _ensure_cli_adapters_loaded():
"""确保所有 CLI Adapter 已被加载(触发注册)"""
global _CLI_ADAPTERS_LOADED
if _CLI_ADAPTERS_LOADED:
return
# 导入各个 CLI Adapter 模块以触发 @register_cli_adapter 装饰器
try:
from src.api.handlers.claude_cli import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.openai_cli import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.gemini_cli import adapter as _ # noqa: F401
except ImportError:
pass
_CLI_ADAPTERS_LOADED = True
def get_cli_adapter_class(api_format: str) -> Optional[Type["CliAdapterBase"]]:
"""根据 API format 获取 CLI Adapter 类"""
_ensure_cli_adapters_loaded()
return _CLI_ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
def get_cli_adapter_instance(api_format: str) -> Optional["CliAdapterBase"]:
"""根据 API format 获取 CLI Adapter 实例"""
adapter_class = get_cli_adapter_class(api_format)
if adapter_class:
return adapter_class()
return None
def list_registered_cli_formats() -> list[str]:
"""返回所有已注册的 CLI API 格式"""
_ensure_cli_adapters_loaded()
return list(_CLI_ADAPTER_REGISTRY.keys())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,279 @@
"""
格式转换器注册表
自动管理不同 API 格式之间的转换器,支持:
- 请求转换:客户端格式 → Provider 格式
- 响应转换Provider 格式 → 客户端格式
使用方法:
1. 实现 Converter 类(需要有 convert_request 和/或 convert_response 方法)
2. 调用 registry.register() 注册转换器
3. 在 Handler 中调用 registry.convert_request/convert_response
示例:
from src.api.handlers.base.format_converter_registry import converter_registry
# 注册转换器
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
# 使用转换器
gemini_request = converter_registry.convert_request(claude_request, "CLAUDE", "GEMINI")
claude_response = converter_registry.convert_response(gemini_response, "GEMINI", "CLAUDE")
"""
from typing import Any, Dict, Optional, Protocol, Tuple
from src.core.logger import logger
class RequestConverter(Protocol):
"""请求转换器协议"""
def convert_request(self, request: Dict[str, Any]) -> Dict[str, Any]: ...
class ResponseConverter(Protocol):
"""响应转换器协议"""
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]: ...
class StreamChunkConverter(Protocol):
"""流式响应块转换器协议"""
def convert_stream_chunk(self, chunk: Dict[str, Any]) -> Dict[str, Any]: ...
class FormatConverterRegistry:
"""
格式转换器注册表
管理不同 API 格式之间的双向转换器
"""
def __init__(self):
# key: (source_format, target_format), value: converter instance
self._converters: Dict[Tuple[str, str], Any] = {}
def register(
self,
source_format: str,
target_format: str,
converter: Any,
) -> None:
"""
注册格式转换器
Args:
source_format: 源格式(如 "CLAUDE", "OPENAI", "GEMINI"
target_format: 目标格式
converter: 转换器实例(需要有 convert_request/convert_response 方法)
"""
key = (source_format.upper(), target_format.upper())
self._converters[key] = converter
logger.info(f"[ConverterRegistry] 注册转换器: {source_format} -> {target_format}")
def get_converter(
self,
source_format: str,
target_format: str,
) -> Optional[Any]:
"""
获取转换器
Args:
source_format: 源格式
target_format: 目标格式
Returns:
转换器实例,如果不存在返回 None
"""
key = (source_format.upper(), target_format.upper())
return self._converters.get(key)
def has_converter(
self,
source_format: str,
target_format: str,
) -> bool:
"""检查是否存在转换器"""
key = (source_format.upper(), target_format.upper())
return key in self._converters
def convert_request(
self,
request: Dict[str, Any],
source_format: str,
target_format: str,
) -> Dict[str, Any]:
"""
转换请求
Args:
request: 原始请求字典
source_format: 源格式(客户端格式)
target_format: 目标格式Provider 格式)
Returns:
转换后的请求字典,如果无需转换或没有转换器则返回原始请求
"""
# 同格式无需转换
if source_format.upper() == target_format.upper():
return request
converter = self.get_converter(source_format, target_format)
if converter is None:
logger.warning(f"[ConverterRegistry] 未找到请求转换器: {source_format} -> {target_format},返回原始请求")
return request
if not hasattr(converter, "convert_request"):
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_request 方法: {source_format} -> {target_format}")
return request
try:
converted = converter.convert_request(request)
logger.debug(f"[ConverterRegistry] 请求转换成功: {source_format} -> {target_format}")
return converted
except Exception as e:
logger.error(f"[ConverterRegistry] 请求转换失败: {source_format} -> {target_format}: {e}")
return request
def convert_response(
self,
response: Dict[str, Any],
source_format: str,
target_format: str,
) -> Dict[str, Any]:
"""
转换响应
Args:
response: 原始响应字典
source_format: 源格式Provider 格式)
target_format: 目标格式(客户端格式)
Returns:
转换后的响应字典,如果无需转换或没有转换器则返回原始响应
"""
# 同格式无需转换
if source_format.upper() == target_format.upper():
return response
converter = self.get_converter(source_format, target_format)
if converter is None:
logger.warning(f"[ConverterRegistry] 未找到响应转换器: {source_format} -> {target_format},返回原始响应")
return response
if not hasattr(converter, "convert_response"):
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_response 方法: {source_format} -> {target_format}")
return response
try:
converted = converter.convert_response(response)
logger.debug(f"[ConverterRegistry] 响应转换成功: {source_format} -> {target_format}")
return converted
except Exception as e:
logger.error(f"[ConverterRegistry] 响应转换失败: {source_format} -> {target_format}: {e}")
return response
def convert_stream_chunk(
self,
chunk: Dict[str, Any],
source_format: str,
target_format: str,
) -> Dict[str, Any]:
"""
转换流式响应块
Args:
chunk: 原始流式响应块
source_format: 源格式Provider 格式)
target_format: 目标格式(客户端格式)
Returns:
转换后的流式响应块
"""
# 同格式无需转换
if source_format.upper() == target_format.upper():
return chunk
converter = self.get_converter(source_format, target_format)
if converter is None:
return chunk
# 优先使用专门的流式转换方法
if hasattr(converter, "convert_stream_chunk"):
try:
return converter.convert_stream_chunk(chunk)
except Exception as e:
logger.error(f"[ConverterRegistry] 流式块转换失败: {source_format} -> {target_format}: {e}")
return chunk
# 降级到普通响应转换
if hasattr(converter, "convert_response"):
try:
return converter.convert_response(chunk)
except Exception:
return chunk
return chunk
def list_converters(self) -> list[Tuple[str, str]]:
"""列出所有已注册的转换器"""
return list(self._converters.keys())
# 全局单例
converter_registry = FormatConverterRegistry()
def register_all_converters():
"""
注册所有内置的格式转换器
在应用启动时调用此函数
"""
# Claude <-> OpenAI
try:
from src.api.handlers.claude.converter import OpenAIToClaudeConverter
from src.api.handlers.openai.converter import ClaudeToOpenAIConverter
converter_registry.register("OPENAI", "CLAUDE", OpenAIToClaudeConverter())
converter_registry.register("CLAUDE", "OPENAI", ClaudeToOpenAIConverter())
except ImportError as e:
logger.warning(f"[ConverterRegistry] 无法加载 Claude/OpenAI 转换器: {e}")
# Claude <-> Gemini
try:
from src.api.handlers.gemini.converter import (
ClaudeToGeminiConverter,
GeminiToClaudeConverter,
)
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
except ImportError as e:
logger.warning(f"[ConverterRegistry] 无法加载 Claude/Gemini 转换器: {e}")
# OpenAI <-> Gemini
try:
from src.api.handlers.gemini.converter import (
GeminiToOpenAIConverter,
OpenAIToGeminiConverter,
)
converter_registry.register("OPENAI", "GEMINI", OpenAIToGeminiConverter())
converter_registry.register("GEMINI", "OPENAI", GeminiToOpenAIConverter())
except ImportError as e:
logger.warning(f"[ConverterRegistry] 无法加载 OpenAI/Gemini 转换器: {e}")
logger.info(f"[ConverterRegistry] 已注册 {len(converter_registry.list_converters())} 个格式转换器")
__all__ = [
"FormatConverterRegistry",
"converter_registry",
"register_all_converters",
]

View File

@@ -0,0 +1,465 @@
"""
响应解析器工厂
直接根据格式 ID 创建对应的 ResponseParser 实现,
不再经过 Protocol 抽象层。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.response_parser import (
ParsedChunk,
ParsedResponse,
ResponseParser,
StreamStats,
)
class OpenAIResponseParser(ResponseParser):
"""OpenAI 格式响应解析器"""
def __init__(self):
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
self._parser = OpenAIStreamParser()
self.name = "OPENAI"
self.api_format = "OPENAI"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
if not line or not line.strip():
return None
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type=None,
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_chunk(parsed):
chunk.is_done = True
stats.has_completion = True
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
choices = response.get("choices", [])
if choices:
message = choices[0].get("message", {})
content = message.get("content")
if content:
result.text_content = content
result.response_id = response.get("id")
# 提取 usage
usage = response.get("usage", {})
result.input_tokens = usage.get("prompt_tokens", 0)
result.output_tokens = usage.get("completion_tokens", 0)
# 检查错误
if "error" in response:
result.is_error = True
error = response.get("error", {})
if isinstance(error, dict):
result.error_type = error.get("type")
result.error_message = error.get("message")
else:
result.error_message = str(error)
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
usage = response.get("usage", {})
return {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
choices = response.get("choices", [])
if choices:
message = choices[0].get("message", {})
content = message.get("content")
if content:
return content
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
return "error" in response
class OpenAICliResponseParser(OpenAIResponseParser):
"""OpenAI CLI 格式响应解析器"""
def __init__(self):
super().__init__()
self.name = "OPENAI_CLI"
self.api_format = "OPENAI_CLI"
class ClaudeResponseParser(ResponseParser):
"""Claude 格式响应解析器"""
def __init__(self):
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
self._parser = ClaudeStreamParser()
self.name = "CLAUDE"
self.api_format = "CLAUDE"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
if not line or not line.strip():
return None
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type=self._parser.get_event_type(parsed),
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_event(parsed):
chunk.is_done = True
stats.has_completion = True
# 提取 usage
usage = self._parser.extract_usage(parsed)
if usage:
chunk.input_tokens = usage.get("input_tokens", 0)
chunk.output_tokens = usage.get("output_tokens", 0)
chunk.cache_creation_tokens = usage.get("cache_creation_tokens", 0)
chunk.cache_read_tokens = usage.get("cache_read_tokens", 0)
stats.input_tokens = chunk.input_tokens
stats.output_tokens = chunk.output_tokens
stats.cache_creation_tokens = chunk.cache_creation_tokens
stats.cache_read_tokens = chunk.cache_read_tokens
# 检查错误
if self._parser.is_error_event(parsed):
chunk.is_error = True
error = parsed.get("error", {})
if isinstance(error, dict):
chunk.error_message = error.get("message", str(error))
else:
chunk.error_message = str(error)
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
content = response.get("content", [])
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
result.text_content = "".join(text_parts)
result.response_id = response.get("id")
# 提取 usage
usage = response.get("usage", {})
result.input_tokens = usage.get("input_tokens", 0)
result.output_tokens = usage.get("output_tokens", 0)
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
# 检查错误
if "error" in response or response.get("type") == "error":
result.is_error = True
error = response.get("error", {})
if isinstance(error, dict):
result.error_type = error.get("type")
result.error_message = error.get("message")
else:
result.error_message = str(error)
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
usage = response.get("usage", {})
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
content = response.get("content", [])
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
return "".join(text_parts)
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
return "error" in response or response.get("type") == "error"
class ClaudeCliResponseParser(ClaudeResponseParser):
"""Claude CLI 格式响应解析器"""
def __init__(self):
super().__init__()
self.name = "CLAUDE_CLI"
self.api_format = "CLAUDE_CLI"
class GeminiResponseParser(ResponseParser):
"""Gemini 格式响应解析器"""
def __init__(self):
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
self._parser = GeminiStreamParser()
self.name = "GEMINI"
self.api_format = "GEMINI"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
"""
解析 Gemini SSE 行
Gemini 的流式响应使用 SSE 格式 (data: {...})
"""
if not line or not line.strip():
return None
# Gemini SSE 格式: data: {...}
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type="content",
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_event(parsed):
chunk.is_done = True
stats.has_completion = True
# 提取 usage
usage = self._parser.extract_usage(parsed)
if usage:
chunk.input_tokens = usage.get("input_tokens", 0)
chunk.output_tokens = usage.get("output_tokens", 0)
chunk.cache_read_tokens = usage.get("cached_tokens", 0)
stats.input_tokens = chunk.input_tokens
stats.output_tokens = chunk.output_tokens
stats.cache_read_tokens = chunk.cache_read_tokens
# 检查错误
if self._parser.is_error_event(parsed):
chunk.is_error = True
error = parsed.get("error", {})
if isinstance(error, dict):
chunk.error_message = error.get("message", str(error))
else:
chunk.error_message = str(error)
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
candidates = response.get("candidates", [])
if candidates:
content = candidates[0].get("content", {})
parts = content.get("parts", [])
text_parts = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
result.text_content = "".join(text_parts)
result.response_id = response.get("modelVersion")
# 提取 usage调用 GeminiStreamParser.extract_usage 作为单一实现源)
usage = self._parser.extract_usage(response)
if usage:
result.input_tokens = usage.get("input_tokens", 0)
result.output_tokens = usage.get("output_tokens", 0)
result.cache_read_tokens = usage.get("cached_tokens", 0)
# 检查错误(使用增强的错误检测)
error_info = self._parser.extract_error_info(response)
if error_info:
result.is_error = True
result.error_type = error_info.get("status")
result.error_message = error_info.get("message")
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
"""
从 Gemini 响应中提取 token 使用量
调用 GeminiStreamParser.extract_usage 作为单一实现源
"""
usage = self._parser.extract_usage(response)
if not usage:
return {
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
}
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": 0,
"cache_read_tokens": usage.get("cached_tokens", 0),
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
candidates = response.get("candidates", [])
if candidates:
content = candidates[0].get("content", {})
parts = content.get("parts", [])
text_parts = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
return "".join(text_parts)
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
"""
判断响应是否为错误响应
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
"""
return self._parser.is_error_event(response)
class GeminiCliResponseParser(GeminiResponseParser):
"""Gemini CLI 格式响应解析器"""
def __init__(self):
super().__init__()
self.name = "GEMINI_CLI"
self.api_format = "GEMINI_CLI"
# 解析器注册表
_PARSERS = {
"CLAUDE": ClaudeResponseParser,
"CLAUDE_CLI": ClaudeCliResponseParser,
"OPENAI": OpenAIResponseParser,
"OPENAI_CLI": OpenAICliResponseParser,
"GEMINI": GeminiResponseParser,
"GEMINI_CLI": GeminiCliResponseParser,
}
def get_parser_for_format(format_id: str) -> ResponseParser:
"""
根据格式 ID 获取 ResponseParser
Args:
format_id: 格式 ID"CLAUDE", "OPENAI", "CLAUDE_CLI", "OPENAI_CLI"
Returns:
ResponseParser 实例
Raises:
KeyError: 格式不存在
"""
format_id = format_id.upper()
if format_id not in _PARSERS:
raise KeyError(f"Unknown format: {format_id}")
return _PARSERS[format_id]()
def is_cli_format(format_id: str) -> bool:
"""判断是否为 CLI 格式"""
return format_id.upper().endswith("_CLI")
__all__ = [
"OpenAIResponseParser",
"OpenAICliResponseParser",
"ClaudeResponseParser",
"ClaudeCliResponseParser",
"GeminiResponseParser",
"GeminiCliResponseParser",
"get_parser_for_format",
"get_parser_from_protocol",
"is_cli_format",
]

View File

@@ -0,0 +1,207 @@
"""
请求构建器 - 透传模式
透传模式 (Passthrough): CLI 和 Chat 等场景,原样转发请求体和头部
- 清理敏感头部authorization, x-api-key, host, content-length 等
- 保留所有其他头部和请求体字段
- 适用于Claude CLI、OpenAI CLI、Chat API 等场景
使用方式:
builder = PassthroughRequestBuilder()
payload, headers = builder.build(original_body, original_headers, endpoint, key)
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, FrozenSet, Optional, Tuple
from src.core.crypto import crypto_service
# ==============================================================================
# 统一的头部配置常量
# ==============================================================================
# 敏感头部 - 透传时需要清理(黑名单)
# 这些头部要么包含认证信息,要么由代理层重新生成
SENSITIVE_HEADERS: FrozenSet[str] = frozenset(
{
"authorization",
"x-api-key",
"x-goog-api-key", # Gemini API 认证头
"host",
"content-length",
"transfer-encoding",
"connection",
# 不透传 accept-encoding让 httpx 自己协商压缩格式
# 避免客户端请求 brotli/zstd 但 httpx 不支持解压的问题
"accept-encoding",
}
)
# ==============================================================================
# 请求构建器
# ==============================================================================
class RequestBuilder(ABC):
"""请求构建器抽象基类"""
@abstractmethod
def build_payload(
self,
original_body: Dict[str, Any],
*,
mapped_model: Optional[str] = None,
is_stream: bool = False,
) -> Dict[str, Any]:
"""构建请求体"""
pass
@abstractmethod
def build_headers(
self,
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
*,
extra_headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""构建请求头"""
pass
def build(
self,
original_body: Dict[str, Any],
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
*,
mapped_model: Optional[str] = None,
is_stream: bool = False,
extra_headers: Optional[Dict[str, str]] = None,
) -> Tuple[Dict[str, Any], Dict[str, str]]:
"""
构建完整的请求(请求体 + 请求头)
Returns:
Tuple[payload, headers]
"""
payload = self.build_payload(
original_body,
mapped_model=mapped_model,
is_stream=is_stream,
)
headers = self.build_headers(
original_headers,
endpoint,
key,
extra_headers=extra_headers,
)
return payload, headers
class PassthroughRequestBuilder(RequestBuilder):
"""
透传模式请求构建器
适用于 CLI 等场景,尽量保持请求原样:
- 请求体直接复制只修改必要字段model, stream
- 请求头:清理敏感头部(黑名单),透传其他所有头部
"""
def build_payload(
self,
original_body: Dict[str, Any],
*,
mapped_model: Optional[str] = None, # noqa: ARG002 - 由 apply_mapped_model 处理
is_stream: bool = False, # noqa: ARG002 - 保留原始值,不自动添加
) -> Dict[str, Any]:
"""
透传请求体 - 原样复制,不做任何修改
透传模式下:
- model: 由各 handler 的 apply_mapped_model 方法处理
- stream: 保留客户端原始值(不同 API 处理方式不同)
"""
return dict(original_body)
def build_headers(
self,
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
*,
extra_headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""
透传请求头 - 清理敏感头部(黑名单),透传其他所有头部
"""
from src.core.api_format_metadata import get_auth_config, resolve_api_format
headers: Dict[str, str] = {}
# 1. 根据 API 格式自动设置认证头
decrypted_key = crypto_service.decrypt(key.api_key)
api_format = getattr(endpoint, "api_format", None)
resolved_format = resolve_api_format(api_format)
auth_header, auth_type = (
get_auth_config(resolved_format) if resolved_format else ("Authorization", "bearer")
)
if auth_type == "bearer":
headers[auth_header] = f"Bearer {decrypted_key}"
else:
headers[auth_header] = decrypted_key
# 2. 添加 endpoint 配置的额外头部
if endpoint.headers:
headers.update(endpoint.headers)
# 3. 透传原始头部(排除敏感头部 - 黑名单模式)
if original_headers:
for name, value in original_headers.items():
lower_name = name.lower()
# 跳过敏感头部
if lower_name in SENSITIVE_HEADERS:
continue
headers[name] = value
# 4. 添加额外头部
if extra_headers:
headers.update(extra_headers)
# 5. 确保有 Content-Type
if "Content-Type" not in headers and "content-type" not in headers:
headers["Content-Type"] = "application/json"
return headers
# ==============================================================================
# 便捷函数
# ==============================================================================
def build_passthrough_request(
original_body: Dict[str, Any],
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
) -> Tuple[Dict[str, Any], Dict[str, str]]:
"""
构建透传模式的请求
纯透传:原样复制请求体,只处理请求头(认证等)。
model mapping 和 stream 由调用方自行处理(不同 API 格式处理方式不同)。
"""
builder = PassthroughRequestBuilder()
return builder.build(
original_body,
original_headers,
endpoint,
key,
)

View File

@@ -0,0 +1,174 @@
"""
响应解析器基类 - 定义统一的响应解析接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
@dataclass
class ParsedChunk:
"""解析后的流式数据块"""
# 原始数据
raw_line: str
event_type: Optional[str] = None
data: Optional[Dict[str, Any]] = None
# 提取的内容
text_delta: str = ""
is_done: bool = False
is_error: bool = False
error_message: Optional[str] = None
# 使用量信息(通常在最后一个 chunk 中)
input_tokens: int = 0
output_tokens: int = 0
cache_creation_tokens: int = 0
cache_read_tokens: int = 0
# 响应 ID
response_id: Optional[str] = None
@dataclass
class StreamStats:
"""流式响应统计信息"""
# 计数
chunk_count: int = 0
data_count: int = 0
# Token 使用量
input_tokens: int = 0
output_tokens: int = 0
cache_creation_tokens: int = 0
cache_read_tokens: int = 0
# 内容
collected_text: str = ""
response_id: Optional[str] = None
# 状态
has_completion: bool = False
status_code: int = 200
error_message: Optional[str] = None
# Provider 信息
provider_name: Optional[str] = None
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
# 响应头和完整响应
response_headers: Dict[str, str] = field(default_factory=dict)
final_response: Optional[Dict[str, Any]] = None
@dataclass
class ParsedResponse:
"""解析后的非流式响应"""
# 原始响应
raw_response: Dict[str, Any]
status_code: int
# 提取的内容
text_content: str = ""
response_id: Optional[str] = None
# 使用量
input_tokens: int = 0
output_tokens: int = 0
cache_creation_tokens: int = 0
cache_read_tokens: int = 0
# 错误信息
is_error: bool = False
error_type: Optional[str] = None
error_message: Optional[str] = None
class ResponseParser(ABC):
"""
响应解析器基类
定义统一的接口来解析不同 API 格式的响应。
子类需要实现具体的解析逻辑。
"""
# 解析器名称(用于日志)
name: str = "base"
# 支持的 API 格式
api_format: str = "UNKNOWN"
@abstractmethod
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
"""
解析单行 SSE 数据
Args:
line: SSE 行数据
stats: 流统计对象(会被更新)
Returns:
解析后的数据块,如果行不包含有效数据则返回 None
"""
pass
@abstractmethod
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
"""
解析非流式响应
Args:
response: 响应 JSON
status_code: HTTP 状态码
Returns:
解析后的响应对象
"""
pass
@abstractmethod
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
"""
从响应中提取 token 使用量
Args:
response: 响应 JSON
Returns:
包含 input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens 的字典
"""
pass
@abstractmethod
def extract_text_content(self, response: Dict[str, Any]) -> str:
"""
从响应中提取文本内容
Args:
response: 响应 JSON
Returns:
提取的文本内容
"""
pass
def is_error_response(self, response: Dict[str, Any]) -> bool:
"""
判断响应是否为错误响应
Args:
response: 响应 JSON
Returns:
是否为错误响应
"""
return "error" in response
def create_stats(self) -> StreamStats:
"""创建新的流统计对象"""
return StreamStats()

View File

@@ -0,0 +1,17 @@
"""
Claude Chat API 处理器
"""
from src.api.handlers.claude.adapter import (
ClaudeChatAdapter,
ClaudeTokenCountAdapter,
build_claude_adapter,
)
from src.api.handlers.claude.handler import ClaudeChatHandler
__all__ = [
"ClaudeChatAdapter",
"ClaudeTokenCountAdapter",
"build_claude_adapter",
"ClaudeChatHandler",
]

View File

@@ -0,0 +1,228 @@
"""
Claude Chat Adapter - 基于 ChatAdapterBase 的 Claude Chat API 适配器
处理 /v1/messages 端点的 Claude Chat 格式请求。
"""
from typing import Any, Dict, Optional, Type
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.context import ApiRequestContext
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.logger import logger
from src.core.optimization_utils import TokenCounter
from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest
class ClaudeCapabilityDetector:
"""Claude API 能力检测器"""
@staticmethod
def detect_from_headers(
headers: Dict[str, str],
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""
从 Claude 请求头检测能力需求
检测规则:
- anthropic-beta: context-1m-xxx -> context_1m: True
Args:
headers: 请求头字典
request_body: 请求体Claude 不使用,保留用于接口统一)
"""
requirements: Dict[str, bool] = {}
# 检查 anthropic-beta 请求头(大小写不敏感)
beta_header = None
for key, value in headers.items():
if key.lower() == "anthropic-beta":
beta_header = value
break
if beta_header:
# 检查是否包含 context-1m 标识
if "context-1m" in beta_header.lower():
requirements["context_1m"] = True
return requirements
@register_adapter
class ClaudeChatAdapter(ChatAdapterBase):
"""
Claude Chat API 适配器
处理 Claude Chat 格式的请求(/v1/messages 端点,进行格式验证)。
"""
FORMAT_ID = "CLAUDE"
name = "claude.chat"
@property
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.claude.handler import ClaudeChatHandler
return ClaudeChatHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["CLAUDE"])
logger.info(f"[{self.name}] 初始化Chat模式适配器 | API格式: {self.allowed_api_formats}")
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (x-api-key)"""
return request.headers.get("x-api-key")
def detect_capability_requirements(
self,
headers: Dict[str, str],
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""检测 Claude 请求中隐含的能力需求"""
return ClaudeCapabilityDetector.detect_from_headers(headers)
# =========================================================================
# Claude 特定的计费逻辑
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算 Claude 的总输入上下文(用于阶梯计费判定)
Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
"""
return input_tokens + cache_creation_input_tokens + cache_read_input_tokens
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
"""验证请求体"""
try:
if not isinstance(original_request_body, dict):
raise ValueError("Request body must be a JSON object")
required_fields = ["model", "messages", "max_tokens"]
missing_fields = [f for f in required_fields if f not in original_request_body]
if missing_fields:
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
request = ClaudeMessagesRequest.model_validate(
original_request_body,
strict=False,
)
except ValueError as e:
logger.error(f"请求体基本验证失败: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
request = ClaudeMessagesRequest.model_construct(
model=original_request_body.get("model"),
max_tokens=original_request_body.get("max_tokens"),
messages=original_request_body.get("messages", []),
stream=original_request_body.get("stream", False),
)
return request
def _build_audit_metadata(self, _payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
"""构建 Claude Chat 特定的审计元数据"""
role_counts: dict[str, int] = {}
for message in request_obj.messages:
role_counts[message.role] = role_counts.get(message.role, 0) + 1
return {
"action": "claude_messages",
"model": request_obj.model,
"stream": bool(request_obj.stream),
"max_tokens": request_obj.max_tokens,
"temperature": getattr(request_obj, "temperature", None),
"top_p": getattr(request_obj, "top_p", None),
"top_k": getattr(request_obj, "top_k", None),
"messages_count": len(request_obj.messages),
"message_roles": role_counts,
"stop_sequences": len(request_obj.stop_sequences or []),
"tools_count": len(request_obj.tools or []),
"system_present": bool(request_obj.system),
"metadata_present": bool(request_obj.metadata),
"thinking_enabled": bool(request_obj.thinking),
}
def build_claude_adapter(x_app_header: Optional[str]):
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
if x_app_header and x_app_header.lower() == "cli":
from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter
return ClaudeCliAdapter()
return ClaudeChatAdapter()
class ClaudeTokenCountAdapter(ApiAdapter):
"""计算 Claude 请求 Token 数的轻量适配器。"""
name = "claude.token_count"
mode = ApiMode.STANDARD
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (x-api-key 或 Authorization: Bearer)"""
# 优先检查 x-api-key
api_key = request.headers.get("x-api-key")
if api_key:
return api_key
# 降级到 Authorization: Bearer
authorization = request.headers.get("authorization")
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
async def handle(self, context: ApiRequestContext):
payload = context.ensure_json_body()
try:
request = ClaudeTokenCountRequest.model_validate(payload, strict=False)
except Exception as e:
logger.error(f"Token count payload invalid: {e}")
raise HTTPException(status_code=400, detail="Invalid token count payload") from e
token_counter = TokenCounter()
total_tokens = 0
if request.system:
if isinstance(request.system, str):
total_tokens += token_counter.count_tokens(request.system, request.model)
elif isinstance(request.system, list):
for block in request.system:
if hasattr(block, "text"):
total_tokens += token_counter.count_tokens(block.text, request.model)
messages_dict = [
msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in request.messages
]
total_tokens += token_counter.count_messages_tokens(messages_dict, request.model)
context.add_audit_metadata(
action="claude_token_count",
model=request.model,
messages_count=len(request.messages),
system_present=bool(request.system),
tools_count=len(request.tools or []),
thinking_enabled=bool(request.thinking),
input_tokens=total_tokens,
)
return JSONResponse({"input_tokens": total_tokens})
__all__ = [
"ClaudeChatAdapter",
"ClaudeTokenCountAdapter",
"build_claude_adapter",
]

View File

@@ -0,0 +1,490 @@
"""
OpenAI -> Claude 格式转换器
将 OpenAI Chat Completions API 格式转换为 Claude Messages API 格式。
"""
from __future__ import annotations
import json
import time
import uuid
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
class OpenAIToClaudeConverter:
"""
OpenAI -> Claude 格式转换器
支持:
- 请求转换OpenAI Chat Request -> Claude Request
- 响应转换OpenAI Chat Response -> Claude Response
- 流式转换OpenAI SSE -> Claude SSE
"""
# 内容类型常量
CONTENT_TYPE_TEXT = "text"
CONTENT_TYPE_IMAGE = "image"
CONTENT_TYPE_TOOL_USE = "tool_use"
CONTENT_TYPE_TOOL_RESULT = "tool_result"
# 停止原因映射OpenAI -> Claude
FINISH_REASON_MAP = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"function_call": "tool_use",
"content_filter": "end_turn",
}
def __init__(self, model_mapping: Optional[Dict[str, str]] = None):
"""
Args:
model_mapping: OpenAI 模型到 Claude 模型的映射
"""
self._model_mapping = model_mapping or {}
# ==================== 请求转换 ====================
def convert_request(self, request: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
"""
将 OpenAI 请求转换为 Claude 格式
Args:
request: OpenAI 请求Dict 或 Pydantic 模型)
Returns:
Claude 格式的请求字典
"""
if hasattr(request, "model_dump"):
data = request.model_dump(exclude_none=True)
else:
data = dict(request)
# 模型映射
model = data.get("model", "")
claude_model = self._model_mapping.get(model, model)
# 处理消息
system_content: Optional[str] = None
claude_messages: List[Dict[str, Any]] = []
for message in data.get("messages", []):
role = message.get("role")
# 提取 system 消息
if role == "system":
system_content = self._collapse_content(message.get("content"))
continue
# 转换其他消息
converted = self._convert_message(message)
if converted:
claude_messages.append(converted)
# 构建 Claude 请求
result: Dict[str, Any] = {
"model": claude_model,
"messages": claude_messages,
"max_tokens": data.get("max_tokens") or 4096,
}
# 可选参数
if data.get("temperature") is not None:
result["temperature"] = data["temperature"]
if data.get("top_p") is not None:
result["top_p"] = data["top_p"]
if data.get("stream"):
result["stream"] = data["stream"]
if data.get("stop"):
result["stop_sequences"] = self._convert_stop(data["stop"])
if system_content:
result["system"] = system_content
# 工具转换
tools = self._convert_tools(data.get("tools"))
if tools:
result["tools"] = tools
tool_choice = self._convert_tool_choice(data.get("tool_choice"))
if tool_choice:
result["tool_choice"] = tool_choice
return result
def _convert_message(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""转换单条消息"""
role = message.get("role")
if role == "user":
return self._convert_user_message(message)
if role == "assistant":
return self._convert_assistant_message(message)
if role == "tool":
return self._convert_tool_message(message)
return None
def _convert_user_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""转换用户消息"""
content = message.get("content")
if isinstance(content, str) or content is None:
return {"role": "user", "content": content or ""}
# 转换内容数组
claude_content: List[Dict[str, Any]] = []
for item in content:
item_type = item.get("type")
if item_type == "text":
claude_content.append(
{"type": self.CONTENT_TYPE_TEXT, "text": item.get("text", "")}
)
elif item_type == "image_url":
image_url = (item.get("image_url") or {}).get("url", "")
claude_content.append(self._convert_image_url(image_url))
return {"role": "user", "content": claude_content}
def _convert_assistant_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""转换助手消息"""
content_blocks: List[Dict[str, Any]] = []
# 处理文本内容
content = message.get("content")
if isinstance(content, str):
content_blocks.append({"type": self.CONTENT_TYPE_TEXT, "text": content})
elif isinstance(content, list):
for part in content:
if part.get("type") == "text":
content_blocks.append(
{"type": self.CONTENT_TYPE_TEXT, "text": part.get("text", "")}
)
# 处理工具调用
for tool_call in message.get("tool_calls") or []:
if tool_call.get("type") == "function":
function = tool_call.get("function", {})
arguments = function.get("arguments", "{}")
try:
input_data = json.loads(arguments)
except json.JSONDecodeError:
input_data = {"raw": arguments}
content_blocks.append(
{
"type": self.CONTENT_TYPE_TOOL_USE,
"id": tool_call.get("id", ""),
"name": function.get("name", ""),
"input": input_data,
}
)
# 简化单文本内容
if not content_blocks:
return {"role": "assistant", "content": ""}
if len(content_blocks) == 1 and content_blocks[0]["type"] == self.CONTENT_TYPE_TEXT:
return {"role": "assistant", "content": content_blocks[0]["text"]}
return {"role": "assistant", "content": content_blocks}
def _convert_tool_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""转换工具结果消息"""
tool_content = message.get("content", "")
# 尝试解析 JSON
parsed_content = tool_content
if isinstance(tool_content, str):
try:
parsed_content = json.loads(tool_content)
except json.JSONDecodeError:
pass
tool_block = {
"type": self.CONTENT_TYPE_TOOL_RESULT,
"tool_use_id": message.get("tool_call_id", ""),
"content": parsed_content,
}
return {"role": "user", "content": [tool_block]}
def _convert_tools(
self, tools: Optional[List[Dict[str, Any]]]
) -> Optional[List[Dict[str, Any]]]:
"""转换工具定义"""
if not tools:
return None
result: List[Dict[str, Any]] = []
for tool in tools:
if tool.get("type") != "function":
continue
function = tool.get("function", {})
result.append(
{
"name": function.get("name", ""),
"description": function.get("description"),
"input_schema": function.get("parameters") or {},
}
)
return result if result else None
def _convert_tool_choice(
self, tool_choice: Optional[Union[str, Dict[str, Any]]]
) -> Optional[Dict[str, Any]]:
"""转换工具选择"""
if tool_choice is None:
return None
if tool_choice == "none":
return {"type": "none"}
if tool_choice == "auto":
return {"type": "auto"}
if tool_choice == "required":
return {"type": "any"}
if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
function = tool_choice.get("function", {})
return {"type": "tool_use", "name": function.get("name", "")}
return {"type": "auto"}
def _convert_image_url(self, image_url: str) -> Dict[str, Any]:
"""转换图片 URL"""
if image_url.startswith("data:"):
header, _, data = image_url.partition(",")
media_type = "image/jpeg"
if ";" in header:
media_type = header.split(";")[0].split(":")[-1]
return {
"type": self.CONTENT_TYPE_IMAGE,
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
},
}
return {"type": self.CONTENT_TYPE_TEXT, "text": f"[Image: {image_url}]"}
def _convert_stop(self, stop: Optional[Union[str, List[str]]]) -> Optional[List[str]]:
"""转换停止序列"""
if stop is None:
return None
if isinstance(stop, str):
return [stop]
return stop
# ==================== 响应转换 ====================
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""
将 OpenAI 响应转换为 Claude 格式
Args:
response: OpenAI 响应字典
Returns:
Claude 格式的响应字典
"""
choices = response.get("choices", [])
if not choices:
return self._empty_claude_response(response)
choice = choices[0]
message = choice.get("message", {})
# 构建 content 数组
content: List[Dict[str, Any]] = []
# 处理文本
text_content = message.get("content")
if text_content:
content.append(
{
"type": self.CONTENT_TYPE_TEXT,
"text": text_content,
}
)
# 处理工具调用
for tool_call in message.get("tool_calls") or []:
if tool_call.get("type") == "function":
function = tool_call.get("function", {})
arguments = function.get("arguments", "{}")
try:
input_data = json.loads(arguments)
except json.JSONDecodeError:
input_data = {"raw": arguments}
content.append(
{
"type": self.CONTENT_TYPE_TOOL_USE,
"id": tool_call.get("id", ""),
"name": function.get("name", ""),
"input": input_data,
}
)
# 转换 finish_reason
finish_reason = choice.get("finish_reason")
stop_reason = self.FINISH_REASON_MAP.get(finish_reason, "end_turn")
# 转换 usage
usage = response.get("usage", {})
claude_usage = {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
}
return {
"id": f"msg_{response.get('id', uuid.uuid4().hex[:8])}",
"type": "message",
"role": "assistant",
"model": response.get("model", ""),
"content": content,
"stop_reason": stop_reason,
"stop_sequence": None,
"usage": claude_usage,
}
def _empty_claude_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""构建空的 Claude 响应"""
return {
"id": f"msg_{response.get('id', uuid.uuid4().hex[:8])}",
"type": "message",
"role": "assistant",
"model": response.get("model", ""),
"content": [],
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {"input_tokens": 0, "output_tokens": 0},
}
# ==================== 流式转换 ====================
def convert_stream_chunk(
self,
chunk: Dict[str, Any],
model: str = "",
message_id: Optional[str] = None,
message_started: bool = False,
) -> List[Dict[str, Any]]:
"""
将 OpenAI SSE chunk 转换为 Claude SSE 事件
Args:
chunk: OpenAI SSE chunk
model: 模型名称
message_id: 消息 ID
message_started: 是否已发送 message_start
Returns:
Claude SSE 事件列表
"""
events: List[Dict[str, Any]] = []
choices = chunk.get("choices", [])
if not choices:
return events
choice = choices[0]
delta = choice.get("delta", {})
finish_reason = choice.get("finish_reason")
# 处理角色(第一个 chunk
role = delta.get("role")
if role and not message_started:
msg_id = message_id or f"msg_{uuid.uuid4().hex[:8]}"
events.append(
{
"type": "message_start",
"message": {
"id": msg_id,
"type": "message",
"role": role,
"model": model,
"content": [],
"stop_reason": None,
"stop_sequence": None,
},
}
)
# 处理文本内容
content_delta = delta.get("content")
if isinstance(content_delta, str):
events.append(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": content_delta},
}
)
# 处理工具调用
tool_calls = delta.get("tool_calls", [])
for tool_call in tool_calls:
index = tool_call.get("index", 0)
# 工具调用开始
if "id" in tool_call:
function = tool_call.get("function", {})
events.append(
{
"type": "content_block_start",
"index": index,
"content_block": {
"type": self.CONTENT_TYPE_TOOL_USE,
"id": tool_call["id"],
"name": function.get("name", ""),
},
}
)
# 工具调用参数增量
function = tool_call.get("function", {})
if "arguments" in function:
events.append(
{
"type": "content_block_delta",
"index": index,
"delta": {
"type": "input_json_delta",
"partial_json": function.get("arguments", ""),
},
}
)
# 处理结束
if finish_reason:
stop_reason = self.FINISH_REASON_MAP.get(finish_reason, "end_turn")
events.append(
{
"type": "message_delta",
"delta": {"stop_reason": stop_reason},
}
)
return events
# ==================== 工具方法 ====================
def _collapse_content(
self, content: Optional[Union[str, List[Dict[str, Any]]]]
) -> Optional[str]:
"""折叠内容为字符串"""
if isinstance(content, str):
return content
if not content:
return None
text_parts = [part.get("text", "") for part in content if part.get("type") == "text"]
return "\n\n".join(filter(None, text_parts)) or None
__all__ = ["OpenAIToClaudeConverter"]

View File

@@ -0,0 +1,150 @@
"""
Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
继承 ChatHandlerBase只需覆盖格式特定的方法。
代码量从原来的 ~1470 行减少到 ~120 行。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
class ClaudeChatHandler(ChatHandlerBase):
"""
Claude Chat Handler - 处理 Claude Chat/CLI API 格式的请求
格式特点:
- 使用 input_tokens/output_tokens
- 支持 cache_creation_input_tokens/cache_read_input_tokens
- 请求格式ClaudeMessagesRequest
"""
FORMAT_ID = "CLAUDE"
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
) -> str:
"""
从请求中提取模型名 - Claude 格式实现
Claude API 的 model 在请求体顶级字段。
Args:
request_body: 请求体
path_params: URL 路径参数Claude 不使用)
Returns:
模型名
"""
model = request_body.get("model")
return str(model) if model else "unknown"
def apply_mapped_model(
self,
request_body: Dict[str, Any],
mapped_model: str,
) -> Dict[str, Any]:
"""
将映射后的模型名应用到请求体
Claude API 的 model 在请求体顶级字段。
Args:
request_body: 原始请求体
mapped_model: 映射后的模型名
Returns:
更新了 model 字段的请求体
"""
result = dict(request_body)
result["model"] = mapped_model
return result
async def _convert_request(self, request):
"""
将请求转换为 Claude 格式
Args:
request: 原始请求对象
Returns:
ClaudeMessagesRequest 对象
"""
from src.api.handlers.claude.converter import OpenAIToClaudeConverter
from src.models.claude import ClaudeMessagesRequest
from src.models.openai import OpenAIRequest
# 如果已经是 Claude 格式,直接返回
if isinstance(request, ClaudeMessagesRequest):
return request
# 如果是 OpenAI 格式,转换为 Claude 格式
if isinstance(request, OpenAIRequest):
converter = OpenAIToClaudeConverter()
claude_dict = converter.convert_request(request.dict())
return ClaudeMessagesRequest(**claude_dict)
# 如果是字典,根据内容判断格式
if isinstance(request, dict):
if "messages" in request and len(request["messages"]) > 0:
first_msg = request["messages"][0]
if "role" in first_msg and "content" in first_msg:
# 可能是 OpenAI 格式
converter = OpenAIToClaudeConverter()
claude_dict = converter.convert_request(request)
return ClaudeMessagesRequest(**claude_dict)
# 否则假设已经是 Claude 格式
return ClaudeMessagesRequest(**request)
return request
def _extract_usage(self, response: Dict) -> Dict[str, int]:
"""
从 Claude 响应中提取 token 使用情况
Claude 格式使用:
- input_tokens / output_tokens
- cache_creation_input_tokens / cache_read_input_tokens
"""
usage = response.get("usage", {})
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
# 处理新的 cache_creation 格式
if "cache_creation" in usage:
cache_creation_data = usage.get("cache_creation", {})
if not cache_creation_input_tokens:
cache_creation_input_tokens = cache_creation_data.get(
"ephemeral_5m_input_tokens", 0
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cache_creation_input_tokens": cache_creation_input_tokens,
"cache_read_input_tokens": cache_read_input_tokens,
}
def _normalize_response(self, response: Dict) -> Dict:
"""
规范化 Claude 响应
Args:
response: 原始响应
Returns:
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
return self.response_normalizer.normalize_claude_response(
response_data=response,
request_id=self.request_id,
)
return response

View File

@@ -0,0 +1,241 @@
"""
Claude SSE 流解析器
解析 Claude Messages API 的 Server-Sent Events 流。
"""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
class ClaudeStreamParser:
"""
Claude SSE 流解析器
解析 Claude Messages API 的 SSE 事件流。
事件类型:
- message_start: 消息开始,包含初始 message 对象
- content_block_start: 内容块开始
- content_block_delta: 内容块增量(文本、工具输入等)
- content_block_stop: 内容块结束
- message_delta: 消息增量,包含 stop_reason 和最终 usage
- message_stop: 消息结束
- ping: 心跳事件
- error: 错误事件
"""
# Claude SSE 事件类型
EVENT_MESSAGE_START = "message_start"
EVENT_MESSAGE_STOP = "message_stop"
EVENT_MESSAGE_DELTA = "message_delta"
EVENT_CONTENT_BLOCK_START = "content_block_start"
EVENT_CONTENT_BLOCK_STOP = "content_block_stop"
EVENT_CONTENT_BLOCK_DELTA = "content_block_delta"
EVENT_PING = "ping"
EVENT_ERROR = "error"
# Delta 类型
DELTA_TEXT = "text_delta"
DELTA_INPUT_JSON = "input_json_delta"
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
"""
解析 SSE 数据块
Args:
chunk: 原始 SSE 数据bytes 或 str
Returns:
解析后的事件列表
"""
if isinstance(chunk, bytes):
text = chunk.decode("utf-8")
else:
text = chunk
events: List[Dict[str, Any]] = []
lines = text.strip().split("\n")
current_event_type: Optional[str] = None
for line in lines:
line = line.strip()
if not line:
continue
# 解析事件类型行
if line.startswith("event: "):
current_event_type = line[7:]
continue
# 解析数据行
if line.startswith("data: "):
data_str = line[6:]
# 处理 [DONE] 标记
if data_str == "[DONE]":
events.append({"type": "__done__", "raw": "[DONE]"})
continue
try:
data = json.loads(data_str)
# 如果数据中没有 type使用事件行的类型
if "type" not in data and current_event_type:
data["type"] = current_event_type
events.append(data)
except json.JSONDecodeError:
# 无法解析的数据,跳过
pass
current_event_type = None
return events
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
"""
解析单行 SSE 数据
Args:
line: SSE 数据行(已去除 "data: " 前缀)
Returns:
解析后的事件字典,如果无法解析返回 None
"""
if not line or line == "[DONE]":
return None
try:
return json.loads(line)
except json.JSONDecodeError:
return None
def is_done_event(self, event: Dict[str, Any]) -> bool:
"""
判断是否为结束事件
Args:
event: 事件字典
Returns:
True 如果是结束事件
"""
event_type = event.get("type")
return event_type in (self.EVENT_MESSAGE_STOP, "__done__")
def is_error_event(self, event: Dict[str, Any]) -> bool:
"""
判断是否为错误事件
Args:
event: 事件字典
Returns:
True 如果是错误事件
"""
return event.get("type") == self.EVENT_ERROR
def get_event_type(self, event: Dict[str, Any]) -> Optional[str]:
"""
获取事件类型
Args:
event: 事件字典
Returns:
事件类型字符串
"""
return event.get("type")
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
"""
从 content_block_delta 事件中提取文本增量
Args:
event: 事件字典
Returns:
文本增量,如果不是文本 delta 返回 None
"""
if event.get("type") != self.EVENT_CONTENT_BLOCK_DELTA:
return None
delta = event.get("delta", {})
if delta.get("type") == self.DELTA_TEXT:
return delta.get("text")
return None
def extract_usage(self, event: Dict[str, Any]) -> Optional[Dict[str, int]]:
"""
从事件中提取 token 使用量
Args:
event: 事件字典
Returns:
使用量字典,如果没有使用量信息返回 None
"""
event_type = event.get("type")
# message_start 事件包含初始 usage
if event_type == self.EVENT_MESSAGE_START:
message = event.get("message", {})
usage = message.get("usage", {})
if usage:
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}
# message_delta 事件包含最终 usage
if event_type == self.EVENT_MESSAGE_DELTA:
usage = event.get("usage", {})
if usage:
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}
return None
def extract_message_id(self, event: Dict[str, Any]) -> Optional[str]:
"""
从 message_start 事件中提取消息 ID
Args:
event: 事件字典
Returns:
消息 ID如果不是 message_start 返回 None
"""
if event.get("type") != self.EVENT_MESSAGE_START:
return None
message = event.get("message", {})
return message.get("id")
def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]:
"""
从 message_delta 事件中提取停止原因
Args:
event: 事件字典
Returns:
停止原因,如果没有返回 None
"""
if event.get("type") != self.EVENT_MESSAGE_DELTA:
return None
delta = event.get("delta", {})
return delta.get("stop_reason")
__all__ = ["ClaudeStreamParser"]

View File

@@ -0,0 +1,11 @@
"""
Claude CLI 透传处理器
"""
from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter
from src.api.handlers.claude_cli.handler import ClaudeCliMessageHandler
__all__ = [
"ClaudeCliAdapter",
"ClaudeCliMessageHandler",
]

View File

@@ -0,0 +1,103 @@
"""
Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。
"""
from typing import Any, Dict, Optional, Type
from fastapi import Request
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
from src.api.handlers.claude.adapter import ClaudeCapabilityDetector
@register_cli_adapter
class ClaudeCliAdapter(CliAdapterBase):
"""
Claude CLI API 适配器
处理 Claude CLI 格式的请求(/v1/messages 端点,使用 Bearer 认证)。
"""
FORMAT_ID = "CLAUDE_CLI"
name = "claude.cli"
@property
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.claude_cli.handler import ClaudeCliMessageHandler
return ClaudeCliMessageHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["CLAUDE_CLI"])
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
authorization = request.headers.get("authorization")
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
def detect_capability_requirements(
self,
headers: Dict[str, str],
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""检测 Claude CLI 请求中隐含的能力需求"""
return ClaudeCapabilityDetector.detect_from_headers(headers)
# =========================================================================
# Claude CLI 特定的计费逻辑
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算 Claude CLI 的总输入上下文(用于阶梯计费判定)
Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
"""
return input_tokens + cache_creation_input_tokens + cache_read_input_tokens
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
"""Claude CLI 使用 messages 字段"""
messages = payload.get("messages", [])
return len(messages) if isinstance(messages, list) else 0
def _build_audit_metadata(
self,
payload: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
) -> Dict[str, Any]:
"""Claude CLI 特定的审计元数据"""
model = payload.get("model", "unknown")
stream = payload.get("stream", False)
messages = payload.get("messages", [])
role_counts = {}
for msg in messages:
role = msg.get("role", "unknown")
role_counts[role] = role_counts.get(role, 0) + 1
return {
"action": "claude_cli_request",
"model": model,
"stream": bool(stream),
"max_tokens": payload.get("max_tokens"),
"messages_count": len(messages),
"message_roles": role_counts,
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"tool_count": len(payload.get("tools") or []),
"system_present": bool(payload.get("system")),
}
__all__ = ["ClaudeCliAdapter"]

View File

@@ -0,0 +1,195 @@
"""
Claude CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
继承 CliMessageHandlerBase只需覆盖格式特定的配置和事件处理逻辑。
验证新架构的有效性:代码量从数百行减少到 ~80 行。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.cli_handler_base import (
CliMessageHandlerBase,
StreamContext,
)
class ClaudeCliMessageHandler(CliMessageHandlerBase):
"""
Claude CLI Message Handler - 处理 Claude CLI API 格式
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
响应格式特点:
- 使用 content[] 数组
- 使用 text 类型
- 流式事件message_start, content_block_delta, message_delta, message_stop
- 支持 cache_creation_input_tokens 和 cache_read_input_tokens
模型字段:请求体顶级 model 字段
"""
FORMAT_ID = "CLAUDE_CLI"
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
) -> str:
"""
从请求中提取模型名 - Claude 格式实现
Claude API 的 model 在请求体顶级字段。
Args:
request_body: 请求体
path_params: URL 路径参数Claude 不使用)
Returns:
模型名
"""
model = request_body.get("model")
return str(model) if model else "unknown"
def apply_mapped_model(
self,
request_body: Dict[str, Any],
mapped_model: str,
) -> Dict[str, Any]:
"""
Claude API 的 model 在请求体顶级
Args:
request_body: 原始请求体
mapped_model: 映射后的模型名
Returns:
更新了 model 字段的请求体
"""
result = dict(request_body)
result["model"] = mapped_model
return result
def _process_event_data(
self,
ctx: StreamContext,
event_type: str,
data: Dict[str, Any],
) -> None:
"""
处理 Claude CLI 格式的 SSE 事件
事件类型:
- message_start: 消息开始,包含初始 usage含缓存 tokens
- content_block_delta: 文本增量
- message_delta: 消息增量,包含最终 usage
- message_stop: 消息结束
"""
# 处理 message_start 事件
if event_type == "message_start":
message = data.get("message", {})
if message.get("id"):
ctx.response_id = message["id"]
# 提取初始 usage包含缓存 tokens
usage = message.get("usage", {})
if usage:
ctx.input_tokens = usage.get("input_tokens", 0)
# Claude 的缓存 tokens 使用不同的字段名
cache_read = usage.get("cache_read_input_tokens", 0)
if cache_read:
ctx.cached_tokens = cache_read
cache_creation = usage.get("cache_creation_input_tokens", 0)
if cache_creation:
ctx.cache_creation_tokens = cache_creation
# 处理文本增量
elif event_type == "content_block_delta":
delta = data.get("delta", {})
if delta.get("type") == "text_delta":
text = delta.get("text", "")
if text:
ctx.collected_text += text
# 处理消息增量(包含最终 usage
elif event_type == "message_delta":
usage = data.get("usage", {})
if usage:
if "input_tokens" in usage:
ctx.input_tokens = usage["input_tokens"]
if "output_tokens" in usage:
ctx.output_tokens = usage["output_tokens"]
# 更新缓存 tokens
if "cache_read_input_tokens" in usage:
ctx.cached_tokens = usage["cache_read_input_tokens"]
if "cache_creation_input_tokens" in usage:
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"]
# 检查是否结束
delta = data.get("delta", {})
if delta.get("stop_reason"):
ctx.has_completion = True
ctx.final_response = data
# 处理消息结束
elif event_type == "message_stop":
ctx.has_completion = True
def _extract_response_metadata(
self,
response: Dict[str, Any],
) -> Dict[str, Any]:
"""
从 Claude 响应中提取元数据
提取 model、stop_reason 等字段作为元数据。
Args:
response: Claude API 响应
Returns:
提取的元数据字典
"""
metadata: Dict[str, Any] = {}
# 提取模型名称(实际使用的模型)
if "model" in response:
metadata["model"] = response["model"]
# 提取停止原因
if "stop_reason" in response:
metadata["stop_reason"] = response["stop_reason"]
# 提取消息 ID
if "id" in response:
metadata["message_id"] = response["id"]
# 提取消息类型
if "type" in response:
metadata["type"] = response["type"]
return metadata
def _finalize_stream_metadata(self, ctx: StreamContext) -> None:
"""
从流上下文中提取最终元数据
在流传输完成后调用,从收集的事件中提取元数据。
Args:
ctx: 流上下文
"""
# 从 response_id 提取消息 ID
if ctx.response_id:
ctx.response_metadata["message_id"] = ctx.response_id
# 从 final_response 提取停止原因message_delta 事件中的 delta.stop_reason
if ctx.final_response:
delta = ctx.final_response.get("delta", {})
if "stop_reason" in delta:
ctx.response_metadata["stop_reason"] = delta["stop_reason"]
# 记录模型名称
if ctx.model:
ctx.response_metadata["model"] = ctx.model

View File

@@ -0,0 +1,26 @@
"""
Gemini API Handler 模块
提供 Gemini API 格式的请求处理
"""
from src.api.handlers.gemini.adapter import GeminiChatAdapter, build_gemini_adapter
from src.api.handlers.gemini.converter import (
ClaudeToGeminiConverter,
GeminiToClaudeConverter,
GeminiToOpenAIConverter,
OpenAIToGeminiConverter,
)
from src.api.handlers.gemini.handler import GeminiChatHandler
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
__all__ = [
"GeminiChatAdapter",
"GeminiChatHandler",
"GeminiStreamParser",
"ClaudeToGeminiConverter",
"GeminiToClaudeConverter",
"OpenAIToGeminiConverter",
"GeminiToOpenAIConverter",
"build_gemini_adapter",
]

View File

@@ -0,0 +1,170 @@
"""
Gemini Chat Adapter
处理 Gemini API 格式的请求适配
"""
from typing import Any, Dict, Optional, Type
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.logger import logger
from src.models.gemini import GeminiRequest
@register_adapter
class GeminiChatAdapter(ChatAdapterBase):
"""
Gemini Chat API 适配器
处理 Gemini Chat 格式的请求
端点: /v1beta/models/{model}:generateContent
"""
FORMAT_ID = "GEMINI"
name = "gemini.chat"
@property
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.gemini.handler import GeminiChatHandler
return GeminiChatHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["GEMINI"])
logger.info(f"[{self.name}] 初始化 Gemini Chat 适配器 | API格式: {self.allowed_api_formats}")
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (x-goog-api-key)"""
return request.headers.get("x-goog-api-key")
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any] # noqa: ARG002
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - Gemini 特化版本
Gemini API 特点:
- model 不合并到请求体(通过 extract_model_from_request 从 path_params 获取)
- stream 不合并到请求体Gemini API 通过 URL 端点区分流式/非流式)
Handler 层的 extract_model_from_request 会从 path_params 获取 model
prepare_provider_request_body 会确保发送给 Gemini API 的请求体不含 model。
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典(不使用)
Returns:
原始请求体(不合并任何 path_params
"""
return original_request_body.copy()
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
"""验证请求体"""
path_params = path_params or {}
is_stream = path_params.get("stream", False)
model = path_params.get("model", "unknown")
try:
if not isinstance(original_request_body, dict):
raise ValueError("Request body must be a JSON object")
# Gemini 必需字段: contents
if "contents" not in original_request_body:
raise ValueError("Missing required field: contents")
request = GeminiRequest.model_validate(
original_request_body,
strict=False,
)
except ValueError as e:
logger.error(f"请求体基本验证失败: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
request = GeminiRequest.model_construct(
contents=original_request_body.get("contents", []),
)
# 设置 model从 path_params 获取,用于日志和审计)
request.model = model
# 设置 stream 属性(用于 ChatAdapterBase 判断流式模式)
request.stream = is_stream
return request
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
"""提取消息数量"""
contents = payload.get("contents", [])
if hasattr(request_obj, "contents"):
contents = request_obj.contents
return len(contents) if isinstance(contents, list) else 0
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
"""构建 Gemini Chat 特定的审计元数据"""
role_counts: dict[str, int] = {}
contents = getattr(request_obj, "contents", []) or []
for content in contents:
role = getattr(content, "role", None) or content.get("role", "unknown")
role_counts[role] = role_counts.get(role, 0) + 1
generation_config = getattr(request_obj, "generation_config", None) or {}
if hasattr(generation_config, "dict"):
generation_config = generation_config.dict()
elif not isinstance(generation_config, dict):
generation_config = {}
# 判断流式模式
stream = getattr(request_obj, "stream", False)
return {
"action": "gemini_generate_content",
"model": getattr(request_obj, "model", payload.get("model", "unknown")),
"stream": bool(stream),
"max_output_tokens": generation_config.get("max_output_tokens"),
"temperature": generation_config.get("temperature"),
"top_p": generation_config.get("top_p"),
"top_k": generation_config.get("top_k"),
"contents_count": len(contents),
"content_roles": role_counts,
"tools_count": len(getattr(request_obj, "tools", None) or []),
"system_instruction_present": bool(getattr(request_obj, "system_instruction", None)),
"safety_settings_count": len(getattr(request_obj, "safety_settings", None) or []),
}
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
"""生成 Gemini 格式的错误响应"""
# Gemini 错误响应格式
return JSONResponse(
status_code=status_code,
content={
"error": {
"code": status_code,
"message": message,
"status": error_type.upper(),
}
},
)
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
"""
根据请求头构建适当的 Gemini 适配器
Args:
x_app_header: X-App 请求头值
Returns:
GeminiChatAdapter 实例
"""
# 目前只有一种 Gemini 适配器
# 未来可以根据 x_app_header 返回不同的适配器(如 CLI 模式)
return GeminiChatAdapter()
__all__ = ["GeminiChatAdapter", "build_gemini_adapter"]

View File

@@ -0,0 +1,544 @@
"""
Gemini 格式转换器
提供 Gemini 与其他 API 格式Claude、OpenAI之间的转换
"""
from typing import Any, Dict, List, Optional
class ClaudeToGeminiConverter:
"""
Claude -> Gemini 请求转换器
将 Claude Messages API 格式转换为 Gemini generateContent 格式
"""
def convert_request(self, claude_request: Dict[str, Any]) -> Dict[str, Any]:
"""
将 Claude 请求转换为 Gemini 请求
Args:
claude_request: Claude 格式的请求字典
Returns:
Gemini 格式的请求字典
"""
gemini_request: Dict[str, Any] = {
"contents": self._convert_messages(claude_request.get("messages", [])),
}
# 转换 system prompt
system = claude_request.get("system")
if system:
gemini_request["system_instruction"] = self._convert_system(system)
# 转换生成配置
generation_config = self._build_generation_config(claude_request)
if generation_config:
gemini_request["generation_config"] = generation_config
# 转换工具
tools = claude_request.get("tools")
if tools:
gemini_request["tools"] = self._convert_tools(tools)
return gemini_request
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""转换消息列表"""
contents = []
for msg in messages:
role = msg.get("role", "user")
# Gemini 使用 "model" 而不是 "assistant"
gemini_role = "model" if role == "assistant" else "user"
content = msg.get("content", "")
parts = self._convert_content_to_parts(content)
contents.append(
{
"role": gemini_role,
"parts": parts,
}
)
return contents
def _convert_content_to_parts(self, content: Any) -> List[Dict[str, Any]]:
"""将 Claude 内容转换为 Gemini parts"""
if isinstance(content, str):
return [{"text": content}]
if isinstance(content, list):
parts = []
for block in content:
if isinstance(block, str):
parts.append({"text": block})
elif isinstance(block, dict):
block_type = block.get("type")
if block_type == "text":
parts.append({"text": block.get("text", "")})
elif block_type == "image":
# 转换图片
source = block.get("source", {})
if source.get("type") == "base64":
parts.append(
{
"inline_data": {
"mime_type": source.get("media_type", "image/png"),
"data": source.get("data", ""),
}
}
)
elif block_type == "tool_use":
# 转换工具调用
parts.append(
{
"function_call": {
"name": block.get("name", ""),
"args": block.get("input", {}),
}
}
)
elif block_type == "tool_result":
# 转换工具结果
parts.append(
{
"function_response": {
"name": block.get("tool_use_id", ""),
"response": {"result": block.get("content", "")},
}
}
)
return parts
return [{"text": str(content)}]
def _convert_system(self, system: Any) -> Dict[str, Any]:
"""转换 system prompt"""
if isinstance(system, str):
return {"parts": [{"text": system}]}
if isinstance(system, list):
parts = []
for item in system:
if isinstance(item, str):
parts.append({"text": item})
elif isinstance(item, dict) and item.get("type") == "text":
parts.append({"text": item.get("text", "")})
return {"parts": parts}
return {"parts": [{"text": str(system)}]}
def _build_generation_config(self, claude_request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""构建生成配置"""
config: Dict[str, Any] = {}
if "max_tokens" in claude_request:
config["max_output_tokens"] = claude_request["max_tokens"]
if "temperature" in claude_request:
config["temperature"] = claude_request["temperature"]
if "top_p" in claude_request:
config["top_p"] = claude_request["top_p"]
if "top_k" in claude_request:
config["top_k"] = claude_request["top_k"]
if "stop_sequences" in claude_request:
config["stop_sequences"] = claude_request["stop_sequences"]
return config if config else None
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""转换工具定义"""
function_declarations = []
for tool in tools:
func_decl = {
"name": tool.get("name", ""),
}
if "description" in tool:
func_decl["description"] = tool["description"]
if "input_schema" in tool:
func_decl["parameters"] = tool["input_schema"]
function_declarations.append(func_decl)
return [{"function_declarations": function_declarations}]
class GeminiToClaudeConverter:
"""
Gemini -> Claude 响应转换器
将 Gemini generateContent 响应转换为 Claude Messages API 格式
"""
def convert_response(self, gemini_response: Dict[str, Any]) -> Dict[str, Any]:
"""
将 Gemini 响应转换为 Claude 响应
Args:
gemini_response: Gemini 格式的响应字典
Returns:
Claude 格式的响应字典
"""
candidates = gemini_response.get("candidates", [])
if not candidates:
return self._create_empty_response()
candidate = candidates[0]
content = candidate.get("content", {})
parts = content.get("parts", [])
# 转换内容块
claude_content = self._convert_parts_to_content(parts)
# 转换使用量
usage = self._convert_usage(gemini_response.get("usageMetadata", {}))
# 转换停止原因
stop_reason = self._convert_finish_reason(candidate.get("finishReason"))
return {
"id": f"msg_{gemini_response.get('modelVersion', 'gemini')}",
"type": "message",
"role": "assistant",
"content": claude_content,
"model": gemini_response.get("modelVersion", "gemini"),
"stop_reason": stop_reason,
"stop_sequence": None,
"usage": usage,
}
def _convert_parts_to_content(self, parts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将 Gemini parts 转换为 Claude content blocks"""
content = []
for part in parts:
if "text" in part:
content.append(
{
"type": "text",
"text": part["text"],
}
)
elif "functionCall" in part:
func_call = part["functionCall"]
content.append(
{
"type": "tool_use",
"id": f"toolu_{func_call.get('name', '')}",
"name": func_call.get("name", ""),
"input": func_call.get("args", {}),
}
)
return content
def _convert_usage(self, usage_metadata: Dict[str, Any]) -> Dict[str, int]:
"""转换使用量信息"""
return {
"input_tokens": usage_metadata.get("promptTokenCount", 0),
"output_tokens": usage_metadata.get("candidatesTokenCount", 0),
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": usage_metadata.get("cachedContentTokenCount", 0),
}
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
"""转换停止原因"""
mapping = {
"STOP": "end_turn",
"MAX_TOKENS": "max_tokens",
"SAFETY": "content_filtered",
"RECITATION": "content_filtered",
"OTHER": "stop_sequence",
}
return mapping.get(finish_reason, "end_turn")
def _create_empty_response(self) -> Dict[str, Any]:
"""创建空响应"""
return {
"id": "msg_empty",
"type": "message",
"role": "assistant",
"content": [],
"model": "gemini",
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {
"input_tokens": 0,
"output_tokens": 0,
},
}
class OpenAIToGeminiConverter:
"""
OpenAI -> Gemini 请求转换器
将 OpenAI Chat Completions API 格式转换为 Gemini generateContent 格式
"""
def convert_request(self, openai_request: Dict[str, Any]) -> Dict[str, Any]:
"""
将 OpenAI 请求转换为 Gemini 请求
Args:
openai_request: OpenAI 格式的请求字典
Returns:
Gemini 格式的请求字典
"""
messages = openai_request.get("messages", [])
# 分离 system 消息和其他消息
system_messages = []
other_messages = []
for msg in messages:
if msg.get("role") == "system":
system_messages.append(msg)
else:
other_messages.append(msg)
gemini_request: Dict[str, Any] = {
"contents": self._convert_messages(other_messages),
}
# 转换 system messages
if system_messages:
system_text = "\n".join(msg.get("content", "") for msg in system_messages)
gemini_request["system_instruction"] = {"parts": [{"text": system_text}]}
# 转换生成配置
generation_config = self._build_generation_config(openai_request)
if generation_config:
gemini_request["generation_config"] = generation_config
# 转换工具
tools = openai_request.get("tools")
if tools:
gemini_request["tools"] = self._convert_tools(tools)
return gemini_request
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""转换消息列表"""
contents = []
for msg in messages:
role = msg.get("role", "user")
gemini_role = "model" if role == "assistant" else "user"
content = msg.get("content", "")
parts = self._convert_content_to_parts(content)
# 处理工具调用
tool_calls = msg.get("tool_calls", [])
for tc in tool_calls:
if tc.get("type") == "function":
func = tc.get("function", {})
import json
try:
args = json.loads(func.get("arguments", "{}"))
except json.JSONDecodeError:
args = {}
parts.append(
{
"function_call": {
"name": func.get("name", ""),
"args": args,
}
}
)
if parts:
contents.append(
{
"role": gemini_role,
"parts": parts,
}
)
return contents
def _convert_content_to_parts(self, content: Any) -> List[Dict[str, Any]]:
"""将 OpenAI 内容转换为 Gemini parts"""
if content is None:
return []
if isinstance(content, str):
return [{"text": content}]
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, str):
parts.append({"text": item})
elif isinstance(item, dict):
item_type = item.get("type")
if item_type == "text":
parts.append({"text": item.get("text", "")})
elif item_type == "image_url":
# OpenAI 图片 URL 格式
image_url = item.get("image_url", {})
url = image_url.get("url", "")
if url.startswith("data:"):
# base64 数据 URL
# 格式: data:image/png;base64,xxxxx
try:
header, data = url.split(",", 1)
mime_type = header.split(":")[1].split(";")[0]
parts.append(
{
"inline_data": {
"mime_type": mime_type,
"data": data,
}
}
)
except (ValueError, IndexError):
pass
return parts
return [{"text": str(content)}]
def _build_generation_config(self, openai_request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""构建生成配置"""
config: Dict[str, Any] = {}
if "max_tokens" in openai_request:
config["max_output_tokens"] = openai_request["max_tokens"]
if "temperature" in openai_request:
config["temperature"] = openai_request["temperature"]
if "top_p" in openai_request:
config["top_p"] = openai_request["top_p"]
if "stop" in openai_request:
stop = openai_request["stop"]
if isinstance(stop, str):
config["stop_sequences"] = [stop]
elif isinstance(stop, list):
config["stop_sequences"] = stop
if "n" in openai_request:
config["candidate_count"] = openai_request["n"]
return config if config else None
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""转换工具定义"""
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
func_decl = {
"name": func.get("name", ""),
}
if "description" in func:
func_decl["description"] = func["description"]
if "parameters" in func:
func_decl["parameters"] = func["parameters"]
function_declarations.append(func_decl)
return [{"function_declarations": function_declarations}]
class GeminiToOpenAIConverter:
"""
Gemini -> OpenAI 响应转换器
将 Gemini generateContent 响应转换为 OpenAI Chat Completions API 格式
"""
def convert_response(self, gemini_response: Dict[str, Any]) -> Dict[str, Any]:
"""
将 Gemini 响应转换为 OpenAI 响应
Args:
gemini_response: Gemini 格式的响应字典
Returns:
OpenAI 格式的响应字典
"""
import time
candidates = gemini_response.get("candidates", [])
choices = []
for i, candidate in enumerate(candidates):
content = candidate.get("content", {})
parts = content.get("parts", [])
# 提取文本内容
text_parts = []
tool_calls = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
elif "functionCall" in part:
func_call = part["functionCall"]
import json
tool_calls.append(
{
"id": f"call_{func_call.get('name', '')}_{i}",
"type": "function",
"function": {
"name": func_call.get("name", ""),
"arguments": json.dumps(func_call.get("args", {})),
},
}
)
message: Dict[str, Any] = {
"role": "assistant",
"content": "".join(text_parts) if text_parts else None,
}
if tool_calls:
message["tool_calls"] = tool_calls
finish_reason = self._convert_finish_reason(candidate.get("finishReason"))
choices.append(
{
"index": i,
"message": message,
"finish_reason": finish_reason,
}
)
# 转换使用量
usage = self._convert_usage(gemini_response.get("usageMetadata", {}))
return {
"id": f"chatcmpl-{gemini_response.get('modelVersion', 'gemini')}",
"object": "chat.completion",
"created": int(time.time()),
"model": gemini_response.get("modelVersion", "gemini"),
"choices": choices,
"usage": usage,
}
def _convert_usage(self, usage_metadata: Dict[str, Any]) -> Dict[str, int]:
"""转换使用量信息"""
prompt_tokens = usage_metadata.get("promptTokenCount", 0)
completion_tokens = usage_metadata.get("candidatesTokenCount", 0)
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
"""转换停止原因"""
mapping = {
"STOP": "stop",
"MAX_TOKENS": "length",
"SAFETY": "content_filter",
"RECITATION": "content_filter",
"OTHER": "stop",
}
return mapping.get(finish_reason, "stop")
__all__ = [
"ClaudeToGeminiConverter",
"GeminiToClaudeConverter",
"OpenAIToGeminiConverter",
"GeminiToOpenAIConverter",
]

View File

@@ -0,0 +1,164 @@
"""
Gemini Chat Handler
处理 Gemini API 格式的请求
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
class GeminiChatHandler(ChatHandlerBase):
"""
Gemini Chat Handler - 处理 Google Gemini API 格式的请求
格式特点:
- 使用 promptTokenCount / candidatesTokenCount
- 支持 cachedContentTokenCount
- 请求格式: GeminiRequest
- 响应格式: JSON 数组流(非 SSE
"""
FORMAT_ID = "GEMINI"
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None,
) -> str:
"""
从请求中提取模型名 - Gemini Chat 格式实现
Gemini Chat 模式下model 在请求体中(经过转换后的 GeminiRequest
与 Gemini CLI 不同CLI 模式的 model 在 URL 路径中。
Args:
request_body: 请求体
path_params: URL 路径参数Chat 模式通常不使用)
Returns:
模型名
"""
# 优先从请求体获取,其次从 path_params
model = request_body.get("model")
if model:
return str(model)
if path_params and "model" in path_params:
return str(path_params["model"])
return "unknown"
async def _convert_request(self, request):
"""
将请求转换为 Gemini 格式
支持自动转换:
- Claude 格式 → Gemini 格式
- OpenAI 格式 → Gemini 格式
Args:
request: 原始请求对象(可能是 Gemini/Claude/OpenAI 格式)
Returns:
GeminiRequest 对象
"""
from src.api.handlers.gemini.converter import (
ClaudeToGeminiConverter,
OpenAIToGeminiConverter,
)
from src.models.claude import ClaudeMessagesRequest
from src.models.gemini import GeminiRequest
from src.models.openai import OpenAIRequest
# 如果已经是 Gemini 格式,直接返回
if isinstance(request, GeminiRequest):
return request
# 如果是 Claude 格式,转换为 Gemini 格式
if isinstance(request, ClaudeMessagesRequest):
converter = ClaudeToGeminiConverter()
gemini_dict = converter.convert_request(request.model_dump())
return GeminiRequest(**gemini_dict)
# 如果是 OpenAI 格式,转换为 Gemini 格式
if isinstance(request, OpenAIRequest):
converter = OpenAIToGeminiConverter()
gemini_dict = converter.convert_request(request.model_dump())
return GeminiRequest(**gemini_dict)
# 如果是字典,根据内容判断格式并转换
if isinstance(request, dict):
# 检测 Gemini 格式特征: contents 字段
if "contents" in request:
return GeminiRequest(**request)
# 检测 Claude 格式特征: messages + 没有 choices
if "messages" in request and "choices" not in request:
# 进一步区分 Claude 和 OpenAI
# Claude 使用 max_tokensOpenAI 也可能有
# Claude 的 messages[].content 可以是数组OpenAI 通常是字符串
messages = request.get("messages", [])
if messages and isinstance(messages[0].get("content"), list):
# 可能是 Claude 格式
converter = ClaudeToGeminiConverter()
gemini_dict = converter.convert_request(request)
return GeminiRequest(**gemini_dict)
else:
# 可能是 OpenAI 格式
converter = OpenAIToGeminiConverter()
gemini_dict = converter.convert_request(request)
return GeminiRequest(**gemini_dict)
# 默认尝试作为 Gemini 格式
return GeminiRequest(**request)
return request
def _extract_usage(self, response: Dict) -> Dict[str, int]:
"""
从 Gemini 响应中提取 token 使用情况
调用 GeminiStreamParser.extract_usage 作为单一实现源
"""
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
usage = GeminiStreamParser().extract_usage(response)
if not usage:
return {
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
}
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_input_tokens": 0, # Gemini 不区分缓存创建
"cache_read_input_tokens": usage.get("cached_tokens", 0),
}
def _normalize_response(self, response: Dict) -> Dict:
"""
规范化 Gemini 响应
Args:
response: 原始响应
Returns:
规范化后的响应
TODO: 如果需要,实现响应规范化逻辑
"""
# 可选:使用 response_normalizer 进行规范化
# if (
# self.response_normalizer
# and self.response_normalizer.should_normalize(response)
# ):
# return self.response_normalizer.normalize_gemini_response(
# response_data=response,
# request_id=self.request_id,
# strict=False,
# )
return response

View File

@@ -0,0 +1,307 @@
"""
Gemini SSE/JSON 流解析器
Gemini API 的流式响应格式与 Claude/OpenAI 不同:
- 使用 JSON 数组格式 (不是 SSE)
- 每个块是一个完整的 JSON 对象
- 响应以 [ 开始,以 ] 结束,块之间用 , 分隔
参考: https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
"""
import json
from typing import Any, Dict, List, Optional
class GeminiStreamParser:
"""
Gemini 流解析器
解析 Gemini streamGenerateContent API 的响应流。
Gemini 流式响应特点:
- 返回 JSON 数组格式: [{chunk1}, {chunk2}, ...]
- 每个 chunk 包含 candidates、usageMetadata 等字段
- finish_reason 可能值: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
"""
# 停止原因
FINISH_REASON_STOP = "STOP"
FINISH_REASON_MAX_TOKENS = "MAX_TOKENS"
FINISH_REASON_SAFETY = "SAFETY"
FINISH_REASON_RECITATION = "RECITATION"
FINISH_REASON_OTHER = "OTHER"
def __init__(self):
self._buffer = ""
self._in_array = False
self._brace_depth = 0
def reset(self):
"""重置解析器状态"""
self._buffer = ""
self._in_array = False
self._brace_depth = 0
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
"""
解析流式数据块
Args:
chunk: 原始数据bytes 或 str
Returns:
解析后的事件列表
"""
if isinstance(chunk, bytes):
text = chunk.decode("utf-8")
else:
text = chunk
events: List[Dict[str, Any]] = []
for char in text:
if char == "[" and not self._in_array:
self._in_array = True
continue
if char == "]" and self._in_array and self._brace_depth == 0:
# 数组结束
self._in_array = False
if self._buffer.strip():
try:
obj = json.loads(self._buffer.strip().rstrip(","))
events.append(obj)
except json.JSONDecodeError:
pass
self._buffer = ""
continue
if self._in_array:
if char == "{":
self._brace_depth += 1
elif char == "}":
self._brace_depth -= 1
self._buffer += char
# 当 brace_depth 回到 0 时,说明一个完整的 JSON 对象结束
if self._brace_depth == 0 and self._buffer.strip():
try:
obj = json.loads(self._buffer.strip().rstrip(","))
events.append(obj)
self._buffer = ""
except json.JSONDecodeError:
# 可能还不完整,继续累积
pass
return events
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
"""
解析单行 JSON 数据
Args:
line: JSON 数据行
Returns:
解析后的事件字典,如果无法解析返回 None
"""
if not line or line.strip() in ["[", "]", ","]:
return None
try:
return json.loads(line.strip().rstrip(","))
except json.JSONDecodeError:
return None
def is_done_event(self, event: Dict[str, Any]) -> bool:
"""
判断是否为结束事件
Args:
event: 事件字典
Returns:
True 如果是结束事件
"""
candidates = event.get("candidates", [])
if candidates:
for candidate in candidates:
finish_reason = candidate.get("finishReason")
if finish_reason in (
self.FINISH_REASON_STOP,
self.FINISH_REASON_MAX_TOKENS,
self.FINISH_REASON_SAFETY,
self.FINISH_REASON_RECITATION,
self.FINISH_REASON_OTHER,
):
return True
return False
def is_error_event(self, event: Dict[str, Any]) -> bool:
"""
判断是否为错误事件
检测多种 Gemini 错误格式:
1. 顶层 error: {"error": {...}}
2. chunks 内嵌套 error: {"chunks": [{"error": {...}}]}
3. candidates 内的错误状态
Args:
event: 事件字典
Returns:
True 如果是错误事件
"""
# 顶层 error
if "error" in event:
return True
# chunks 内嵌套 error (某些 Gemini 响应格式)
chunks = event.get("chunks", [])
if chunks and isinstance(chunks, list):
for chunk in chunks:
if isinstance(chunk, dict) and "error" in chunk:
return True
return False
def extract_error_info(self, event: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
从事件中提取错误信息
Args:
event: 事件字典
Returns:
错误信息字典 {"code": int, "message": str, "status": str},无错误返回 None
"""
# 顶层 error
if "error" in event:
error = event["error"]
if isinstance(error, dict):
return {
"code": error.get("code"),
"message": error.get("message", str(error)),
"status": error.get("status"),
}
return {"code": None, "message": str(error), "status": None}
# chunks 内嵌套 error
chunks = event.get("chunks", [])
if chunks and isinstance(chunks, list):
for chunk in chunks:
if isinstance(chunk, dict) and "error" in chunk:
error = chunk["error"]
if isinstance(error, dict):
return {
"code": error.get("code"),
"message": error.get("message", str(error)),
"status": error.get("status"),
}
return {"code": None, "message": str(error), "status": None}
return None
def get_finish_reason(self, event: Dict[str, Any]) -> Optional[str]:
"""
获取结束原因
Args:
event: 事件字典
Returns:
结束原因字符串
"""
candidates = event.get("candidates", [])
if candidates:
return candidates[0].get("finishReason")
return None
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
"""
从响应中提取文本内容
Args:
event: 事件字典
Returns:
文本内容,如果没有文本返回 None
"""
candidates = event.get("candidates", [])
if not candidates:
return None
content = candidates[0].get("content", {})
parts = content.get("parts", [])
text_parts = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
return "".join(text_parts) if text_parts else None
def extract_usage(self, event: Dict[str, Any]) -> Optional[Dict[str, int]]:
"""
从事件中提取 token 使用量
这是 Gemini token 提取的单一实现源,其他地方都应该调用此方法。
Args:
event: 事件字典(包含 usageMetadata
Returns:
使用量字典,如果没有完整的使用量信息返回 None
注意:
- 只有当 totalTokenCount 存在时才提取(确保是完整的 usage 数据)
- 输出 token = thoughtsTokenCount + candidatesTokenCount
"""
usage_metadata = event.get("usageMetadata", {})
if not usage_metadata or "totalTokenCount" not in usage_metadata:
return None
# 输出 token = thoughtsTokenCount + candidatesTokenCount
thoughts_tokens = usage_metadata.get("thoughtsTokenCount", 0)
candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
output_tokens = thoughts_tokens + candidates_tokens
return {
"input_tokens": usage_metadata.get("promptTokenCount", 0),
"output_tokens": output_tokens,
"total_tokens": usage_metadata.get("totalTokenCount", 0),
"cached_tokens": usage_metadata.get("cachedContentTokenCount", 0),
}
def extract_model_version(self, event: Dict[str, Any]) -> Optional[str]:
"""
从响应中提取模型版本
Args:
event: 事件字典
Returns:
模型版本,如果没有返回 None
"""
return event.get("modelVersion")
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
"""
从响应中提取安全评级
Args:
event: 事件字典
Returns:
安全评级列表,如果没有返回 None
"""
candidates = event.get("candidates", [])
if not candidates:
return None
return candidates[0].get("safetyRatings")
__all__ = ["GeminiStreamParser"]

View File

@@ -0,0 +1,12 @@
"""
Gemini CLI 透传处理器
"""
from src.api.handlers.gemini_cli.adapter import GeminiCliAdapter, build_gemini_cli_adapter
from src.api.handlers.gemini_cli.handler import GeminiCliMessageHandler
__all__ = [
"GeminiCliAdapter",
"GeminiCliMessageHandler",
"build_gemini_cli_adapter",
]

View File

@@ -0,0 +1,112 @@
"""
Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
继承 CliAdapterBase处理 Gemini CLI 格式的请求。
"""
from typing import Any, Dict, Optional, Type
from fastapi import Request
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
@register_cli_adapter
class GeminiCliAdapter(CliAdapterBase):
"""
Gemini CLI API 适配器
处理 Gemini CLI 格式的请求(透传模式,最小验证)。
"""
FORMAT_ID = "GEMINI_CLI"
name = "gemini.cli"
@property
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.gemini_cli.handler import GeminiCliMessageHandler
return GeminiCliMessageHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["GEMINI_CLI"])
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (x-goog-api-key)"""
return request.headers.get("x-goog-api-key")
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any] # noqa: ARG002
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - Gemini CLI 特化版本
Gemini API 特点:
- model 不合并到请求体Gemini 原生请求体不含 model通过 URL 路径传递)
- stream 不合并到请求体Gemini API 通过 URL 端点区分流式/非流式)
基类已经从 path_params 获取 model 和 stream 用于日志和路由判断。
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典(包含 model、stream 等)
Returns:
原始请求体(不合并任何 path_params
"""
# Gemini: 不合并任何 path_params 到请求体
return original_request_body.copy()
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
"""Gemini CLI 使用 contents 字段"""
contents = payload.get("contents", [])
return len(contents) if isinstance(contents, list) else 0
def _build_audit_metadata(
self,
payload: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Gemini CLI 特定的审计元数据"""
# 从 path_params 获取 modelGemini 请求体不含 model
model = path_params.get("model", "unknown") if path_params else "unknown"
contents = payload.get("contents", [])
generation_config = payload.get("generation_config", {}) or {}
role_counts: Dict[str, int] = {}
for content in contents:
role = content.get("role", "unknown") if isinstance(content, dict) else "unknown"
role_counts[role] = role_counts.get(role, 0) + 1
return {
"action": "gemini_cli_request",
"model": model,
"stream": bool(payload.get("stream", False)),
"max_output_tokens": generation_config.get("max_output_tokens"),
"contents_count": len(contents),
"content_roles": role_counts,
"temperature": generation_config.get("temperature"),
"top_p": generation_config.get("top_p"),
"top_k": generation_config.get("top_k"),
"tools_count": len(payload.get("tools") or []),
"system_instruction_present": bool(payload.get("system_instruction")),
"safety_settings_count": len(payload.get("safety_settings") or []),
}
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
"""
构建 Gemini CLI 适配器
Args:
x_app_header: X-App 请求头值(预留扩展)
Returns:
GeminiCliAdapter 实例
"""
return GeminiCliAdapter()
__all__ = ["GeminiCliAdapter", "build_gemini_cli_adapter"]

View File

@@ -0,0 +1,210 @@
"""
Gemini CLI Message Handler - 基于通用 CLI Handler 基类的实现
继承 CliMessageHandlerBase处理 Gemini CLI API 格式的请求。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.cli_handler_base import (
CliMessageHandlerBase,
StreamContext,
)
class GeminiCliMessageHandler(CliMessageHandlerBase):
"""
Gemini CLI Message Handler - 处理 Gemini CLI API 格式
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
响应格式特点:
- Gemini 使用 JSON 数组格式流式响应(非 SSE
- 每个 chunk 包含 candidates、usageMetadata 等字段
- finish_reason: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
- Token 使用: promptTokenCount (输入), thoughtsTokenCount + candidatesTokenCount (输出), cachedContentTokenCount (缓存)
Gemini API 特殊处理:
- model 在 URL 路径中而非请求体,如 /v1beta/models/{model}:generateContent
- 请求体中的 model 字段用于内部路由,不发送给 API
"""
FORMAT_ID = "GEMINI_CLI"
def extract_model_from_request(
self,
request_body: Dict[str, Any], # noqa: ARG002 - 基类签名要求
path_params: Optional[Dict[str, Any]] = None,
) -> str:
"""
从请求中提取模型名 - Gemini 格式实现
Gemini API 的 model 在 URL 路径中而非请求体:
/v1beta/models/{model}:generateContent
Args:
request_body: 请求体Gemini 不包含 model
path_params: URL 路径参数(包含 model
Returns:
模型名,如果无法提取则返回 "unknown"
"""
# Gemini: model 从 URL 路径参数获取
if path_params and "model" in path_params:
return str(path_params["model"])
return "unknown"
def prepare_provider_request_body(
self,
request_body: Dict[str, Any],
) -> Dict[str, Any]:
"""
准备发送给 Gemini API 的请求体 - 移除 model 字段
Gemini API 要求 model 只在 URL 路径中,请求体中的 model 字段
会导致某些代理返回 404 错误。
Args:
request_body: 请求体
Returns:
不含 model 字段的请求体
"""
result = dict(request_body)
result.pop("model", None)
return result
def get_model_for_url(
self,
request_body: Dict[str, Any],
mapped_model: Optional[str],
) -> Optional[str]:
"""
Gemini 需要将 model 放入 URL 路径中
Args:
request_body: 请求体
mapped_model: 映射后的模型名(如果有)
Returns:
用于 URL 路径的模型名
"""
# 优先使用映射后的模型名,否则使用请求体中的
return mapped_model or request_body.get("model")
def _extract_usage_from_event(self, event: Dict[str, Any]) -> Dict[str, int]:
"""
从 Gemini 事件中提取 token 使用情况
调用 GeminiStreamParser.extract_usage 作为单一实现源
Args:
event: Gemini 流式响应事件
Returns:
包含 input_tokens, output_tokens, cached_tokens 的字典
"""
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
usage = GeminiStreamParser().extract_usage(event)
if not usage:
return {
"input_tokens": 0,
"output_tokens": 0,
"cached_tokens": 0,
}
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cached_tokens": usage.get("cached_tokens", 0),
}
def _process_event_data(
self,
ctx: StreamContext,
_event_type: str,
data: Dict[str, Any],
) -> None:
"""
处理 Gemini CLI 格式的流式事件
Gemini 的流式响应是 JSON 数组格式,每个元素结构如下:
{
"candidates": [{
"content": {"parts": [{"text": "..."}], "role": "model"},
"finishReason": "STOP",
"safetyRatings": [...]
}],
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 20,
"totalTokenCount": 30,
"cachedContentTokenCount": 5
},
"modelVersion": "gemini-1.5-pro"
}
注意: Gemini 流解析器会将每个 JSON 对象作为一个"事件"传递
event_type 在这里可能为空或是自定义的标记
"""
# 提取候选响应
candidates = data.get("candidates", [])
if candidates:
candidate = candidates[0]
content = candidate.get("content", {})
# 提取文本内容
parts = content.get("parts", [])
for part in parts:
if "text" in part:
ctx.collected_text += part["text"]
# 检查结束原因
finish_reason = candidate.get("finishReason")
if finish_reason in ("STOP", "MAX_TOKENS", "SAFETY", "RECITATION", "OTHER"):
ctx.has_completion = True
ctx.final_response = data
# 提取使用量信息(复用 GeminiStreamParser.extract_usage
usage = self._extract_usage_from_event(data)
if usage["input_tokens"] > 0 or usage["output_tokens"] > 0:
ctx.input_tokens = usage["input_tokens"]
ctx.output_tokens = usage["output_tokens"]
ctx.cached_tokens = usage["cached_tokens"]
# 提取模型版本作为响应 ID
model_version = data.get("modelVersion")
if model_version:
if not ctx.response_id:
ctx.response_id = f"gemini-{model_version}"
# 存储到 response_metadata 供 Usage 记录使用
ctx.response_metadata["model_version"] = model_version
# 检查错误
if "error" in data:
ctx.has_completion = True
ctx.final_response = data
def _extract_response_metadata(
self,
response: Dict[str, Any],
) -> Dict[str, Any]:
"""
从 Gemini 响应中提取元数据
提取 modelVersion 字段,记录实际使用的模型版本。
Args:
response: Gemini API 响应
Returns:
包含 model_version 的元数据字典
"""
metadata: Dict[str, Any] = {}
model_version = response.get("modelVersion")
if model_version:
metadata["model_version"] = model_version
return metadata

View File

@@ -0,0 +1,11 @@
"""
OpenAI Chat API 处理器
"""
from src.api.handlers.openai.adapter import OpenAIChatAdapter
from src.api.handlers.openai.handler import OpenAIChatHandler
__all__ = [
"OpenAIChatAdapter",
"OpenAIChatHandler",
]

View File

@@ -0,0 +1,109 @@
"""
OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
"""
from typing import Any, Dict, Optional, Type
from fastapi import Request
from fastapi.responses import JSONResponse
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.logger import logger
from src.models.openai import OpenAIRequest
@register_adapter
class OpenAIChatAdapter(ChatAdapterBase):
"""
OpenAI Chat Completions API 适配器
处理 OpenAI Chat 格式的请求(/v1/chat/completions 端点)。
"""
FORMAT_ID = "OPENAI"
name = "openai.chat"
@property
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.openai.handler import OpenAIChatHandler
return OpenAIChatHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["OPENAI"])
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
authorization = request.headers.get("authorization")
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
"""验证请求体"""
if not isinstance(original_request_body, dict):
return self._error_response(
400, "Request body must be a JSON object", "invalid_request_error"
)
required_fields = ["model", "messages"]
missing = [f for f in required_fields if f not in original_request_body]
if missing:
return self._error_response(
400,
f"Missing required fields: {', '.join(missing)}",
"invalid_request_error",
)
try:
return OpenAIRequest.model_validate(original_request_body, strict=False)
except ValueError as e:
return self._error_response(400, str(e), "invalid_request_error")
except Exception as e:
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
return OpenAIRequest.model_construct(
model=original_request_body.get("model"),
messages=original_request_body.get("messages", []),
stream=original_request_body.get("stream", False),
max_tokens=original_request_body.get("max_tokens"),
)
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
"""构建 OpenAI Chat 特定的审计元数据"""
role_counts = {}
for message in request_obj.messages:
role_counts[message.role] = role_counts.get(message.role, 0) + 1
return {
"action": "openai_chat_completion",
"model": request_obj.model,
"stream": bool(request_obj.stream),
"max_tokens": request_obj.max_tokens,
"temperature": request_obj.temperature,
"top_p": request_obj.top_p,
"messages_count": len(request_obj.messages),
"message_roles": role_counts,
"tools_count": len(request_obj.tools or []),
"response_format": bool(request_obj.response_format),
"user_identifier": request_obj.user,
}
def _error_response(self, status_code: int, message: str, error_type: str) -> JSONResponse:
"""生成 OpenAI 格式的错误响应"""
return JSONResponse(
status_code=status_code,
content={
"error": {
"type": error_type,
"message": message,
"code": status_code,
}
},
)
__all__ = ["OpenAIChatAdapter"]

View File

@@ -0,0 +1,424 @@
"""
Claude -> OpenAI 格式转换器
将 Claude Messages API 格式转换为 OpenAI Chat Completions API 格式。
"""
from __future__ import annotations
import json
import time
import uuid
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
class ClaudeToOpenAIConverter:
"""
Claude -> OpenAI 格式转换器
支持:
- 请求转换Claude Request -> OpenAI Chat Request
- 响应转换Claude Response -> OpenAI Chat Response
- 流式转换Claude SSE -> OpenAI SSE
"""
# 内容类型常量
CONTENT_TYPE_TEXT = "text"
CONTENT_TYPE_IMAGE = "image"
CONTENT_TYPE_TOOL_USE = "tool_use"
CONTENT_TYPE_TOOL_RESULT = "tool_result"
# 停止原因映射
STOP_REASON_MAP = {
"end_turn": "stop",
"max_tokens": "length",
"stop_sequence": "stop",
"tool_use": "tool_calls",
}
def __init__(self, model_mapping: Optional[Dict[str, str]] = None):
"""
Args:
model_mapping: Claude 模型到 OpenAI 模型的映射
"""
self._model_mapping = model_mapping or {}
# ==================== 请求转换 ====================
def convert_request(self, request: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
"""
将 Claude 请求转换为 OpenAI 格式
Args:
request: Claude 请求Dict 或 Pydantic 模型)
Returns:
OpenAI 格式的请求字典
"""
if hasattr(request, "model_dump"):
data = request.model_dump(exclude_none=True)
else:
data = dict(request)
# 模型映射
model = data.get("model", "")
openai_model = self._model_mapping.get(model, model)
# 构建消息列表
messages: List[Dict[str, Any]] = []
# 处理 system 消息
system_content = self._extract_text_content(data.get("system"))
if system_content:
messages.append({"role": "system", "content": system_content})
# 处理对话消息
for message in data.get("messages", []):
converted = self._convert_message(message)
if converted:
messages.append(converted)
# 构建 OpenAI 请求
result: Dict[str, Any] = {
"model": openai_model,
"messages": messages,
}
# 可选参数
if data.get("max_tokens"):
result["max_tokens"] = data["max_tokens"]
if data.get("temperature") is not None:
result["temperature"] = data["temperature"]
if data.get("top_p") is not None:
result["top_p"] = data["top_p"]
if data.get("stream"):
result["stream"] = data["stream"]
if data.get("stop_sequences"):
result["stop"] = data["stop_sequences"]
# 工具转换
tools = self._convert_tools(data.get("tools"))
if tools:
result["tools"] = tools
tool_choice = self._convert_tool_choice(data.get("tool_choice"))
if tool_choice:
result["tool_choice"] = tool_choice
return result
def _convert_message(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""转换单条消息"""
role = message.get("role")
if role == "user":
return self._convert_user_message(message)
if role == "assistant":
return self._convert_assistant_message(message)
return None
def _convert_user_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""转换用户消息"""
content = message.get("content")
if isinstance(content, str):
return {"role": "user", "content": content}
openai_content: List[Dict[str, Any]] = []
for block in content or []:
block_type = block.get("type")
if block_type == self.CONTENT_TYPE_TEXT:
openai_content.append({"type": "text", "text": block.get("text", "")})
elif block_type == self.CONTENT_TYPE_IMAGE:
source = block.get("source", {})
media_type = source.get("media_type", "image/jpeg")
data = source.get("data", "")
openai_content.append(
{"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{data}"}}
)
elif block_type == self.CONTENT_TYPE_TOOL_RESULT:
tool_content = block.get("content", "")
rendered = self._render_tool_content(tool_content)
openai_content.append({"type": "text", "text": f"Tool result: {rendered}"})
# 简化单文本内容
if len(openai_content) == 1 and openai_content[0]["type"] == "text":
return {"role": "user", "content": openai_content[0]["text"]}
return {"role": "user", "content": openai_content or ""}
def _convert_assistant_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""转换助手消息"""
content = message.get("content")
text_parts: List[str] = []
tool_calls: List[Dict[str, Any]] = []
if isinstance(content, str):
text_parts.append(content)
else:
for idx, block in enumerate(content or []):
block_type = block.get("type")
if block_type == self.CONTENT_TYPE_TEXT:
text_parts.append(block.get("text", ""))
elif block_type == self.CONTENT_TYPE_TOOL_USE:
tool_calls.append(
{
"id": block.get("id", f"call_{idx}"),
"type": "function",
"function": {
"name": block.get("name", ""),
"arguments": json.dumps(block.get("input", {}), ensure_ascii=False),
},
}
)
result: Dict[str, Any] = {"role": "assistant"}
message_content = "\n".join([p for p in text_parts if p]) or None
if message_content:
result["content"] = message_content
if tool_calls:
result["tool_calls"] = tool_calls
return result
def _convert_tools(
self, tools: Optional[List[Dict[str, Any]]]
) -> Optional[List[Dict[str, Any]]]:
"""转换工具定义"""
if not tools:
return None
result: List[Dict[str, Any]] = []
for tool in tools:
result.append(
{
"type": "function",
"function": {
"name": tool.get("name", ""),
"description": tool.get("description"),
"parameters": tool.get("input_schema", {}),
},
}
)
return result
def _convert_tool_choice(
self, tool_choice: Optional[Dict[str, Any]]
) -> Optional[Union[str, Dict[str, Any]]]:
"""转换工具选择"""
if tool_choice is None:
return None
choice_type = tool_choice.get("type")
if choice_type in ("tool", "tool_use"):
return {"type": "function", "function": {"name": tool_choice.get("name", "")}}
if choice_type == "any":
return "required"
if choice_type == "auto":
return "auto"
return tool_choice
# ==================== 响应转换 ====================
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""
将 Claude 响应转换为 OpenAI 格式
Args:
response: Claude 响应字典
Returns:
OpenAI 格式的响应字典
"""
# 提取内容
content_parts: List[str] = []
tool_calls: List[Dict[str, Any]] = []
for idx, block in enumerate(response.get("content", [])):
block_type = block.get("type")
if block_type == self.CONTENT_TYPE_TEXT:
content_parts.append(block.get("text", ""))
elif block_type == self.CONTENT_TYPE_TOOL_USE:
tool_calls.append(
{
"id": block.get("id", f"call_{idx}"),
"type": "function",
"function": {
"name": block.get("name", ""),
"arguments": json.dumps(block.get("input", {}), ensure_ascii=False),
},
}
)
# 构建消息
message: Dict[str, Any] = {"role": "assistant"}
text_content = "\n".join([p for p in content_parts if p]) or None
if text_content:
message["content"] = text_content
if tool_calls:
message["tool_calls"] = tool_calls
# 转换停止原因
stop_reason = response.get("stop_reason")
finish_reason = self.STOP_REASON_MAP.get(stop_reason, stop_reason)
# 转换 usage
usage = response.get("usage", {})
openai_usage = {
"prompt_tokens": usage.get("input_tokens", 0),
"completion_tokens": usage.get("output_tokens", 0),
"total_tokens": (usage.get("input_tokens", 0) + usage.get("output_tokens", 0)),
}
return {
"id": f"chatcmpl-{response.get('id', uuid.uuid4().hex[:8])}",
"object": "chat.completion",
"created": int(time.time()),
"model": response.get("model", ""),
"choices": [
{
"index": 0,
"message": message,
"finish_reason": finish_reason,
}
],
"usage": openai_usage,
}
# ==================== 流式转换 ====================
def convert_stream_event(
self,
event: Dict[str, Any],
model: str = "",
message_id: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""
将 Claude SSE 事件转换为 OpenAI 格式
Args:
event: Claude SSE 事件
model: 模型名称
message_id: 消息 ID
Returns:
OpenAI 格式的 SSE chunk如果无法转换返回 None
"""
event_type = event.get("type")
chunk_id = f"chatcmpl-{(message_id or 'stream')[-8:]}"
if event_type == "message_start":
message = event.get("message", {})
return self._base_chunk(
chunk_id,
model or message.get("model", ""),
{"role": "assistant"},
)
if event_type == "content_block_start":
content_block = event.get("content_block", {})
if content_block.get("type") == self.CONTENT_TYPE_TOOL_USE:
delta = {
"tool_calls": [
{
"index": event.get("index", 0),
"id": content_block.get("id", ""),
"type": "function",
"function": {
"name": content_block.get("name", ""),
"arguments": "",
},
}
]
}
return self._base_chunk(chunk_id, model, delta)
return None
if event_type == "content_block_delta":
delta_payload = event.get("delta", {})
delta_type = delta_payload.get("type")
if delta_type == "text_delta":
delta = {"content": delta_payload.get("text", "")}
return self._base_chunk(chunk_id, model, delta)
if delta_type == "input_json_delta":
delta = {
"tool_calls": [
{
"index": event.get("index", 0),
"function": {"arguments": delta_payload.get("partial_json", "")},
}
]
}
return self._base_chunk(chunk_id, model, delta)
return None
if event_type == "message_delta":
delta = event.get("delta", {})
stop_reason = delta.get("stop_reason")
finish_reason = self.STOP_REASON_MAP.get(stop_reason, stop_reason)
return self._base_chunk(chunk_id, model, {}, finish_reason=finish_reason)
if event_type == "message_stop":
return self._base_chunk(chunk_id, model, {}, finish_reason="stop")
return None
def _base_chunk(
self,
chunk_id: str,
model: str,
delta: Dict[str, Any],
finish_reason: Optional[str] = None,
) -> Dict[str, Any]:
"""构建基础 OpenAI chunk"""
return {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"system_fingerprint": None,
"choices": [
{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
# ==================== 工具方法 ====================
def _extract_text_content(
self, content: Optional[Union[str, List[Dict[str, Any]]]]
) -> Optional[str]:
"""提取文本内容"""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [
block.get("text", "")
for block in content
if block.get("type") == self.CONTENT_TYPE_TEXT
]
return "\n\n".join(filter(None, parts)) or None
return None
def _render_tool_content(self, tool_content: Any) -> str:
"""渲染工具内容"""
if isinstance(tool_content, list):
return json.dumps(tool_content, ensure_ascii=False)
return str(tool_content)
__all__ = ["ClaudeToOpenAIConverter"]

View File

@@ -0,0 +1,137 @@
"""
OpenAI Chat Handler - 基于通用 Chat Handler 基类的简化实现
继承 ChatHandlerBase只需覆盖格式特定的方法。
代码量从原来的 ~1315 行减少到 ~100 行。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
class OpenAIChatHandler(ChatHandlerBase):
"""
OpenAI Chat Handler - 处理 OpenAI Chat Completions API 格式的请求
格式特点:
- 使用 prompt_tokens/completion_tokens
- 不支持 cache tokens
- 请求格式OpenAIRequest
"""
FORMAT_ID = "OPENAI"
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
) -> str:
"""
从请求中提取模型名 - OpenAI 格式实现
OpenAI API 的 model 在请求体顶级字段。
Args:
request_body: 请求体
path_params: URL 路径参数OpenAI 不使用)
Returns:
模型名
"""
model = request_body.get("model")
return str(model) if model else "unknown"
def apply_mapped_model(
self,
request_body: Dict[str, Any],
mapped_model: str,
) -> Dict[str, Any]:
"""
将映射后的模型名应用到请求体
OpenAI API 的 model 在请求体顶级字段。
Args:
request_body: 原始请求体
mapped_model: 映射后的模型名
Returns:
更新了 model 字段的请求体
"""
result = dict(request_body)
result["model"] = mapped_model
return result
async def _convert_request(self, request):
"""
将请求转换为 OpenAI 格式
Args:
request: 原始请求对象
Returns:
OpenAIRequest 对象
"""
from src.api.handlers.openai.converter import ClaudeToOpenAIConverter
from src.models.claude import ClaudeMessagesRequest
from src.models.openai import OpenAIRequest
# 如果已经是 OpenAI 格式,直接返回
if isinstance(request, OpenAIRequest):
return request
# 如果是 Claude 格式,转换为 OpenAI 格式
if isinstance(request, ClaudeMessagesRequest):
converter = ClaudeToOpenAIConverter()
openai_dict = converter.convert_request(request.dict())
return OpenAIRequest(**openai_dict)
# 如果是字典,尝试判断格式
if isinstance(request, dict):
try:
return OpenAIRequest(**request)
except Exception:
try:
converter = ClaudeToOpenAIConverter()
openai_dict = converter.convert_request(request)
return OpenAIRequest(**openai_dict)
except Exception:
return OpenAIRequest(**request)
return request
def _extract_usage(self, response: Dict) -> Dict[str, int]:
"""
从 OpenAI 响应中提取 token 使用情况
OpenAI 格式使用:
- prompt_tokens / completion_tokens
- 不支持 cache tokens
"""
usage = response.get("usage", {})
return {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
}
def _normalize_response(self, response: Dict) -> Dict:
"""
规范化 OpenAI 响应
Args:
response: 原始响应
Returns:
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
return self.response_normalizer.normalize_openai_response(
response_data=response,
request_id=self.request_id,
strict=False,
)
return response

View File

@@ -0,0 +1,181 @@
"""
OpenAI SSE 流解析器
解析 OpenAI Chat Completions API 的 Server-Sent Events 流。
"""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
class OpenAIStreamParser:
"""
OpenAI SSE 流解析器
解析 OpenAI Chat Completions API 的 SSE 事件流。
OpenAI 流格式:
- 每个 chunk 是一个 JSON 对象,包含 choices 数组
- choices[0].delta 包含增量内容
- choices[0].finish_reason 表示结束原因
- 流结束时发送 data: [DONE]
"""
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
"""
解析 SSE 数据块
Args:
chunk: 原始 SSE 数据bytes 或 str
Returns:
解析后的 chunk 列表
"""
if isinstance(chunk, bytes):
text = chunk.decode("utf-8")
else:
text = chunk
chunks: List[Dict[str, Any]] = []
lines = text.strip().split("\n")
for line in lines:
line = line.strip()
if not line:
continue
# 解析数据行
if line.startswith("data: "):
data_str = line[6:]
# 处理 [DONE] 标记
if data_str == "[DONE]":
chunks.append({"__done__": True})
continue
try:
data = json.loads(data_str)
chunks.append(data)
except json.JSONDecodeError:
# 无法解析的数据,跳过
pass
return chunks
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
"""
解析单行 SSE 数据
Args:
line: SSE 数据行(已去除 "data: " 前缀)
Returns:
解析后的 chunk 字典,如果无法解析返回 None
"""
if not line or line == "[DONE]":
return None
try:
return json.loads(line)
except json.JSONDecodeError:
return None
def is_done_chunk(self, chunk: Dict[str, Any]) -> bool:
"""
判断是否为结束 chunk
Args:
chunk: chunk 字典
Returns:
True 如果是结束 chunk
"""
# 内部标记
if chunk.get("__done__"):
return True
# 检查 finish_reason
choices = chunk.get("choices", [])
if choices:
finish_reason = choices[0].get("finish_reason")
return finish_reason is not None
return False
def get_finish_reason(self, chunk: Dict[str, Any]) -> Optional[str]:
"""
获取结束原因
Args:
chunk: chunk 字典
Returns:
结束原因字符串
"""
choices = chunk.get("choices", [])
if choices:
return choices[0].get("finish_reason")
return None
def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]:
"""
从 chunk 中提取文本增量
Args:
chunk: chunk 字典
Returns:
文本增量,如果没有返回 None
"""
choices = chunk.get("choices", [])
if not choices:
return None
delta = choices[0].get("delta", {})
content = delta.get("content")
if isinstance(content, str):
return content
return None
def extract_tool_calls_delta(self, chunk: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
"""
从 chunk 中提取工具调用增量
Args:
chunk: chunk 字典
Returns:
工具调用列表,如果没有返回 None
"""
choices = chunk.get("choices", [])
if not choices:
return None
delta = choices[0].get("delta", {})
return delta.get("tool_calls")
def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]:
"""
从 chunk 中提取角色
通常只在第一个 chunk 中出现。
Args:
chunk: chunk 字典
Returns:
角色字符串
"""
choices = chunk.get("choices", [])
if not choices:
return None
delta = choices[0].get("delta", {})
return delta.get("role")
__all__ = ["OpenAIStreamParser"]

View File

@@ -0,0 +1,11 @@
"""
OpenAI CLI 透传处理器
"""
from src.api.handlers.openai_cli.adapter import OpenAICliAdapter
from src.api.handlers.openai_cli.handler import OpenAICliMessageHandler
__all__ = [
"OpenAICliAdapter",
"OpenAICliMessageHandler",
]

View File

@@ -0,0 +1,44 @@
"""
OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。
"""
from typing import Optional, Type
from fastapi import Request
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
@register_cli_adapter
class OpenAICliAdapter(CliAdapterBase):
"""
OpenAI CLI API 适配器
处理 /v1/responses 端点的请求。
"""
FORMAT_ID = "OPENAI_CLI"
name = "openai.cli"
@property
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.openai_cli.handler import OpenAICliMessageHandler
return OpenAICliMessageHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["OPENAI_CLI"])
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
authorization = request.headers.get("authorization")
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
__all__ = ["OpenAICliAdapter"]

View File

@@ -0,0 +1,211 @@
"""
OpenAI CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
继承 CliMessageHandlerBase只需覆盖格式特定的配置和事件处理逻辑。
代码量从原来的 900+ 行减少到 ~100 行。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.cli_handler_base import (
CliMessageHandlerBase,
StreamContext,
)
class OpenAICliMessageHandler(CliMessageHandlerBase):
"""
OpenAI CLI Message Handler - 处理 OpenAI CLI Responses API 格式
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
响应格式特点:
- 使用 output[] 数组而非 content[]
- 使用 output_text 类型而非普通 text
- 流式事件response.output_text.delta, response.output_text.done
模型字段:请求体顶级 model 字段
"""
FORMAT_ID = "OPENAI_CLI"
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
) -> str:
"""
从请求中提取模型名 - OpenAI 格式实现
OpenAI API 的 model 在请求体顶级字段。
Args:
request_body: 请求体
path_params: URL 路径参数OpenAI 不使用)
Returns:
模型名
"""
model = request_body.get("model")
return str(model) if model else "unknown"
def apply_mapped_model(
self,
request_body: Dict[str, Any],
mapped_model: str,
) -> Dict[str, Any]:
"""
OpenAI CLI (Responses API) 的 model 在请求体顶级字段。
Args:
request_body: 原始请求体
mapped_model: 映射后的模型名
Returns:
更新了 model 字段的请求体
"""
result = dict(request_body)
result["model"] = mapped_model
return result
def _process_event_data(
self,
ctx: StreamContext,
event_type: str,
data: Dict[str, Any],
) -> None:
"""
处理 OpenAI CLI 格式的 SSE 事件
事件类型:
- response.output_text.delta: 文本增量
- response.completed: 响应完成(包含 usage
"""
# 提取 response_id
if not ctx.response_id:
response_obj = data.get("response")
if isinstance(response_obj, dict) and response_obj.get("id"):
ctx.response_id = response_obj["id"]
elif "id" in data:
ctx.response_id = data["id"]
# 处理文本增量
if event_type in ["response.output_text.delta", "response.outtext.delta"]:
delta = data.get("delta")
if isinstance(delta, str):
ctx.collected_text += delta
elif isinstance(delta, dict) and "text" in delta:
ctx.collected_text += delta["text"]
# 处理完成事件
elif event_type == "response.completed":
ctx.has_completion = True
response_obj = data.get("response")
if isinstance(response_obj, dict):
ctx.final_response = response_obj
usage_obj = response_obj.get("usage")
if isinstance(usage_obj, dict):
ctx.final_usage = usage_obj
ctx.input_tokens = usage_obj.get("input_tokens", 0)
ctx.output_tokens = usage_obj.get("output_tokens", 0)
details = usage_obj.get("input_tokens_details")
if isinstance(details, dict):
ctx.cached_tokens = details.get("cached_tokens", 0)
# 如果没有收集到文本,从 output 中提取
if not ctx.collected_text and "output" in response_obj:
for output_item in response_obj.get("output", []):
if output_item.get("type") != "message":
continue
for content_item in output_item.get("content", []):
if content_item.get("type") == "output_text":
text = content_item.get("text", "")
if text:
ctx.collected_text += text
# 备用:从顶层 usage 提取
usage_obj = data.get("usage")
if isinstance(usage_obj, dict) and not ctx.final_usage:
ctx.final_usage = usage_obj
ctx.input_tokens = usage_obj.get("input_tokens", 0)
ctx.output_tokens = usage_obj.get("output_tokens", 0)
details = usage_obj.get("input_tokens_details")
if isinstance(details, dict):
ctx.cached_tokens = details.get("cached_tokens", 0)
# 备用:从 response 字段提取
response_obj = data.get("response")
if isinstance(response_obj, dict) and not ctx.final_response:
ctx.final_response = response_obj
def _extract_response_metadata(
self,
response: Dict[str, Any],
) -> Dict[str, Any]:
"""
从 OpenAI 响应中提取元数据
提取 model、status、response_id 等字段作为元数据。
Args:
response: OpenAI API 响应
Returns:
提取的元数据字典
"""
metadata: Dict[str, Any] = {}
# 提取模型名称(实际使用的模型)
if "model" in response:
metadata["model"] = response["model"]
# 提取响应 ID
if "id" in response:
metadata["response_id"] = response["id"]
# 提取状态
if "status" in response:
metadata["status"] = response["status"]
# 提取对象类型
if "object" in response:
metadata["object"] = response["object"]
# 提取系统指纹(如果存在)
if "system_fingerprint" in response:
metadata["system_fingerprint"] = response["system_fingerprint"]
return metadata
def _finalize_stream_metadata(self, ctx: StreamContext) -> None:
"""
从流上下文中提取最终元数据
在流传输完成后调用,从收集的事件中提取元数据。
Args:
ctx: 流上下文
"""
# 从 response_id 提取响应 ID
if ctx.response_id:
ctx.response_metadata["response_id"] = ctx.response_id
# 从 final_response 提取更多元数据
if ctx.final_response and isinstance(ctx.final_response, dict):
if "model" in ctx.final_response:
ctx.response_metadata["model"] = ctx.final_response["model"]
if "status" in ctx.final_response:
ctx.response_metadata["status"] = ctx.final_response["status"]
if "object" in ctx.final_response:
ctx.response_metadata["object"] = ctx.final_response["object"]
if "system_fingerprint" in ctx.final_response:
ctx.response_metadata["system_fingerprint"] = ctx.final_response["system_fingerprint"]
# 如果没有从响应中获取到 model使用上下文中的
if "model" not in ctx.response_metadata and ctx.model:
ctx.response_metadata["model"] = ctx.model

View File

@@ -0,0 +1,10 @@
"""User monitoring routers."""
from fastapi import APIRouter
from .user import router as monitoring_router
router = APIRouter()
router.include_router(monitoring_router)
__all__ = ["router"]

148
src/api/monitoring/user.py Normal file
View File

@@ -0,0 +1,148 @@
"""普通用户可访问的监控与审计端点。"""
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.orm import Session
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_query
from src.api.base.pipeline import ApiRequestPipeline
from src.core.logger import logger
from src.database import get_db
from src.models.database import ApiKey, AuditLog
from src.plugins.manager import get_plugin_manager
router = APIRouter(prefix="/api/monitoring", tags=["Monitoring"])
pipeline = ApiRequestPipeline()
@router.get("/my-audit-logs")
async def get_my_audit_logs(
request: Request,
event_type: Optional[str] = Query(None, description="事件类型筛选"),
days: int = Query(30, description="查询天数"),
limit: int = Query(50, description="返回数量限制"),
offset: int = Query(0, ge=0, description="偏移量"),
db: Session = Depends(get_db),
):
adapter = UserAuditLogsAdapter(event_type=event_type, days=days, limit=limit, offset=offset)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/rate-limit-status")
async def get_rate_limit_status(request: Request, db: Session = Depends(get_db)):
adapter = UserRateLimitStatusAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AuthenticatedApiAdapter(ApiAdapter):
"""需要用户登录的适配器基类。"""
mode = ApiMode.USER
def authorize(self, context): # type: ignore[override]
if not context.user:
raise HTTPException(status_code=401, detail="未登录")
@dataclass
class UserAuditLogsAdapter(AuthenticatedApiAdapter):
event_type: Optional[str]
days: int
limit: int
offset: int
async def handle(self, context): # type: ignore[override]
db = context.db
user = context.user
if not user:
raise HTTPException(status_code=401, detail="未登录")
query = db.query(AuditLog).filter(AuditLog.user_id == user.id)
if self.event_type:
query = query.filter(AuditLog.event_type == self.event_type)
cutoff_time = datetime.now(timezone.utc) - timedelta(days=self.days)
query = query.filter(AuditLog.created_at >= cutoff_time)
query = query.order_by(AuditLog.created_at.desc())
total, logs = paginate_query(query, self.limit, self.offset)
items = [
{
"id": log.id,
"event_type": log.event_type,
"description": log.description,
"ip_address": log.ip_address,
"status_code": log.status_code,
"created_at": log.created_at.isoformat() if log.created_at else None,
}
for log in logs
]
meta = PaginationMeta(
total=total,
limit=self.limit,
offset=self.offset,
count=len(items),
)
return build_pagination_payload(
items,
meta,
filters={
"event_type": self.event_type,
"days": self.days,
},
)
class UserRateLimitStatusAdapter(AuthenticatedApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
user = context.user
if not user:
raise HTTPException(status_code=401, detail="未登录")
rate_limiter = _get_rate_limit_plugin()
if not rate_limiter or not hasattr(rate_limiter, "get_rate_limit_headers"):
raise HTTPException(status_code=503, detail="速率限制插件未启用或不支持状态查询")
api_keys = (
db.query(ApiKey)
.filter(ApiKey.user_id == user.id, ApiKey.is_active.is_(True))
.order_by(ApiKey.created_at.desc())
.all()
)
rate_limit_info = []
for key in api_keys:
try:
headers = rate_limiter.get_rate_limit_headers(key)
except Exception as exc:
logger.warning(f"无法获取Key {key.id} 的限流信息: {exc}")
headers = {}
rate_limit_info.append(
{
"api_key_name": key.name or f"Key-{key.id}",
"limit": headers.get("X-RateLimit-Limit"),
"remaining": headers.get("X-RateLimit-Remaining"),
"reset_time": headers.get("X-RateLimit-Reset"),
"window": headers.get("X-RateLimit-Window"),
}
)
return {"user_id": user.id, "api_keys": rate_limit_info}
def _get_rate_limit_plugin():
try:
plugin_manager = get_plugin_manager()
return plugin_manager.get_plugin("rate_limit")
except Exception as exc:
logger.warning(f"获取速率限制插件失败: {exc}")
return None

View File

@@ -0,0 +1,20 @@
"""Public-facing API routers."""
from fastapi import APIRouter
from .capabilities import router as capabilities_router
from .catalog import router as catalog_router
from .claude import router as claude_router
from .gemini import router as gemini_router
from .openai import router as openai_router
from .system_catalog import router as system_catalog_router
router = APIRouter()
router.include_router(claude_router, tags=["Claude API"])
router.include_router(openai_router)
router.include_router(gemini_router, tags=["Gemini API"])
router.include_router(system_catalog_router, tags=["System Catalog"])
router.include_router(catalog_router)
router.include_router(capabilities_router)
__all__ = ["router"]

View File

@@ -0,0 +1,104 @@
"""
能力配置公共 API
提供系统支持的能力列表,供前端展示和配置使用。
"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from src.core.key_capabilities import (
get_all_capabilities,
get_user_configurable_capabilities,
)
from src.database import get_db
router = APIRouter(prefix="/api/capabilities", tags=["Capabilities"])
@router.get("")
async def list_capabilities():
"""获取所有能力定义"""
return {
"capabilities": [
{
"name": cap.name,
"display_name": cap.display_name,
"short_name": cap.short_name,
"description": cap.description,
"match_mode": cap.match_mode.value,
"config_mode": cap.config_mode.value,
}
for cap in get_all_capabilities()
]
}
@router.get("/user-configurable")
async def list_user_configurable_capabilities():
"""获取用户可配置的能力列表(用于前端展示配置选项)"""
return {
"capabilities": [
{
"name": cap.name,
"display_name": cap.display_name,
"short_name": cap.short_name,
"description": cap.description,
"match_mode": cap.match_mode.value,
"config_mode": cap.config_mode.value,
}
for cap in get_user_configurable_capabilities()
]
}
@router.get("/model/{model_name}")
async def get_model_supported_capabilities(
model_name: str,
db: Session = Depends(get_db),
):
"""
获取指定模型支持的能力列表
Args:
model_name: 模型名称(如 claude-sonnet-4-20250514
Returns:
模型支持的能力列表,以及每个能力的详细定义
"""
from src.services.model.mapping_resolver import get_model_mapping_resolver
mapping_resolver = get_model_mapping_resolver()
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
if not global_model:
return {
"model": model_name,
"supported_capabilities": [],
"capability_details": [],
"error": "模型不存在",
}
supported_caps = global_model.supported_capabilities or []
# 获取支持的能力详情
all_caps = {cap.name: cap for cap in get_all_capabilities()}
capability_details = []
for cap_name in supported_caps:
if cap_name in all_caps:
cap = all_caps[cap_name]
capability_details.append({
"name": cap.name,
"display_name": cap.display_name,
"description": cap.description,
"match_mode": cap.match_mode.value,
"config_mode": cap.config_mode.value,
})
return {
"model": model_name,
"global_model_id": str(global_model.id),
"global_model_name": global_model.name,
"supported_capabilities": supported_caps,
"capability_details": capability_details,
}

643
src/api/public/catalog.py Normal file
View File

@@ -0,0 +1,643 @@
"""
公开API端点 - 用户可查看的提供商和模型信息
不包含敏感信息,普通用户可访问
"""
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import and_, func, or_
from sqlalchemy.orm import Session, joinedload
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.pipeline import ApiRequestPipeline
from src.core.logger import logger
from src.database import get_db
from src.models.api import (
ProviderStatsResponse,
PublicGlobalModelListResponse,
PublicGlobalModelResponse,
PublicModelMappingResponse,
PublicModelResponse,
PublicProviderResponse,
)
from src.models.database import (
GlobalModel,
Model,
ModelMapping,
Provider,
ProviderEndpoint,
RequestCandidate,
)
from src.models.endpoint_models import (
PublicApiFormatHealthMonitor,
PublicApiFormatHealthMonitorResponse,
PublicHealthEvent,
)
from src.services.health.endpoint import EndpointHealthService
router = APIRouter(prefix="/api/public", tags=["Public Catalog"])
pipeline = ApiRequestPipeline()
@router.get("/providers", response_model=List[PublicProviderResponse])
async def get_public_providers(
request: Request,
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
skip: int = Query(0, description="跳过记录数"),
limit: int = Query(100, description="返回记录数限制"),
db: Session = Depends(get_db),
):
"""获取提供商列表(用户视图)。"""
adapter = PublicProvidersAdapter(is_active=is_active, skip=skip, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/models", response_model=List[PublicModelResponse])
async def get_public_models(
request: Request,
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
skip: int = Query(0, description="跳过记录数"),
limit: int = Query(100, description="返回记录数限制"),
db: Session = Depends(get_db),
):
adapter = PublicModelsAdapter(
provider_id=provider_id, is_active=is_active, skip=skip, limit=limit
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/model-mappings", response_model=List[PublicModelMappingResponse])
async def get_public_model_mappings(
request: Request,
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
alias: Optional[str] = Query(None, description="别名过滤原source_model"),
skip: int = Query(0, description="跳过记录数"),
limit: int = Query(100, description="返回记录数限制"),
db: Session = Depends(get_db),
):
adapter = PublicModelMappingsAdapter(
provider_id=provider_id,
alias=alias,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/stats", response_model=ProviderStatsResponse)
async def get_public_stats(request: Request, db: Session = Depends(get_db)):
adapter = PublicStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/search/models")
async def search_models(
request: Request,
q: str = Query(..., description="搜索关键词"),
provider_id: Optional[int] = Query(None, description="提供商ID过滤"),
limit: int = Query(20, description="返回记录数限制"),
db: Session = Depends(get_db),
):
adapter = PublicSearchModelsAdapter(query=q, provider_id=provider_id, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/health/api-formats", response_model=PublicApiFormatHealthMonitorResponse)
async def get_public_api_format_health(
request: Request,
lookback_hours: int = Query(6, ge=1, le=168, description="回溯小时数"),
per_format_limit: int = Query(100, ge=10, le=500, description="每个格式的事件数限制"),
db: Session = Depends(get_db),
):
"""获取各 API 格式的健康监控数据(公开版,不含敏感信息)"""
adapter = PublicApiFormatHealthMonitorAdapter(
lookback_hours=lookback_hours,
per_format_limit=per_format_limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/global-models", response_model=PublicGlobalModelListResponse)
async def get_public_global_models(
request: Request,
skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回记录数限制"),
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
search: Optional[str] = Query(None, description="搜索关键词"),
db: Session = Depends(get_db),
):
"""获取 GlobalModel 列表(用户视图,只读)"""
adapter = PublicGlobalModelsAdapter(
skip=skip,
limit=limit,
is_active=is_active,
search=search,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
# -------- 公共适配器 --------
class PublicApiAdapter(ApiAdapter):
mode = ApiMode.PUBLIC
def authorize(self, context): # type: ignore[override]
return None
@dataclass
class PublicProvidersAdapter(PublicApiAdapter):
is_active: Optional[bool]
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求提供商列表")
query = db.query(Provider)
if self.is_active is not None:
query = query.filter(Provider.is_active == self.is_active)
else:
query = query.filter(Provider.is_active.is_(True))
providers = query.offset(self.skip).limit(self.limit).all()
result = []
for provider in providers:
models_count = db.query(Model).filter(Model.provider_id == provider.id).count()
active_models_count = (
db.query(Model)
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True)))
.count()
)
mappings_count = (
db.query(ModelMapping)
.filter(
and_(ModelMapping.provider_id == provider.id, ModelMapping.is_active.is_(True))
)
.count()
)
endpoints_count = len(provider.endpoints) if provider.endpoints else 0
active_endpoints_count = (
sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0
)
provider_data = PublicProviderResponse(
id=provider.id,
name=provider.name,
display_name=provider.display_name,
description=provider.description,
is_active=provider.is_active,
provider_priority=provider.provider_priority,
models_count=models_count,
active_models_count=active_models_count,
mappings_count=mappings_count,
endpoints_count=endpoints_count,
active_endpoints_count=active_endpoints_count,
)
result.append(provider_data.model_dump())
logger.debug(f"返回 {len(result)} 个提供商信息")
return result
@dataclass
class PublicModelsAdapter(PublicApiAdapter):
provider_id: Optional[str]
is_active: Optional[bool]
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求模型列表")
query = (
db.query(Model, Provider)
.options(joinedload(Model.global_model))
.join(Provider)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
)
if self.provider_id is not None:
query = query.filter(Model.provider_id == self.provider_id)
results = query.offset(self.skip).limit(self.limit).all()
response = []
for model, provider in results:
global_model = model.global_model
display_name = global_model.display_name if global_model else model.provider_model_name
unified_name = global_model.name if global_model else model.provider_model_name
model_data = PublicModelResponse(
id=model.id,
provider_id=model.provider_id,
provider_name=provider.name,
provider_display_name=provider.display_name,
name=unified_name,
display_name=display_name,
description=global_model.description if global_model else None,
tags=None,
icon_url=global_model.icon_url if global_model else None,
input_price_per_1m=model.get_effective_input_price(),
output_price_per_1m=model.get_effective_output_price(),
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
cache_read_price_per_1m=model.get_effective_cache_read_price(),
supports_vision=model.get_effective_supports_vision(),
supports_function_calling=model.get_effective_supports_function_calling(),
supports_streaming=model.get_effective_supports_streaming(),
is_active=model.is_active,
)
response.append(model_data.model_dump())
logger.debug(f"返回 {len(response)} 个模型信息")
return response
@dataclass
class PublicModelMappingsAdapter(PublicApiAdapter):
provider_id: Optional[str]
alias: Optional[str] # 原 source_model改为 alias
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求模型映射列表")
query = (
db.query(ModelMapping, GlobalModel, Provider)
.join(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
.outerjoin(Provider, ModelMapping.provider_id == Provider.id)
.filter(
and_(
ModelMapping.is_active.is_(True),
GlobalModel.is_active.is_(True),
)
)
)
if self.provider_id is not None:
provider_global_model_ids = (
db.query(Model.global_model_id)
.join(Provider, Model.provider_id == Provider.id)
.filter(
Provider.id == self.provider_id,
Model.is_active.is_(True),
Provider.is_active.is_(True),
Model.global_model_id.isnot(None),
)
.distinct()
)
query = query.filter(
or_(
ModelMapping.provider_id == self.provider_id,
and_(
ModelMapping.provider_id.is_(None),
ModelMapping.target_global_model_id.in_(provider_global_model_ids),
),
)
)
else:
query = query.filter(ModelMapping.provider_id.is_(None))
if self.alias is not None:
query = query.filter(ModelMapping.source_model.ilike(f"%{self.alias}%"))
results = query.offset(self.skip).limit(self.limit).all()
response = []
for mapping, global_model, provider in results:
scope = "provider" if mapping.provider_id else "global"
mapping_data = PublicModelMappingResponse(
id=mapping.id,
source_model=mapping.source_model,
target_global_model_id=mapping.target_global_model_id,
target_global_model_name=global_model.name if global_model else None,
target_global_model_display_name=(
global_model.display_name if global_model else None
),
provider_id=mapping.provider_id,
scope=scope,
is_active=mapping.is_active,
)
response.append(mapping_data.model_dump())
logger.debug(f"返回 {len(response)} 个模型映射")
return response
class PublicStatsAdapter(PublicApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求系统统计信息")
active_providers = db.query(Provider).filter(Provider.is_active.is_(True)).count()
active_models = (
db.query(Model)
.join(Provider)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
.count()
)
from ...models.database import ModelMapping
active_mappings = db.query(ModelMapping).filter(ModelMapping.is_active.is_(True)).count()
formats = (
db.query(Provider.api_format).filter(Provider.is_active.is_(True)).distinct().all()
)
supported_formats = [f.api_format for f in formats if f.api_format]
stats = ProviderStatsResponse(
total_providers=active_providers,
active_providers=active_providers,
total_models=active_models,
active_models=active_models,
total_mappings=active_mappings,
supported_formats=supported_formats,
)
logger.debug("返回系统统计信息")
return stats.model_dump()
@dataclass
class PublicSearchModelsAdapter(PublicApiAdapter):
query: str
provider_id: Optional[int]
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug(f"公共API搜索模型: {self.query}")
query_stmt = (
db.query(Model, Provider)
.options(joinedload(Model.global_model))
.join(Provider)
.outerjoin(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
)
search_filter = (
Model.provider_model_name.ilike(f"%{self.query}%")
| GlobalModel.name.ilike(f"%{self.query}%")
| GlobalModel.display_name.ilike(f"%{self.query}%")
| GlobalModel.description.ilike(f"%{self.query}%")
)
query_stmt = query_stmt.filter(search_filter)
if self.provider_id is not None:
query_stmt = query_stmt.filter(Model.provider_id == self.provider_id)
results = query_stmt.limit(self.limit).all()
response = []
for model, provider in results:
global_model = model.global_model
display_name = global_model.display_name if global_model else model.provider_model_name
unified_name = global_model.name if global_model else model.provider_model_name
model_data = PublicModelResponse(
id=model.id,
provider_id=model.provider_id,
provider_name=provider.name,
provider_display_name=provider.display_name,
name=unified_name,
display_name=display_name,
description=global_model.description if global_model else None,
tags=None,
icon_url=global_model.icon_url if global_model else None,
input_price_per_1m=model.get_effective_input_price(),
output_price_per_1m=model.get_effective_output_price(),
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
cache_read_price_per_1m=model.get_effective_cache_read_price(),
supports_vision=model.get_effective_supports_vision(),
supports_function_calling=model.get_effective_supports_function_calling(),
supports_streaming=model.get_effective_supports_streaming(),
is_active=model.is_active,
)
response.append(model_data.model_dump())
logger.debug(f"搜索 '{self.query}' 返回 {len(response)} 个结果")
return response
@dataclass
class PublicApiFormatHealthMonitorAdapter(PublicApiAdapter):
"""公开版 API 格式健康监控适配器(返回 events 数组,前端复用 EndpointHealthTimeline 组件)"""
lookback_hours: int
per_format_limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
now = datetime.now(timezone.utc)
since = now - timedelta(hours=self.lookback_hours)
# 1. 获取所有活跃的 API 格式
active_formats = (
db.query(ProviderEndpoint.api_format)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.distinct()
.all()
)
all_formats: List[str] = []
for (api_format_enum,) in active_formats:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
all_formats.append(api_format)
# API 格式 -> Endpoint ID 映射(用于 Usage 时间线)
endpoint_rows = (
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.all()
)
endpoint_map: Dict[str, List[str]] = defaultdict(list)
for api_format_enum, endpoint_id in endpoint_rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
endpoint_map[api_format].append(endpoint_id)
# 2. 获取最近一段时间的 RequestCandidate限制数量
# 只查询最终状态的记录success, failed, skipped
final_statuses = ["success", "failed", "skipped"]
limit_rows = max(500, self.per_format_limit * 10)
rows = (
db.query(
RequestCandidate,
ProviderEndpoint.api_format,
)
.join(ProviderEndpoint, RequestCandidate.endpoint_id == ProviderEndpoint.id)
.filter(
RequestCandidate.created_at >= since,
RequestCandidate.status.in_(final_statuses),
)
.order_by(RequestCandidate.created_at.desc())
.limit(limit_rows)
.all()
)
grouped_candidates: Dict[str, List[RequestCandidate]] = {}
for candidate, api_format_enum in rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
if api_format not in grouped_candidates:
grouped_candidates[api_format] = []
if len(grouped_candidates[api_format]) < self.per_format_limit:
grouped_candidates[api_format].append(candidate)
# 3. 为所有活跃格式生成监控数据
monitors: List[PublicApiFormatHealthMonitor] = []
for api_format in all_formats:
candidates = grouped_candidates.get(api_format, [])
# 统计
success_count = sum(1 for c in candidates if c.status == "success")
failed_count = sum(1 for c in candidates if c.status == "failed")
skipped_count = sum(1 for c in candidates if c.status == "skipped")
total_attempts = len(candidates)
# 计算成功率 = success / (success + failed)
actual_completed = success_count + failed_count
success_rate = success_count / actual_completed if actual_completed > 0 else 1.0
# 转换为公开版事件列表(不含敏感信息如 provider_id, key_id
events: List[PublicHealthEvent] = []
for c in candidates:
event_time = c.finished_at or c.started_at or c.created_at
events.append(
PublicHealthEvent(
timestamp=event_time,
status=c.status,
status_code=c.status_code,
latency_ms=c.latency_ms,
error_type=c.error_type,
)
)
# 最后事件时间
last_event_at = None
if candidates:
last_event_at = (
candidates[0].finished_at
or candidates[0].started_at
or candidates[0].created_at
)
timeline_data = EndpointHealthService._generate_timeline_from_usage(
db=db,
endpoint_ids=endpoint_map.get(api_format, []),
now=now,
lookback_hours=self.lookback_hours,
)
# 获取本站入口路径
from src.core.api_format_metadata import get_local_path
from src.core.enums import APIFormat
try:
api_format_enum = APIFormat(api_format)
local_path = get_local_path(api_format_enum)
except ValueError:
local_path = "/"
monitors.append(
PublicApiFormatHealthMonitor(
api_format=api_format,
api_path=local_path,
total_attempts=total_attempts,
success_count=success_count,
failed_count=failed_count,
skipped_count=skipped_count,
success_rate=success_rate,
last_event_at=last_event_at,
events=events,
timeline=timeline_data.get("timeline", []),
time_range_start=timeline_data.get("time_range_start"),
time_range_end=timeline_data.get("time_range_end"),
)
)
response = PublicApiFormatHealthMonitorResponse(
generated_at=now,
formats=monitors,
)
logger.debug(f"公开健康监控: 返回 {len(monitors)} 个 API 格式的健康数据")
return response
@dataclass
class PublicGlobalModelsAdapter(PublicApiAdapter):
"""公开的 GlobalModel 列表适配器"""
skip: int
limit: int
is_active: Optional[bool]
search: Optional[str]
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求 GlobalModel 列表")
query = db.query(GlobalModel)
# 默认只返回活跃的模型
if self.is_active is not None:
query = query.filter(GlobalModel.is_active == self.is_active)
else:
query = query.filter(GlobalModel.is_active.is_(True))
# 搜索过滤
if self.search:
search_term = f"%{self.search}%"
query = query.filter(
or_(
GlobalModel.name.ilike(search_term),
GlobalModel.display_name.ilike(search_term),
GlobalModel.description.ilike(search_term),
)
)
# 统计总数
total = query.count()
# 分页
models = query.order_by(GlobalModel.name).offset(self.skip).limit(self.limit).all()
# 转换为响应格式
model_responses = []
for gm in models:
model_responses.append(
PublicGlobalModelResponse(
id=gm.id,
name=gm.name,
display_name=gm.display_name,
description=gm.description,
icon_url=gm.icon_url,
is_active=gm.is_active,
default_price_per_request=gm.default_price_per_request,
default_tiered_pricing=gm.default_tiered_pricing,
default_supports_vision=gm.default_supports_vision or False,
default_supports_function_calling=gm.default_supports_function_calling or False,
default_supports_streaming=(
gm.default_supports_streaming
if gm.default_supports_streaming is not None
else True
),
default_supports_extended_thinking=gm.default_supports_extended_thinking
or False,
supported_capabilities=gm.supported_capabilities,
)
)
logger.debug(f"返回 {len(model_responses)} 个 GlobalModel")
return PublicGlobalModelListResponse(models=model_responses, total=total)

52
src/api/public/claude.py Normal file
View File

@@ -0,0 +1,52 @@
"""
Claude API 端点
- /v1/messages - Claude Messages API
- /v1/messages/count_tokens - Token Count API
"""
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from src.api.base.pipeline import ApiRequestPipeline
from src.api.handlers.claude import (
ClaudeTokenCountAdapter,
build_claude_adapter,
)
from src.core.api_format_metadata import get_api_format_definition
from src.core.enums import APIFormat
from src.database import get_db
_claude_def = get_api_format_definition(APIFormat.CLAUDE)
router = APIRouter(tags=["Claude API"], prefix=_claude_def.path_prefix)
pipeline = ApiRequestPipeline()
@router.post("/v1/messages")
async def create_message(
http_request: Request,
db: Session = Depends(get_db),
):
"""统一入口:根据 x-app 自动在标准/Claude Code 之间切换。"""
adapter = build_claude_adapter(http_request.headers.get("x-app", ""))
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
api_format_hint=adapter.allowed_api_formats[0],
)
@router.post("/v1/messages/count_tokens")
async def count_tokens(
http_request: Request,
db: Session = Depends(get_db),
):
adapter = ClaudeTokenCountAdapter()
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
)

130
src/api/public/gemini.py Normal file
View File

@@ -0,0 +1,130 @@
"""
Gemini API 专属端点
托管 Gemini API 相关路由:
- /v1beta/models/{model}:generateContent
- /v1beta/models/{model}:streamGenerateContent
注意: Gemini API 的 model 在 URL 路径中,而不是请求体中
路径配置来源: src.core.api_format_metadata.APIFormat.GEMINI
- path_prefix: 本站路径前缀(如 /gemini通过 router prefix 配置
- default_path: 标准 API 路径模板
"""
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from src.api.base.pipeline import ApiRequestPipeline
from src.api.handlers.gemini import build_gemini_adapter
from src.api.handlers.gemini_cli import build_gemini_cli_adapter
from src.core.api_format_metadata import get_api_format_definition
from src.core.enums import APIFormat
from src.database import get_db
# 从配置获取路径前缀
_gemini_def = get_api_format_definition(APIFormat.GEMINI)
router = APIRouter(tags=["Gemini API"], prefix=_gemini_def.path_prefix)
pipeline = ApiRequestPipeline()
def _is_cli_request(request: Request) -> bool:
"""
判断是否为 CLI 请求
检查顺序:
1. x-app header 包含 "cli"
2. user-agent 包含 "GeminiCLI""gemini-cli"
"""
# 检查 x-app header
x_app = request.headers.get("x-app", "")
if "cli" in x_app.lower():
return True
# 检查 user-agent
user_agent = request.headers.get("user-agent", "")
user_agent_lower = user_agent.lower()
if "geminicli" in user_agent_lower or "gemini-cli" in user_agent_lower:
return True
return False
@router.post("/v1beta/models/{model}:generateContent")
async def generate_content(
model: str,
http_request: Request,
db: Session = Depends(get_db),
):
"""
Gemini generateContent 端点
非流式生成内容请求
"""
# 根据 user-agent 或 x-app header 选择适配器
if _is_cli_request(http_request):
adapter = build_gemini_cli_adapter()
else:
adapter = build_gemini_adapter()
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
api_format_hint=adapter.allowed_api_formats[0],
# 将 model 注入到请求体中stream 用于内部判断流式模式
path_params={"model": model, "stream": False},
)
@router.post("/v1beta/models/{model}:streamGenerateContent")
async def stream_generate_content(
model: str,
http_request: Request,
db: Session = Depends(get_db),
):
"""
Gemini streamGenerateContent 端点
流式生成内容请求
注意: Gemini API 通过 URL 端点区分流式/非流式,不需要在请求体中添加 stream 字段
"""
# 根据 user-agent 或 x-app header 选择适配器
if _is_cli_request(http_request):
adapter = build_gemini_cli_adapter()
else:
adapter = build_gemini_adapter()
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
api_format_hint=adapter.allowed_api_formats[0],
# model 注入到请求体stream 用于内部判断流式模式(不发送到 API
path_params={"model": model, "stream": True},
)
# 兼容 v1 路径(部分 SDK 可能使用)
@router.post("/v1/models/{model}:generateContent")
async def generate_content_v1(
model: str,
http_request: Request,
db: Session = Depends(get_db),
):
"""v1 兼容端点"""
return await generate_content(model, http_request, db)
@router.post("/v1/models/{model}:streamGenerateContent")
async def stream_generate_content_v1(
model: str,
http_request: Request,
db: Session = Depends(get_db),
):
"""v1 兼容端点"""
return await stream_generate_content(model, http_request, db)

50
src/api/public/openai.py Normal file
View File

@@ -0,0 +1,50 @@
"""
OpenAI API 端点
- /v1/chat/completions - OpenAI Chat API
- /v1/responses - OpenAI Responses API (CLI)
"""
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from src.api.base.pipeline import ApiRequestPipeline
from src.api.handlers.openai import OpenAIChatAdapter
from src.api.handlers.openai_cli import OpenAICliAdapter
from src.core.api_format_metadata import get_api_format_definition
from src.core.enums import APIFormat
from src.database import get_db
_openai_def = get_api_format_definition(APIFormat.OPENAI)
router = APIRouter(tags=["OpenAI API"], prefix=_openai_def.path_prefix)
pipeline = ApiRequestPipeline()
@router.post("/v1/chat/completions")
async def create_chat_completion(
http_request: Request,
db: Session = Depends(get_db),
):
adapter = OpenAIChatAdapter()
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
api_format_hint=adapter.allowed_api_formats[0],
)
@router.post("/v1/responses")
async def create_responses(
http_request: Request,
db: Session = Depends(get_db),
):
adapter = OpenAICliAdapter()
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
api_format_hint=adapter.allowed_api_formats[0],
)

View File

@@ -0,0 +1,306 @@
"""
System Catalog / 健康检查相关端点
这些是系统工具端点,不需要复杂的 Adapter 抽象。
"""
from datetime import datetime, timezone
from typing import Any, Dict, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy import func
from sqlalchemy.orm import Session, selectinload
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.clients.redis_client import get_redis_client, get_redis_client_sync
from src.core.logger import logger
from src.database import get_db
from src.database.database import get_pool_status
from src.models.database import Model, Provider
from src.services.orchestration.fallback_orchestrator import FallbackOrchestrator
from src.services.provider.transport import build_provider_url
router = APIRouter(tags=["System Catalog"])
# ============== 辅助函数 ==============
def _as_bool(value: Optional[str], default: bool) -> bool:
"""将字符串转换为布尔值"""
if value is None:
return default
return value.lower() in {"1", "true", "yes", "on"}
def _serialize_provider(
provider: Provider,
include_models: bool,
include_endpoints: bool,
) -> Dict[str, Any]:
"""序列化 Provider 对象"""
provider_data: Dict[str, Any] = {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"is_active": provider.is_active,
"provider_priority": provider.provider_priority,
}
if include_endpoints:
provider_data["endpoints"] = [
{
"id": endpoint.id,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
"is_active": endpoint.is_active,
}
for endpoint in provider.endpoints or []
]
if include_models:
provider_data["models"] = [
{
"id": model.id,
"name": (
model.global_model.name if model.global_model else model.provider_model_name
),
"display_name": (
model.global_model.display_name
if model.global_model
else model.provider_model_name
),
"is_active": model.is_active,
"supports_streaming": model.supports_streaming,
}
for model in provider.models or []
if model.is_active
]
return provider_data
def _select_provider(db: Session, provider_name: Optional[str]) -> Optional[Provider]:
"""选择 Provider按 provider_priority 优先级选择)"""
query = db.query(Provider).filter(Provider.is_active == True)
if provider_name:
provider = query.filter(Provider.name == provider_name).first()
if provider:
return provider
# 按优先级选择provider_priority 最小的优先)
return query.order_by(Provider.provider_priority.asc()).first()
# ============== 端点 ==============
@router.get("/v1/health")
async def service_health(db: Session = Depends(get_db)):
"""返回服务健康状态与依赖信息"""
active_providers = (
db.query(func.count(Provider.id)).filter(Provider.is_active == True).scalar() or 0
)
active_models = db.query(func.count(Model.id)).filter(Model.is_active == True).scalar() or 0
redis_info: Dict[str, Any] = {"status": "unknown"}
try:
redis = await get_redis_client()
if redis:
await redis.ping()
redis_info = {"status": "ok"}
else:
redis_info = {"status": "degraded", "message": "Redis client not initialized"}
except Exception as exc:
redis_info = {"status": "error", "message": str(exc)}
return {
"status": "ok",
"timestamp": datetime.now(timezone.utc).isoformat(),
"stats": {
"active_providers": active_providers,
"active_models": active_models,
},
"dependencies": {
"database": {"status": "ok"},
"redis": redis_info,
},
}
@router.get("/health")
async def health_check():
"""简单健康检查端点(无需认证)"""
try:
pool_status = get_pool_status()
pool_health = {
"checked_out": pool_status["checked_out"],
"pool_size": pool_status["pool_size"],
"overflow": pool_status["overflow"],
"max_capacity": pool_status["max_capacity"],
"usage_rate": (
f"{(pool_status['checked_out'] / pool_status['max_capacity'] * 100):.1f}%"
if pool_status["max_capacity"] > 0
else "0.0%"
),
}
except Exception as e:
pool_health = {"error": str(e)}
return {
"status": "healthy",
"timestamp": datetime.now(timezone.utc).isoformat(),
"database_pool": pool_health,
}
@router.get("/")
async def root(db: Session = Depends(get_db)):
"""Root endpoint - 服务信息概览"""
# 按优先级选择最高优先级的提供商
top_provider = (
db.query(Provider)
.filter(Provider.is_active == True)
.order_by(Provider.provider_priority.asc())
.first()
)
active_providers = db.query(Provider).filter(Provider.is_active == True).count()
return {
"message": "AI Proxy with Modular Architecture v4.0.0",
"status": "running",
"current_provider": top_provider.name if top_provider else "None",
"available_providers": active_providers,
"config": {},
"endpoints": {
"messages": "/v1/messages",
"count_tokens": "/v1/messages/count_tokens",
"health": "/v1/health",
"providers": "/v1/providers",
"test_connection": "/v1/test-connection",
},
}
@router.get("/v1/providers")
async def list_providers(
db: Session = Depends(get_db),
include_models: bool = Query(False),
include_endpoints: bool = Query(False),
active_only: bool = Query(True),
):
"""列出所有 Provider"""
load_options = []
if include_models:
load_options.append(selectinload(Provider.models).selectinload(Model.global_model))
if include_endpoints:
load_options.append(selectinload(Provider.endpoints))
base_query = db.query(Provider)
if load_options:
base_query = base_query.options(*load_options)
if active_only:
base_query = base_query.filter(Provider.is_active == True)
base_query = base_query.order_by(Provider.provider_priority.asc(), Provider.name.asc())
providers = base_query.all()
return {
"providers": [
_serialize_provider(provider, include_models, include_endpoints)
for provider in providers
]
}
@router.get("/v1/providers/{provider_identifier}")
async def provider_detail(
provider_identifier: str,
db: Session = Depends(get_db),
include_models: bool = Query(False),
include_endpoints: bool = Query(False),
):
"""获取单个 Provider 详情"""
load_options = []
if include_models:
load_options.append(selectinload(Provider.models).selectinload(Model.global_model))
if include_endpoints:
load_options.append(selectinload(Provider.endpoints))
base_query = db.query(Provider)
if load_options:
base_query = base_query.options(*load_options)
provider = base_query.filter(
(Provider.id == provider_identifier) | (Provider.name == provider_identifier)
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
return _serialize_provider(provider, include_models, include_endpoints)
@router.get("/v1/test-connection")
@router.get("/test-connection")
async def test_connection(
request: Request,
db: Session = Depends(get_db),
provider: Optional[str] = Query(None),
model: str = Query("claude-3-haiku-20240307"),
api_format: Optional[str] = Query(None),
):
"""测试 Provider 连接"""
selected_provider = _select_provider(db, provider)
if not selected_provider:
raise HTTPException(status_code=503, detail="No active provider available")
# 构建测试请求体
payload = {
"model": model,
"messages": [{"role": "user", "content": "Health check"}],
"max_tokens": 5,
}
# 确定 API 格式
format_value = api_format or "CLAUDE"
# 创建 FallbackOrchestrator
redis_client = get_redis_client_sync()
orchestrator = FallbackOrchestrator(db, redis_client)
# 定义请求函数
async def test_request_func(_prov, endpoint, key):
request_builder = PassthroughRequestBuilder()
provider_payload, provider_headers = request_builder.build(
payload, {}, endpoint, key, is_stream=False
)
url = build_provider_url(
endpoint,
query_params=dict(request.query_params),
path_params={"model": model},
is_stream=False,
)
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(url, json=provider_payload, headers=provider_headers)
resp.raise_for_status()
return resp.json()
try:
response, actual_provider, *_ = await orchestrator.execute_with_fallback(
api_format=format_value,
model_name=model,
user_api_key=None,
request_func=test_request_func,
request_id=None,
)
return {
"status": "success",
"provider": actual_provider,
"timestamp": datetime.now(timezone.utc).isoformat(),
"response_id": response.get("id", "unknown"),
}
except Exception as exc:
logger.error(f"API connectivity test failed: {exc}")
raise HTTPException(status_code=503, detail=str(exc))

View File

@@ -0,0 +1,10 @@
"""Routes for authenticated user self-service APIs."""
from fastapi import APIRouter
from .routes import router as me_router
router = APIRouter()
router.include_router(me_router)
__all__ = ["router"]

1127
src/api/user_me/routes.py Normal file

File diff suppressed because it is too large Load Diff

11
src/clients/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
from .http_client import HTTPClientPool, close_http_clients, get_http_client
from .redis_client import close_redis_client, get_redis_client, get_redis_client_sync
__all__ = [
"HTTPClientPool",
"get_http_client",
"close_http_clients",
"get_redis_client",
"get_redis_client_sync",
"close_redis_client",
]

133
src/clients/http_client.py Normal file
View File

@@ -0,0 +1,133 @@
"""
全局HTTP客户端池管理
避免每次请求都创建新的AsyncClient,提高性能
"""
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional
import httpx
from src.core.logger import logger
class HTTPClientPool:
"""
全局HTTP客户端池单例
管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接
"""
_instance: Optional["HTTPClientPool"] = None
_default_client: Optional[httpx.AsyncClient] = None
_clients: Dict[str, httpx.AsyncClient] = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def get_default_client(cls) -> httpx.AsyncClient:
"""
获取默认的HTTP客户端
用于大多数HTTP请求,具有合理的默认配置
"""
if cls._default_client is None:
cls._default_client = httpx.AsyncClient(
http2=False, # 暂时禁用HTTP/2以提高兼容性
verify=True, # 启用SSL验证
timeout=httpx.Timeout(
connect=10.0, # 连接超时
read=300.0, # 读取超时(5分钟,适合流式响应)
write=10.0, # 写入超时
pool=5.0, # 连接池超时
),
limits=httpx.Limits(
max_connections=100, # 最大连接数
max_keepalive_connections=20, # 最大保活连接数
keepalive_expiry=30.0, # 保活过期时间(秒)
),
follow_redirects=True, # 跟随重定向
)
logger.info("全局HTTP客户端池已初始化")
return cls._default_client
@classmethod
def get_client(cls, name: str, **kwargs: Any) -> httpx.AsyncClient:
"""
获取或创建命名的HTTP客户端
用于需要特定配置的场景(如不同的超时设置、代理等)
Args:
name: 客户端标识符
**kwargs: httpx.AsyncClient的配置参数
"""
if name not in cls._clients:
# 合并默认配置和自定义配置
config = {
"http2": False,
"verify": True,
"timeout": httpx.Timeout(10.0, read=300.0),
"follow_redirects": True,
}
config.update(kwargs)
cls._clients[name] = httpx.AsyncClient(**config)
logger.debug(f"创建命名HTTP客户端: {name}")
return cls._clients[name]
@classmethod
async def close_all(cls):
"""关闭所有HTTP客户端"""
if cls._default_client is not None:
await cls._default_client.aclose()
cls._default_client = None
logger.info("默认HTTP客户端已关闭")
for name, client in cls._clients.items():
await client.aclose()
logger.debug(f"命名HTTP客户端已关闭: {name}")
cls._clients.clear()
logger.info("所有HTTP客户端已关闭")
@classmethod
@asynccontextmanager
async def get_temp_client(cls, **kwargs: Any):
"""
获取临时HTTP客户端(上下文管理器)
用于一次性请求,使用后自动关闭
用法:
async with HTTPClientPool.get_temp_client() as client:
response = await client.get('https://example.com')
"""
config = {
"http2": False,
"verify": True,
"timeout": httpx.Timeout(10.0),
}
config.update(kwargs)
client = httpx.AsyncClient(**config)
try:
yield client
finally:
await client.aclose()
# 便捷访问函数
def get_http_client() -> httpx.AsyncClient:
"""获取默认HTTP客户端的便捷函数"""
return HTTPClientPool.get_default_client()
async def close_http_clients():
"""关闭所有HTTP客户端的便捷函数"""
await HTTPClientPool.close_all()

346
src/clients/redis_client.py Normal file
View File

@@ -0,0 +1,346 @@
"""
全局Redis客户端管理
提供统一的Redis客户端访问确保所有服务使用同一个连接池
熔断器说明:
- 连续失败达到阈值后开启熔断
- 熔断期间返回明确的状态而非静默失败
- 调用方可以根据状态决定降级策略
"""
import os
import time
from enum import Enum
from typing import Optional
import redis.asyncio as aioredis
from src.core.logger import logger
from redis.asyncio import sentinel as redis_sentinel
class RedisState(Enum):
"""Redis 连接状态"""
NOT_INITIALIZED = "not_initialized" # 未初始化
CONNECTED = "connected" # 已连接
CIRCUIT_OPEN = "circuit_open" # 熔断中
DISCONNECTED = "disconnected" # 断开连接
class RedisClientManager:
"""
Redis客户端管理器单例
提供 Redis 连接管理、熔断器保护和状态监控。
"""
_instance: Optional["RedisClientManager"] = None
_redis: Optional[aioredis.Redis] = None
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# 避免重复初始化
if getattr(self, "_initialized", False):
return
self._initialized = True
self._circuit_open_until: Optional[float] = None
self._consecutive_failures: int = 0
self._circuit_threshold = int(os.getenv("REDIS_CIRCUIT_BREAKER_THRESHOLD", "3"))
self._circuit_reset_seconds = int(os.getenv("REDIS_CIRCUIT_BREAKER_RESET_SECONDS", "60"))
self._last_error: Optional[str] = None # 记录最后一次错误
def get_state(self) -> RedisState:
"""
获取 Redis 连接状态
Returns:
当前连接状态枚举值
"""
if self._redis is not None:
return RedisState.CONNECTED
if self._circuit_open_until and time.time() < self._circuit_open_until:
return RedisState.CIRCUIT_OPEN
if self._last_error:
return RedisState.DISCONNECTED
return RedisState.NOT_INITIALIZED
def get_circuit_info(self) -> dict:
"""
获取熔断器详细信息
Returns:
包含熔断器状态的字典
"""
state = self.get_state()
info = {
"state": state.value,
"consecutive_failures": self._consecutive_failures,
"circuit_threshold": self._circuit_threshold,
"last_error": self._last_error,
}
if state == RedisState.CIRCUIT_OPEN and self._circuit_open_until:
info["circuit_remaining_seconds"] = max(0, self._circuit_open_until - time.time())
return info
def reset_circuit_breaker(self) -> None:
"""
手动重置熔断器(用于管理后台紧急恢复)
"""
logger.info("Redis 熔断器手动重置")
self._circuit_open_until = None
self._consecutive_failures = 0
self._last_error = None
async def initialize(self, require_redis: bool = False) -> Optional[aioredis.Redis]:
"""
初始化Redis连接
Args:
require_redis: 是否强制要求Redis连接成功如果为True则连接失败时抛出异常
Returns:
Redis客户端实例如果连接失败返回None当require_redis=False时
Raises:
RuntimeError: 当require_redis=True且连接失败时
"""
if self._redis is not None:
return self._redis
# 检查熔断状态
if self._circuit_open_until and time.time() < self._circuit_open_until:
remaining = self._circuit_open_until - time.time()
logger.warning(
"Redis 客户端处于熔断状态,跳过初始化,剩余 %.1f 秒 (last_error: %s)",
remaining,
self._last_error,
)
if require_redis:
raise RuntimeError(
f"Redis 处于熔断状态,剩余 {remaining:.1f} 秒。"
f"最后错误: {self._last_error}"
"使用管理 API 重置熔断器或等待自动恢复。"
)
return None
# 优先使用 REDIS_URL如果没有则根据密码构建 URL
redis_url = os.getenv("REDIS_URL")
redis_max_conn = int(os.getenv("REDIS_MAX_CONNECTIONS", "50"))
sentinel_hosts = os.getenv("REDIS_SENTINEL_HOSTS")
sentinel_service = os.getenv("REDIS_SENTINEL_SERVICE_NAME", "mymaster")
redis_password = os.getenv("REDIS_PASSWORD")
if not redis_url and not sentinel_hosts:
# 本地开发模式:从 REDIS_PASSWORD 构建 URL
if redis_password:
redis_url = f"redis://:{redis_password}@localhost:6379/0"
else:
redis_url = "redis://localhost:6379/0"
try:
if sentinel_hosts:
sentinel_list = []
for host in sentinel_hosts.split(","):
host = host.strip()
if not host:
continue
if ":" in host:
hostname, port = host.split(":", 1)
sentinel_list.append((hostname, int(port)))
else:
sentinel_list.append((host, 26379))
sentinel_kwargs = {
"password": redis_password,
"socket_timeout": 5.0,
}
sentinel = redis_sentinel.Sentinel(
sentinel_list,
**sentinel_kwargs,
)
self._redis = sentinel.master_for(
service_name=sentinel_service,
max_connections=redis_max_conn,
decode_responses=True,
socket_connect_timeout=5.0,
)
safe_url = f"sentinel://{sentinel_service}"
else:
self._redis = await aioredis.from_url(
redis_url,
encoding="utf-8",
decode_responses=True,
socket_timeout=5.0,
socket_connect_timeout=5.0,
max_connections=redis_max_conn,
)
safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url
# 测试连接
await self._redis.ping()
logger.info(f"[OK] 全局Redis客户端初始化成功: {safe_url}")
self._consecutive_failures = 0
self._circuit_open_until = None
return self._redis
except Exception as e:
error_msg = str(e)
self._last_error = error_msg
logger.error(f"[ERROR] Redis连接失败: {error_msg}")
self._consecutive_failures += 1
if self._consecutive_failures >= self._circuit_threshold:
self._circuit_open_until = time.time() + self._circuit_reset_seconds
logger.warning(
"Redis 初始化连续失败 %s 次,开启熔断 %s 秒。"
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
self._consecutive_failures,
self._circuit_reset_seconds,
)
if require_redis:
# 强制要求Redis时抛出异常拒绝启动
raise RuntimeError(
f"Redis连接失败: {error_msg}\n"
"缓存亲和性功能需要Redis支持请确保Redis服务正常运行。\n"
"检查事项:\n"
"1. Redis服务是否已启动docker-compose up -d redis\n"
"2. 环境变量 REDIS_URL 或 REDIS_PASSWORD 是否配置正确\n"
"3. Redis端口默认6379是否可访问"
) from e
logger.warning(
"[WARN] Redis 不可用,以下功能将降级运行(仅在单实例环境下安全):\n"
" - 缓存亲和性: 禁用(每次请求随机选择 Endpoint\n"
" - 分布式并发控制: 降级为本地计数\n"
" - RPM 限流: 降级为本地限流"
)
self._redis = None
return None
async def close(self) -> None:
"""关闭Redis连接"""
if self._redis:
await self._redis.close()
self._redis = None
logger.info("全局Redis客户端已关闭")
def get_client(self) -> Optional[aioredis.Redis]:
"""
获取Redis客户端非异步
注意必须先调用initialize()初始化
Returns:
Redis客户端实例或None
"""
return self._redis
# 全局单例
_redis_manager: Optional[RedisClientManager] = None
async def get_redis_client(require_redis: bool = False) -> Optional[aioredis.Redis]:
"""
获取全局Redis客户端
Args:
require_redis: 是否强制要求Redis连接成功如果为True则连接失败时抛出异常
Returns:
Redis客户端实例如果未初始化或连接失败返回None当require_redis=False时
Raises:
RuntimeError: 当require_redis=True且连接失败时
"""
global _redis_manager
if _redis_manager is None:
_redis_manager = RedisClientManager()
await _redis_manager.initialize(require_redis=require_redis)
return _redis_manager.get_client()
def get_redis_client_sync() -> Optional[aioredis.Redis]:
"""
同步获取Redis客户端不会初始化
Returns:
Redis客户端实例或None
"""
global _redis_manager
if _redis_manager is None:
return None
return _redis_manager.get_client()
async def close_redis_client() -> None:
"""关闭全局Redis客户端"""
global _redis_manager
if _redis_manager:
await _redis_manager.close()
def get_redis_state() -> RedisState:
"""
获取 Redis 连接状态(同步方法)
Returns:
Redis 连接状态枚举
"""
global _redis_manager
if _redis_manager is None:
return RedisState.NOT_INITIALIZED
return _redis_manager.get_state()
def get_redis_circuit_info() -> dict:
"""
获取 Redis 熔断器详细信息(同步方法)
Returns:
熔断器状态字典
"""
global _redis_manager
if _redis_manager is None:
return {
"state": RedisState.NOT_INITIALIZED.value,
"consecutive_failures": 0,
"circuit_threshold": 3,
"last_error": None,
}
return _redis_manager.get_circuit_info()
def reset_redis_circuit_breaker() -> bool:
"""
手动重置 Redis 熔断器(同步方法)
Returns:
是否成功重置
"""
global _redis_manager
if _redis_manager is None:
return False
_redis_manager.reset_circuit_breaker()
return True

3
src/config/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .settings import Config, config
__all__ = ["Config", "config"]

235
src/config/constants.py Normal file
View File

@@ -0,0 +1,235 @@
# Constants for better maintainability
# ==============================================================================
# 缓存相关常量
# ==============================================================================
# 缓存 TTL
class CacheTTL:
"""缓存过期时间配置(秒)"""
# 用户缓存 - 用户信息变更较频繁
USER = 60 # 1分钟
# Provider/Model 缓存 - 配置变更不频繁
PROVIDER = 300 # 5分钟
MODEL = 300 # 5分钟
MODEL_MAPPING = 300 # 5分钟
# 缓存亲和性 - 对应 provider_api_key.cache_ttl_minutes 默认值
CACHE_AFFINITY = 300 # 5分钟
# L1 本地缓存(用于减少 Redis 访问)
L1_LOCAL = 3 # 3秒
# 并发锁 TTL - 防止死锁
CONCURRENCY_LOCK = 600 # 10分钟
# 缓存容量限制
class CacheSize:
"""缓存容量配置"""
# 默认 LRU 缓存大小
DEFAULT = 1000
# ModelMapping 缓存(可能有较多别名)
MODEL_MAPPING = 2000
# ==============================================================================
# 并发和限流常量
# ==============================================================================
class ConcurrencyDefaults:
"""并发控制默认值"""
# 自适应并发初始限制(保守值)
INITIAL_LIMIT = 3
# 429错误后的冷却时间分钟- 在此期间不会增加并发限制
COOLDOWN_AFTER_429_MINUTES = 5
# 探测间隔上限(分钟)- 用于长期探测策略
MAX_PROBE_INTERVAL_MINUTES = 60
# === 基于滑动窗口的扩容参数 ===
# 滑动窗口大小(采样点数量)
UTILIZATION_WINDOW_SIZE = 20
# 滑动窗口时间范围(秒)- 只保留最近这段时间内的采样
UTILIZATION_WINDOW_SECONDS = 120 # 2分钟
# 利用率阈值 - 窗口内平均利用率 >= 此值时考虑扩容
UTILIZATION_THRESHOLD = 0.7 # 70%
# 高利用率采样比例 - 窗口内超过阈值的采样点比例 >= 此值时触发扩容
HIGH_UTILIZATION_RATIO = 0.6 # 60% 的采样点高于阈值
# 最小采样数 - 窗口内至少需要这么多采样才能做出扩容决策
MIN_SAMPLES_FOR_DECISION = 5
# 扩容步长 - 每次扩容增加的并发数
INCREASE_STEP = 1
# 缩容乘数 - 遇到 429 时的缩容比例
DECREASE_MULTIPLIER = 0.7
# 最大并发限制上限
MAX_CONCURRENT_LIMIT = 100
# 最小并发限制下限
MIN_CONCURRENT_LIMIT = 1
# === 探测性扩容参数 ===
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
PROBE_INCREASE_INTERVAL_MINUTES = 30
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
PROBE_INCREASE_MIN_REQUESTS = 10
class CircuitBreakerDefaults:
"""熔断器配置默认值(滑动窗口 + 半开状态模式)
新的熔断器基于滑动窗口错误率,而不是累计健康度。
支持半开状态,允许少量请求验证服务是否恢复。
"""
# === 滑动窗口配置 ===
# 滑动窗口大小(最近 N 次请求)
WINDOW_SIZE = 20
# 滑动窗口时间范围(秒)- 只保留最近这段时间内的请求记录
WINDOW_SECONDS = 300 # 5分钟
# 最小请求数 - 窗口内至少需要这么多请求才能做出熔断决策
MIN_REQUESTS_FOR_DECISION = 5
# 错误率阈值 - 窗口内错误率超过此值时触发熔断
ERROR_RATE_THRESHOLD = 0.5 # 50%
# === 半开状态配置 ===
# 半开状态持续时间(秒)- 在此期间允许少量请求通过
HALF_OPEN_DURATION_SECONDS = 30
# 半开状态成功阈值 - 达到此成功次数则关闭熔断器
HALF_OPEN_SUCCESS_THRESHOLD = 3
# 半开状态失败阈值 - 达到此失败次数则重新打开熔断器
HALF_OPEN_FAILURE_THRESHOLD = 2
# === 熔断恢复配置 ===
# 初始探测间隔(秒)- 熔断后多久进入半开状态
INITIAL_RECOVERY_SECONDS = 30
# 探测间隔退避倍数
RECOVERY_BACKOFF_MULTIPLIER = 2
# 最大探测间隔(秒)
MAX_RECOVERY_SECONDS = 300 # 5分钟
# === 旧参数(向后兼容,仍用于展示健康度)===
# 成功时健康度增量
SUCCESS_INCREMENT = 0.15
# 失败时健康度减量
FAILURE_DECREMENT = 0.03
# 探测成功后的快速恢复健康度
PROBE_RECOVERY_SCORE = 0.5
class AdaptiveReservationDefaults:
"""动态预留比例配置默认值
动态预留机制根据学习置信度和负载自动调整缓存用户预留比例,
解决固定 30% 预留在学习初期和负载变化时的不适应问题。
"""
# 探测阶段配置
PROBE_PHASE_REQUESTS = 100 # 探测阶段请求数阈值
PROBE_RESERVATION = 0.1 # 探测阶段预留比例10%
# 稳定阶段配置
STABLE_MIN_RESERVATION = 0.1 # 稳定阶段最小预留10%
STABLE_MAX_RESERVATION = 0.35 # 稳定阶段最大预留35%
# 置信度计算参数
SUCCESS_COUNT_FOR_FULL_CONFIDENCE = 50 # 连续成功多少次达到满置信
COOLDOWN_HOURS_FOR_FULL_CONFIDENCE = 24 # 429后多少小时达到满置信
# 负载阈值
LOW_LOAD_THRESHOLD = 0.5 # 低负载阈值50%
HIGH_LOAD_THRESHOLD = 0.8 # 高负载阈值80%
# ==============================================================================
# 超时和重试常量
# ==============================================================================
class TimeoutDefaults:
"""超时配置默认值(秒)"""
# HTTP 请求默认超时
HTTP_REQUEST = 300 # 5分钟
# 数据库连接池获取超时
DB_POOL = 30
# Redis 操作超时
REDIS_OPERATION = 5
class RetryDefaults:
"""重试配置默认值"""
# 最大重试次数
MAX_RETRIES = 3
# 重试基础延迟(秒)
BASE_DELAY = 1.0
# 重试延迟倍数(指数退避)
DELAY_MULTIPLIER = 2.0
# ==============================================================================
# 消息格式常量
# ==============================================================================
# 角色常量
ROLE_USER = "user"
ROLE_ASSISTANT = "assistant"
ROLE_SYSTEM = "system"
ROLE_TOOL = "tool"
# 内容类型常量
CONTENT_TEXT = "text"
CONTENT_IMAGE = "image"
CONTENT_TOOL_USE = "tool_use"
CONTENT_TOOL_RESULT = "tool_result"
# 工具常量
TOOL_FUNCTION = "function"
# 停止原因常量
STOP_END_TURN = "end_turn"
STOP_MAX_TOKENS = "max_tokens"
STOP_TOOL_USE = "tool_use"
STOP_ERROR = "error"
# 事件类型常量
EVENT_MESSAGE_START = "message_start"
EVENT_MESSAGE_STOP = "message_stop"
EVENT_MESSAGE_DELTA = "message_delta"
EVENT_CONTENT_BLOCK_START = "content_block_start"
EVENT_CONTENT_BLOCK_STOP = "content_block_stop"
EVENT_CONTENT_BLOCK_DELTA = "content_block_delta"
EVENT_PING = "ping"
# Delta类型常量
DELTA_TEXT = "text_delta"
DELTA_INPUT_JSON = "input_json_delta"

259
src/config/settings.py Normal file
View File

@@ -0,0 +1,259 @@
"""
服务器配置
从环境变量或 .env 文件加载配置
"""
import os
from pathlib import Path
# 尝试加载 .env 文件
try:
from dotenv import load_dotenv
env_file = Path(".env")
if env_file.exists():
load_dotenv(env_file)
except ImportError:
# 如果没有安装 python-dotenv仍然可以从环境变量读取
pass
class Config:
def __init__(self) -> None:
# 服务器配置
self.host = os.getenv("HOST", "0.0.0.0")
self.port = int(os.getenv("PORT", "8084"))
self.log_level = os.getenv("LOG_LEVEL", "INFO")
self.worker_processes = int(
os.getenv("WEB_CONCURRENCY", os.getenv("GUNICORN_WORKERS", "4"))
)
# PostgreSQL 连接池计算相关配置
# PG_MAX_CONNECTIONS: PostgreSQL 的 max_connections 设置(默认 100
# PG_RESERVED_CONNECTIONS: 为其他应用/管理工具预留的连接数(默认 10
self.pg_max_connections = int(os.getenv("PG_MAX_CONNECTIONS", "100"))
self.pg_reserved_connections = int(os.getenv("PG_RESERVED_CONNECTIONS", "10"))
# 数据库配置 - 延迟验证,支持测试环境覆盖
self._database_url = os.getenv("DATABASE_URL")
# JWT配置
self.jwt_secret_key = os.getenv("JWT_SECRET_KEY", None)
self.jwt_algorithm = os.getenv("JWT_ALGORITHM", "HS256")
self.jwt_expiration_hours = int(os.getenv("JWT_EXPIRATION_HOURS", "24"))
# 加密密钥配置独立于JWT密钥用于敏感数据加密
self.encryption_key = os.getenv("ENCRYPTION_KEY", None)
# 环境配置 - 智能检测
# Docker 部署默认为生产环境,本地开发默认为开发环境
is_docker = (
os.path.exists("/.dockerenv")
or os.environ.get("DOCKER_CONTAINER", "false").lower() == "true"
)
default_env = "production" if is_docker else "development"
self.environment = os.getenv("ENVIRONMENT", default_env)
# Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖)
redis_required_env = os.getenv("REDIS_REQUIRED")
if redis_required_env is None:
self.require_redis = self.environment not in {"development", "test", "testing"}
else:
self.require_redis = redis_required_env.lower() == "true"
# CORS配置 - 使用环境变量配置允许的源
# 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com"
cors_origins = os.getenv("CORS_ORIGINS", "")
if cors_origins:
self.cors_origins = [
origin.strip() for origin in cors_origins.split(",") if origin.strip()
]
else:
# 默认: 开发环境允许本地前端,生产环境不允许任何跨域
if self.environment == "development":
self.cors_origins = [
"http://localhost:3000",
"http://localhost:5173", # Vite 默认端口
"http://127.0.0.1:3000",
"http://127.0.0.1:5173",
]
else:
# 生产环境默认不允许跨域,必须显式配置
self.cors_origins = []
# CORS是否允许凭证(Cookie/Authorization header)
# 注意: allow_credentials=True 时不能使用 allow_origins=["*"]
self.cors_allow_credentials = os.getenv("CORS_ALLOW_CREDENTIALS", "true").lower() == "true"
# 管理员账户配置(用于初始化)
self.admin_email = os.getenv("ADMIN_EMAIL", "admin@localhost")
self.admin_username = os.getenv("ADMIN_USERNAME", "admin")
# 管理员密码 - 必须在环境变量中设置
admin_password_env = os.getenv("ADMIN_PASSWORD")
if admin_password_env:
self.admin_password = admin_password_env
else:
# 未设置密码,启动时会报错
self.admin_password = ""
self._missing_admin_password = True
# API Key 配置
self.api_key_prefix = os.getenv("API_KEY_PREFIX", "sk")
# LLM API 速率限制配置(每分钟请求数)
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
# 异常处理配置
# 设置为 True 时ProxyException 会传播到路由层以便记录 provider_request_headers
# 设置为 False 时,使用全局异常处理器统一处理
self.propagate_provider_exceptions = os.getenv(
"PROPAGATE_PROVIDER_EXCEPTIONS", "true"
).lower() == "true"
# 数据库连接池配置 - 智能自动调整
# 系统会根据 Worker 数量和 PostgreSQL 限制自动计算安全值
self.db_pool_size = int(os.getenv("DB_POOL_SIZE") or self._auto_pool_size())
self.db_max_overflow = int(os.getenv("DB_MAX_OVERFLOW") or self._auto_max_overflow())
self.db_pool_timeout = int(os.getenv("DB_POOL_TIMEOUT", "60"))
self.db_pool_recycle = int(os.getenv("DB_POOL_RECYCLE", "3600"))
self.db_pool_warn_threshold = int(os.getenv("DB_POOL_WARN_THRESHOLD", "70"))
# 验证连接池配置
self._validate_pool_config()
def _auto_pool_size(self) -> int:
"""
智能计算连接池大小 - 根据 Worker 数量和 PostgreSQL 限制计算
公式: (pg_max_connections - reserved) / workers / 2
除以 2 是因为还要预留 max_overflow 的空间
"""
available_connections = self.pg_max_connections - self.pg_reserved_connections
# 每个 Worker 可用的连接数pool_size + max_overflow
per_worker_total = available_connections // max(self.worker_processes, 1)
# pool_size 取总数的一半,另一半留给 overflow
pool_size = max(per_worker_total // 2, 5) # 最小 5 个连接
return min(pool_size, 30) # 最大 30 个连接
def _auto_max_overflow(self) -> int:
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
return self.db_pool_size
def _validate_pool_config(self) -> None:
"""验证连接池配置是否安全"""
total_per_worker = self.db_pool_size + self.db_max_overflow
total_all_workers = total_per_worker * self.worker_processes
safe_limit = self.pg_max_connections - self.pg_reserved_connections
if total_all_workers > safe_limit:
# 记录警告(不抛出异常,避免阻止启动)
self._pool_config_warning = (
f"[WARN] 数据库连接池配置可能超过 PostgreSQL 限制: "
f"{self.worker_processes} workers x {total_per_worker} connections = "
f"{total_all_workers} > {safe_limit} (pg_max_connections - reserved). "
f"建议调整 DB_POOL_SIZE 或 PG_MAX_CONNECTIONS 环境变量。"
)
else:
self._pool_config_warning = None
@property
def database_url(self) -> str:
"""
数据库 URL延迟验证
在测试环境中可以通过依赖注入覆盖,而不会在导入时崩溃
"""
if not self._database_url:
raise ValueError(
"DATABASE_URL environment variable is required. "
"Example: postgresql://username:password@localhost:5432/dbname"
)
return self._database_url
@database_url.setter
def database_url(self, value: str):
"""允许在测试中设置数据库 URL"""
self._database_url = value
def log_startup_warnings(self) -> None:
"""
记录启动时的安全警告
这个方法应该在 logger 初始化后调用
"""
from src.core.logger import logger
# 连接池配置警告
if hasattr(self, "_pool_config_warning") and self._pool_config_warning:
logger.warning(self._pool_config_warning)
# 管理员密码检查(必须在环境变量中设置)
if hasattr(self, "_missing_admin_password") and self._missing_admin_password:
logger.error("必须设置 ADMIN_PASSWORD 环境变量!")
raise ValueError("ADMIN_PASSWORD environment variable must be set!")
# JWT 密钥警告
if not self.jwt_secret_key:
if self.environment == "production":
logger.error(
"生产环境未设置 JWT_SECRET_KEY! 这是严重的安全漏洞。"
"使用 'python generate_keys.py' 生成安全密钥。"
)
else:
logger.warning("JWT_SECRET_KEY 未设置,将使用默认密钥(仅限开发环境)")
# 加密密钥警告
if not self.encryption_key and self.environment != "production":
logger.warning(
"ENCRYPTION_KEY 未设置,使用开发环境默认密钥。生产环境必须设置。"
)
# CORS 配置警告(生产环境)
if self.environment == "production" and not self.cors_origins:
logger.warning("生产环境 CORS 未配置,前端将无法访问 API。请设置 CORS_ORIGINS。")
def validate_security_config(self) -> list[str]:
"""
验证安全配置,返回错误列表
生产环境会阻止启动,开发环境仅警告
Returns:
错误消息列表(空列表表示验证通过)
"""
errors: list[str] = []
if self.environment == "production":
# 生产环境必须设置 JWT 密钥
if not self.jwt_secret_key:
errors.append(
"JWT_SECRET_KEY must be set in production. "
"Use 'python generate_keys.py' to generate a secure key."
)
elif len(self.jwt_secret_key) < 32:
errors.append("JWT_SECRET_KEY must be at least 32 characters in production.")
# 生产环境必须设置加密密钥
if not self.encryption_key:
errors.append(
"ENCRYPTION_KEY must be set in production. "
"Use 'python generate_keys.py' to generate a secure key."
)
return errors
def __repr__(self):
"""配置信息字符串表示"""
return f"""
Configuration:
Server: {self.host}:{self.port}
Log Level: {self.log_level}
Environment: {self.environment}
"""
# 创建全局配置实例
config = Config()
# 在调试模式下记录配置(延迟到日志系统初始化后)
# 这个配置信息会在应用启动时通过日志系统输出

0
src/core/__init__.py Normal file
View File

View File

@@ -0,0 +1,272 @@
"""
集中维护 API 格式的元数据,避免新增格式时到处修改常量。
此模块与 src/formats/ 的 FormatProtocol 系统配合使用:
- api_format_metadata: 定义格式的元数据(别名、默认路径)
- src/formats/: 定义格式的协议实现(解析、转换、验证)
使用方式:
# 解析格式别名
from src.core.api_format_metadata import resolve_api_format
api_format = resolve_api_format("claude") # -> APIFormat.CLAUDE
# 获取格式协议
from src.core.api_format_metadata import get_format_protocol
protocol = get_format_protocol(APIFormat.CLAUDE) # -> ClaudeProtocol
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from functools import lru_cache
from types import MappingProxyType
from typing import Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Union
from .enums import APIFormat
@dataclass(frozen=True)
class ApiFormatDefinition:
"""
描述一个 API 格式的所有通用信息。
- aliases: 用于 detect_api_format 的 provider 别名或快捷名称
- default_path: 上游默认请求路径(如 /v1/messages可通过 Endpoint.custom_path 覆盖
- path_prefix: 本站路径前缀(如 /claude, /openai为空表示无前缀
- auth_header: 认证头名称 (如 "x-api-key", "x-goog-api-key")
- auth_type: 认证类型 ("header" 直接放值, "bearer" 加 Bearer 前缀)
"""
api_format: APIFormat
aliases: Sequence[str] = field(default_factory=tuple)
default_path: str = "/" # 上游默认请求路径
path_prefix: str = "" # 本站路径前缀,为空表示无前缀
auth_header: str = "Authorization"
auth_type: str = "bearer" # "bearer" or "header"
def iter_aliases(self) -> Iterable[str]:
"""返回大小写统一后的别名集合,包含枚举名本身。"""
yield normalize_alias_value(self.api_format.value)
for alias in self.aliases:
normalized = normalize_alias_value(alias)
if normalized:
yield normalized
_DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
APIFormat.CLAUDE: ApiFormatDefinition(
api_format=APIFormat.CLAUDE,
aliases=("claude", "anthropic", "claude_compatible"),
default_path="/v1/messages",
path_prefix="", # 本站路径前缀,可配置如 "/claude"
auth_header="x-api-key",
auth_type="header",
),
APIFormat.CLAUDE_CLI: ApiFormatDefinition(
api_format=APIFormat.CLAUDE_CLI,
aliases=("claude_cli", "claude-cli"),
default_path="/v1/messages",
path_prefix="", # 与 CLAUDE 共享入口,通过 header 区分
auth_header="authorization",
auth_type="bearer",
),
APIFormat.OPENAI: ApiFormatDefinition(
api_format=APIFormat.OPENAI,
aliases=(
"openai",
"deepseek",
"grok",
"moonshot",
"zhipu",
"qwen",
"baichuan",
"minimax",
"openai_compatible",
),
default_path="/v1/chat/completions",
path_prefix="", # 本站路径前缀,可配置如 "/openai"
auth_header="Authorization",
auth_type="bearer",
),
APIFormat.OPENAI_CLI: ApiFormatDefinition(
api_format=APIFormat.OPENAI_CLI,
aliases=("openai_cli", "responses"),
default_path="/responses",
path_prefix="",
auth_header="Authorization",
auth_type="bearer",
),
APIFormat.GEMINI: ApiFormatDefinition(
api_format=APIFormat.GEMINI,
aliases=("gemini", "google", "vertex"),
default_path="/v1beta/models/{model}:{action}",
path_prefix="", # 本站路径前缀,可配置如 "/gemini"
auth_header="x-goog-api-key",
auth_type="header",
),
APIFormat.GEMINI_CLI: ApiFormatDefinition(
api_format=APIFormat.GEMINI_CLI,
aliases=("gemini_cli", "gemini-cli"),
default_path="/v1beta/models/{model}:{action}",
path_prefix="", # 与 GEMINI 共享入口
auth_header="x-goog-api-key",
auth_type="header",
),
}
# 对外只暴露只读视图,避免被随意修改
API_FORMAT_DEFINITIONS: Mapping[APIFormat, ApiFormatDefinition] = MappingProxyType(_DEFINITIONS)
def get_api_format_definition(api_format: APIFormat) -> ApiFormatDefinition:
"""获取指定格式的定义,不存在时抛出 KeyError。"""
return API_FORMAT_DEFINITIONS[api_format]
def list_api_format_definitions() -> List[ApiFormatDefinition]:
"""返回所有定义的浅拷贝列表,供遍历使用。"""
return list(API_FORMAT_DEFINITIONS.values())
def build_alias_lookup() -> Dict[str, APIFormat]:
"""
构建 alias -> APIFormat 的查找表。
每次调用都会返回新的 dict避免可变全局引发并发问题。
"""
lookup: MutableMapping[str, APIFormat] = {}
for definition in API_FORMAT_DEFINITIONS.values():
for alias in definition.iter_aliases():
lookup.setdefault(alias, definition.api_format)
return dict(lookup)
def get_default_path(api_format: APIFormat) -> str:
"""
获取该格式的上游默认请求路径。
可通过 Endpoint.custom_path 覆盖。
"""
definition = API_FORMAT_DEFINITIONS.get(api_format)
return definition.default_path if definition else "/"
def get_local_path(api_format: APIFormat) -> str:
"""
获取该格式的本站入口路径。
本站入口路径 = path_prefix + default_path
例如path_prefix="/openai" + default_path="/v1/chat/completions" -> "/openai/v1/chat/completions"
"""
definition = API_FORMAT_DEFINITIONS.get(api_format)
if definition:
prefix = definition.path_prefix or ""
return prefix + definition.default_path
return "/"
def get_auth_config(api_format: APIFormat) -> tuple[str, str]:
"""
获取该格式的认证配置。
Returns:
(auth_header, auth_type) 元组
- auth_header: 认证头名称
- auth_type: "bearer""header"
"""
definition = API_FORMAT_DEFINITIONS.get(api_format)
if definition:
return definition.auth_header, definition.auth_type
return "Authorization", "bearer"
@lru_cache(maxsize=1)
def _alias_lookup_cache() -> Dict[str, APIFormat]:
"""缓存 alias -> APIFormat 查找表,减少重复构建。"""
return build_alias_lookup()
def resolve_api_format_alias(value: str) -> Optional[APIFormat]:
"""根据别名查找 APIFormat找不到时返回 None。"""
if not value:
return None
normalized = normalize_alias_value(value)
if not normalized:
return None
return _alias_lookup_cache().get(normalized)
def resolve_api_format(
value: Union[str, APIFormat, None],
default: Optional[APIFormat] = None,
) -> Optional[APIFormat]:
"""
将任意字符串/枚举值解析为 APIFormat。
Args:
value: 可以是 APIFormat 或任意字符串/别名
default: 未解析成功时返回的默认值
"""
if isinstance(value, APIFormat):
return value
if isinstance(value, str):
stripped = value.strip()
if not stripped:
return default
upper = stripped.upper()
if upper in APIFormat.__members__:
return APIFormat[upper]
alias = resolve_api_format_alias(stripped)
if alias:
return alias
return default
def register_api_format_definition(definition: ApiFormatDefinition, *, override: bool = False):
"""
注册或覆盖 API 格式定义,允许运行时扩展。
Args:
definition: 要注册的定义
override: 若目标枚举已存在,是否允许覆盖
"""
existing = _DEFINITIONS.get(definition.api_format)
if existing and not override:
raise ValueError(f"{definition.api_format.value} 已存在,如需覆盖请设置 override=True")
_DEFINITIONS[definition.api_format] = definition
_refresh_metadata_cache()
def _refresh_metadata_cache():
"""更新别名缓存,供注册函数调用。"""
_alias_lookup_cache.cache_clear()
def normalize_alias_value(value: str) -> str:
"""统一别名格式:去空白、转小写,并将非字母数字转为单个下划线。"""
if value is None:
return ""
text = value.strip().lower()
# 将所有非字母数字字符替换为下划线,并折叠连续的下划线
text = re.sub(r"[^a-z0-9]+", "_", text)
return text.strip("_")
# =============================================================================
# 格式判断工具
# =============================================================================
def is_cli_api_format(api_format: APIFormat) -> bool:
"""
判断是否为 CLI 透传格式。
Args:
api_format: APIFormat 枚举值
Returns:
True 如果是 CLI 格式
"""
from src.api.handlers.base.parsers import is_cli_format
return is_cli_format(api_format.value)

115
src/core/batch_committer.py Normal file
View File

@@ -0,0 +1,115 @@
"""
批量提交器 - 减少数据库 commit 次数,提升并发能力
核心思想:
- 非关键数据(监控、统计)不立即 commit
- 在后台定期批量 commit
- 关键数据(计费)仍然立即 commit
"""
import asyncio
from typing import Set
from src.core.logger import logger
from sqlalchemy.orm import Session
class BatchCommitter:
"""批量提交管理器"""
def __init__(self, interval_seconds: float = 1.0):
"""
Args:
interval_seconds: 批量提交间隔(秒)
"""
self.interval_seconds = interval_seconds
self._pending_sessions: Set[Session] = set()
self._lock = asyncio.Lock()
self._task = None
async def start(self):
"""启动后台批量提交任务"""
if self._task is None:
self._task = asyncio.create_task(self._batch_commit_loop())
logger.info(f"批量提交器已启动,间隔: {self.interval_seconds}s")
async def stop(self):
"""停止后台任务"""
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self._task = None
logger.info("批量提交器已停止")
def mark_dirty(self, session: Session):
"""标记 Session 有待提交的更改"""
self._pending_sessions.add(session)
async def _batch_commit_loop(self):
"""后台批量提交循环"""
while True:
try:
await asyncio.sleep(self.interval_seconds)
await self._commit_all()
except asyncio.CancelledError:
# 关闭前提交所有待处理的
await self._commit_all()
raise
except Exception as e:
logger.error(f"批量提交出错: {e}")
async def _commit_all(self):
"""提交所有待处理的 Session"""
async with self._lock:
if not self._pending_sessions:
return
sessions_to_commit = list(self._pending_sessions)
self._pending_sessions.clear()
committed = 0
failed = 0
for session in sessions_to_commit:
try:
session.commit()
committed += 1
except Exception as e:
logger.error(f"提交 Session 失败: {e}")
try:
session.rollback()
except:
pass
failed += 1
if committed > 0:
logger.debug(f"批量提交完成: {committed} 个 Session")
if failed > 0:
logger.warning(f"批量提交失败: {failed} 个 Session")
# 全局单例
_batch_committer: BatchCommitter = None
def get_batch_committer() -> BatchCommitter:
"""获取全局批量提交器"""
global _batch_committer
if _batch_committer is None:
_batch_committer = BatchCommitter(interval_seconds=1.0)
return _batch_committer
async def init_batch_committer():
"""初始化并启动批量提交器"""
committer = get_batch_committer()
await committer.start()
async def shutdown_batch_committer():
"""关闭批量提交器"""
committer = get_batch_committer()
await committer.stop()

174
src/core/cache_service.py Normal file
View File

@@ -0,0 +1,174 @@
"""
缓存服务 - 统一的缓存抽象层
"""
import json
from datetime import timedelta
from typing import Any, Optional
from src.clients.redis_client import get_redis_client
from src.core.logger import logger
class CacheService:
"""缓存服务"""
@staticmethod
async def get(key: str) -> Optional[Any]:
"""
从缓存获取数据
Args:
key: 缓存键
Returns:
缓存的值,如果不存在则返回 None
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return None
value = await redis.get(key)
if value:
# 尝试 JSON 反序列化
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
return value
return None
except Exception as e:
logger.warning(f"缓存读取失败: {key} - {e}")
return None
@staticmethod
async def set(key: str, value: Any, ttl_seconds: int = 60) -> bool:
"""
设置缓存
Args:
key: 缓存键
value: 缓存值
ttl_seconds: 过期时间(秒),默认 60 秒
Returns:
是否设置成功
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return False
# JSON 序列化
if isinstance(value, (dict, list)):
value = json.dumps(value)
elif not isinstance(value, (str, bytes)):
value = str(value)
await redis.setex(key, ttl_seconds, value)
return True
except Exception as e:
logger.warning(f"缓存写入失败: {key} - {e}")
return False
@staticmethod
async def delete(key: str) -> bool:
"""
删除缓存
Args:
key: 缓存键
Returns:
是否删除成功
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return False
await redis.delete(key)
return True
except Exception as e:
logger.warning(f"缓存删除失败: {key} - {e}")
return False
@staticmethod
async def exists(key: str) -> bool:
"""
检查缓存是否存在
Args:
key: 缓存键
Returns:
是否存在
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return False
return await redis.exists(key) > 0
except Exception as e:
logger.warning(f"缓存检查失败: {key} - {e}")
return False
# 缓存键前缀
class CacheKeys:
"""缓存键定义"""
# User 缓存TTL 60秒
USER_BY_ID = "user:id:{user_id}"
USER_BY_EMAIL = "user:email:{email}"
# API Key 缓存TTL 30秒
APIKEY_HASH = "apikey:hash:{key_hash}"
APIKEY_AUTH = "apikey:auth:{key_hash}" # 认证结果缓存
# Provider 配置缓存TTL 300秒
PROVIDER_BY_ID = "provider:id:{provider_id}"
ENDPOINT_BY_ID = "endpoint:id:{endpoint_id}"
API_KEY_BY_ID = "api_key:id:{api_key_id}"
@staticmethod
def user_by_id(user_id: str) -> str:
"""User ID 缓存键"""
return CacheKeys.USER_BY_ID.format(user_id=user_id)
@staticmethod
def user_by_email(email: str) -> str:
"""User Email 缓存键"""
return CacheKeys.USER_BY_EMAIL.format(email=email)
@staticmethod
def apikey_hash(key_hash: str) -> str:
"""API Key Hash 缓存键"""
return CacheKeys.APIKEY_HASH.format(key_hash=key_hash)
@staticmethod
def apikey_auth(key_hash: str) -> str:
"""API Key 认证结果缓存键"""
return CacheKeys.APIKEY_AUTH.format(key_hash=key_hash)
@staticmethod
def provider_by_id(provider_id: str) -> str:
"""Provider ID 缓存键"""
return CacheKeys.PROVIDER_BY_ID.format(provider_id=provider_id)
@staticmethod
def endpoint_by_id(endpoint_id: str) -> str:
"""Endpoint ID 缓存键"""
return CacheKeys.ENDPOINT_BY_ID.format(endpoint_id=endpoint_id)
@staticmethod
def api_key_by_id(api_key_id: str) -> str:
"""API Key ID 缓存键"""
return CacheKeys.API_KEY_BY_ID.format(api_key_id=api_key_id)

Some files were not shown because too many files have changed in this diff Show More