mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 08:12:26 +08:00
Initial commit
This commit is contained in:
11
src/__init__.py
Normal file
11
src/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""AI Proxy
|
||||
|
||||
A proxy server that enables AI models to work with multiple API providers.
|
||||
"""
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
__version__ = "9.1.0"
|
||||
__author__ = "AI Proxy"
|
||||
0
src/api/__init__.py
Normal file
0
src/api/__init__.py
Normal file
30
src/api/admin/__init__.py
Normal file
30
src/api/admin/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Admin API routers."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .adaptive import router as adaptive_router
|
||||
from .api_keys import router as api_keys_router
|
||||
from .endpoints import router as endpoints_router
|
||||
from .models import router as models_router
|
||||
from .monitoring import router as monitoring_router
|
||||
from .provider_strategy import router as provider_strategy_router
|
||||
from .providers import router as providers_router
|
||||
from .security import router as security_router
|
||||
from .system import router as system_router
|
||||
from .usage import router as usage_router
|
||||
from .users import router as users_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(system_router)
|
||||
router.include_router(users_router)
|
||||
router.include_router(providers_router)
|
||||
router.include_router(api_keys_router)
|
||||
router.include_router(usage_router)
|
||||
router.include_router(monitoring_router)
|
||||
router.include_router(endpoints_router)
|
||||
router.include_router(provider_strategy_router)
|
||||
router.include_router(adaptive_router)
|
||||
router.include_router(models_router)
|
||||
router.include_router(security_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
377
src/api/admin/adaptive.py
Normal file
377
src/api/admin/adaptive.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
自适应并发管理 API 端点
|
||||
|
||||
设计原则:
|
||||
- 自适应模式由 max_concurrent 字段决定:
|
||||
- max_concurrent = NULL:启用自适应模式,系统自动学习并调整并发限制
|
||||
- max_concurrent = 数字:固定限制模式,使用用户指定的并发限制
|
||||
- learned_max_concurrent:自适应模式下学习到的并发限制值
|
||||
- adaptive_mode 是计算字段,基于 max_concurrent 是否为 NULL
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.database import get_db
|
||||
from src.models.database import ProviderAPIKey
|
||||
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager
|
||||
|
||||
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive Concurrency"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# ==================== Pydantic Models ====================
|
||||
|
||||
|
||||
class EnableAdaptiveRequest(BaseModel):
|
||||
"""启用自适应模式请求"""
|
||||
|
||||
enabled: bool = Field(..., description="是否启用自适应模式(true=自适应,false=固定限制)")
|
||||
fixed_limit: Optional[int] = Field(
|
||||
None, ge=1, le=100, description="固定并发限制(仅当 enabled=false 时生效)"
|
||||
)
|
||||
|
||||
|
||||
class AdaptiveStatsResponse(BaseModel):
|
||||
"""自适应统计响应"""
|
||||
|
||||
adaptive_mode: bool = Field(..., description="是否为自适应模式(max_concurrent=NULL)")
|
||||
max_concurrent: Optional[int] = Field(None, description="用户配置的固定限制(NULL=自适应)")
|
||||
effective_limit: Optional[int] = Field(
|
||||
None, description="当前有效限制(自适应使用学习值,固定使用配置值)"
|
||||
)
|
||||
learned_limit: Optional[int] = Field(None, description="学习到的并发限制")
|
||||
concurrent_429_count: int
|
||||
rpm_429_count: int
|
||||
last_429_at: Optional[str]
|
||||
last_429_type: Optional[str]
|
||||
adjustment_count: int
|
||||
recent_adjustments: List[dict]
|
||||
|
||||
|
||||
class KeyListItem(BaseModel):
|
||||
"""Key 列表项"""
|
||||
|
||||
id: str
|
||||
name: Optional[str]
|
||||
endpoint_id: str
|
||||
is_adaptive: bool = Field(..., description="是否为自适应模式(max_concurrent=NULL)")
|
||||
max_concurrent: Optional[int] = Field(None, description="固定并发限制(NULL=自适应)")
|
||||
effective_limit: Optional[int] = Field(None, description="当前有效限制")
|
||||
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
|
||||
concurrent_429_count: int
|
||||
rpm_429_count: int
|
||||
|
||||
|
||||
# ==================== API Endpoints ====================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/keys",
|
||||
response_model=List[KeyListItem],
|
||||
summary="获取所有启用自适应模式的Key",
|
||||
)
|
||||
async def list_adaptive_keys(
|
||||
request: Request,
|
||||
endpoint_id: Optional[str] = Query(None, description="按 Endpoint 过滤"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取所有启用自适应模式的Key列表
|
||||
|
||||
可选参数:
|
||||
- endpoint_id: 按 Endpoint 过滤
|
||||
"""
|
||||
adapter = ListAdaptiveKeysAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/keys/{key_id}/mode",
|
||||
summary="Toggle key's concurrency control mode",
|
||||
)
|
||||
async def toggle_adaptive_mode(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Toggle the concurrency control mode for a specific key
|
||||
|
||||
Parameters:
|
||||
- enabled: true=adaptive mode (max_concurrent=NULL), false=fixed limit mode
|
||||
- fixed_limit: fixed limit value (required when enabled=false)
|
||||
"""
|
||||
adapter = ToggleAdaptiveModeAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/keys/{key_id}/stats",
|
||||
response_model=AdaptiveStatsResponse,
|
||||
summary="获取Key的自适应统计",
|
||||
)
|
||||
async def get_adaptive_stats(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定Key的自适应并发统计信息
|
||||
|
||||
包括:
|
||||
- 当前配置
|
||||
- 学习到的限制
|
||||
- 429错误统计
|
||||
- 调整历史
|
||||
"""
|
||||
adapter = GetAdaptiveStatsAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/keys/{key_id}/learning",
|
||||
summary="Reset key's learning state",
|
||||
)
|
||||
async def reset_adaptive_learning(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Reset the adaptive learning state for a specific key
|
||||
|
||||
Clears:
|
||||
- Learned concurrency limit (learned_max_concurrent)
|
||||
- 429 error counts
|
||||
- Adjustment history
|
||||
|
||||
Does not change:
|
||||
- max_concurrent config (determines adaptive mode)
|
||||
"""
|
||||
adapter = ResetAdaptiveLearningAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/keys/{key_id}/limit",
|
||||
summary="Set key to fixed concurrency limit mode",
|
||||
)
|
||||
async def set_concurrent_limit(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(..., ge=1, le=100, description="Concurrency limit value"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Set key to fixed concurrency limit mode
|
||||
|
||||
Note:
|
||||
- After setting this value, key switches to fixed limit mode and won't auto-adjust
|
||||
- To restore adaptive mode, use PATCH /keys/{key_id}/mode
|
||||
"""
|
||||
adapter = SetConcurrentLimitAdapter(key_id=key_id, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/summary",
|
||||
summary="获取自适应并发的全局统计",
|
||||
)
|
||||
async def get_adaptive_summary(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取自适应并发的全局统计摘要
|
||||
|
||||
包括:
|
||||
- 启用自适应模式的Key数量
|
||||
- 总429错误数
|
||||
- 并发限制调整次数
|
||||
"""
|
||||
adapter = AdaptiveSummaryAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ==================== Pipeline 适配器 ====================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListAdaptiveKeysAdapter(AdminApiAdapter):
|
||||
endpoint_id: Optional[str] = None
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
# 自适应模式:max_concurrent = NULL
|
||||
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None))
|
||||
if self.endpoint_id:
|
||||
query = query.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
|
||||
|
||||
keys = query.all()
|
||||
return [
|
||||
KeyListItem(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
endpoint_id=key.endpoint_id,
|
||||
is_adaptive=key.max_concurrent is None,
|
||||
max_concurrent=key.max_concurrent,
|
||||
effective_limit=(
|
||||
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
|
||||
),
|
||||
learned_max_concurrent=key.learned_max_concurrent,
|
||||
concurrent_429_count=key.concurrent_429_count or 0,
|
||||
rpm_429_count=key.rpm_429_count or 0,
|
||||
)
|
||||
for key in keys
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToggleAdaptiveModeAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
body = EnableAdaptiveRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
if body.enabled:
|
||||
# 启用自适应模式:将 max_concurrent 设为 NULL
|
||||
key.max_concurrent = None
|
||||
message = "已切换为自适应模式,系统将自动学习并调整并发限制"
|
||||
else:
|
||||
# 禁用自适应模式:设置固定限制
|
||||
if body.fixed_limit is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="禁用自适应模式时必须提供 fixed_limit 参数"
|
||||
)
|
||||
key.max_concurrent = body.fixed_limit
|
||||
message = f"已切换为固定限制模式,并发限制设为 {body.fixed_limit}"
|
||||
|
||||
context.db.commit()
|
||||
context.db.refresh(key)
|
||||
|
||||
is_adaptive = key.max_concurrent is None
|
||||
return {
|
||||
"message": message,
|
||||
"key_id": key.id,
|
||||
"is_adaptive": is_adaptive,
|
||||
"max_concurrent": key.max_concurrent,
|
||||
"effective_limit": key.learned_max_concurrent if is_adaptive else key.max_concurrent,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetAdaptiveStatsAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
adaptive_manager = get_adaptive_manager()
|
||||
stats = adaptive_manager.get_adjustment_stats(key)
|
||||
|
||||
# 转换字段名以匹配响应模型
|
||||
return AdaptiveStatsResponse(
|
||||
adaptive_mode=stats["adaptive_mode"],
|
||||
max_concurrent=stats["max_concurrent"],
|
||||
effective_limit=stats["effective_limit"],
|
||||
learned_limit=stats["learned_limit"],
|
||||
concurrent_429_count=stats["concurrent_429_count"],
|
||||
rpm_429_count=stats["rpm_429_count"],
|
||||
last_429_at=stats["last_429_at"],
|
||||
last_429_type=stats["last_429_type"],
|
||||
adjustment_count=stats["adjustment_count"],
|
||||
recent_adjustments=stats["recent_adjustments"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetAdaptiveLearningAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
adaptive_manager = get_adaptive_manager()
|
||||
adaptive_manager.reset_learning(context.db, key)
|
||||
return {"message": "学习状态已重置", "key_id": key.id}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetConcurrentLimitAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
was_adaptive = key.max_concurrent is None
|
||||
key.max_concurrent = self.limit
|
||||
context.db.commit()
|
||||
context.db.refresh(key)
|
||||
|
||||
return {
|
||||
"message": f"已设置为固定限制模式,并发限制为 {self.limit}",
|
||||
"key_id": key.id,
|
||||
"is_adaptive": False,
|
||||
"max_concurrent": key.max_concurrent,
|
||||
"previous_mode": "adaptive" if was_adaptive else "fixed",
|
||||
}
|
||||
|
||||
|
||||
class AdaptiveSummaryAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
# 自适应模式:max_concurrent = NULL
|
||||
adaptive_keys = (
|
||||
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None)).all()
|
||||
)
|
||||
|
||||
total_keys = len(adaptive_keys)
|
||||
total_concurrent_429 = sum(key.concurrent_429_count or 0 for key in adaptive_keys)
|
||||
total_rpm_429 = sum(key.rpm_429_count or 0 for key in adaptive_keys)
|
||||
total_adjustments = sum(len(key.adjustment_history or []) for key in adaptive_keys)
|
||||
|
||||
recent_adjustments = []
|
||||
for key in adaptive_keys:
|
||||
if key.adjustment_history:
|
||||
for adj in key.adjustment_history[-3:]:
|
||||
recent_adjustments.append(
|
||||
{
|
||||
"key_id": key.id,
|
||||
"key_name": key.name,
|
||||
**adj,
|
||||
}
|
||||
)
|
||||
|
||||
recent_adjustments.sort(key=lambda item: item.get("timestamp", ""), reverse=True)
|
||||
|
||||
return {
|
||||
"total_adaptive_keys": total_keys,
|
||||
"total_concurrent_429_errors": total_concurrent_429,
|
||||
"total_rpm_429_errors": total_rpm_429,
|
||||
"total_adjustments": total_adjustments,
|
||||
"recent_adjustments": recent_adjustments[:10],
|
||||
}
|
||||
5
src/api/admin/api_keys/__init__.py
Normal file
5
src/api/admin/api_keys/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""API key admin routes export."""
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
497
src/api/admin/api_keys/routes.py
Normal file
497
src/api/admin/api_keys/routes.py
Normal file
@@ -0,0 +1,497 @@
|
||||
"""管理员独立余额 API Key 管理路由。
|
||||
|
||||
独立余额Key:不关联用户配额,有独立余额限制,用于给非注册用户使用。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import CreateApiKeyRequest
|
||||
from src.models.database import ApiKey, User
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/admin/api-keys", tags=["Admin - API Keys (Standalone)"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_standalone_api_keys(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
is_active: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列出所有独立余额API Keys"""
|
||||
adapter = AdminListStandaloneKeysAdapter(skip=skip, limit=limit, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_standalone_api_key(
|
||||
request: Request,
|
||||
key_data: CreateApiKeyRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建独立余额API Key(必须设置余额限制)"""
|
||||
adapter = AdminCreateStandaloneKeyAdapter(key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{key_id}")
|
||||
async def update_api_key(
|
||||
key_id: str, request: Request, key_data: CreateApiKeyRequest, db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新独立余额Key(可修改名称、过期时间、余额限制等)"""
|
||||
adapter = AdminUpdateApiKeyAdapter(key_id=key_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{key_id}")
|
||||
async def toggle_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Toggle API key active status (PATCH with is_active in body)"""
|
||||
adapter = AdminToggleApiKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{key_id}")
|
||||
async def delete_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminDeleteApiKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{key_id}/balance")
|
||||
async def add_balance_to_key(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Adjust balance for standalone API key (positive to add, negative to deduct)"""
|
||||
# 从请求体获取调整金额
|
||||
body = await request.json()
|
||||
amount_usd = body.get("amount_usd")
|
||||
|
||||
# 参数校验
|
||||
if amount_usd is None:
|
||||
raise HTTPException(status_code=400, detail="缺少必需参数: amount_usd")
|
||||
|
||||
if amount_usd == 0:
|
||||
raise HTTPException(status_code=400, detail="调整金额不能为 0")
|
||||
|
||||
# 类型校验
|
||||
try:
|
||||
amount_usd = float(amount_usd)
|
||||
except (ValueError, TypeError):
|
||||
raise HTTPException(status_code=400, detail="调整金额必须是有效数字")
|
||||
|
||||
# 如果是扣除操作,检查Key是否存在以及余额是否充足
|
||||
if amount_usd < 0:
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API密钥不存在")
|
||||
|
||||
if not api_key.is_standalone:
|
||||
raise HTTPException(status_code=400, detail="只能为独立余额Key调整余额")
|
||||
|
||||
if api_key.current_balance_usd is not None:
|
||||
if abs(amount_usd) > api_key.current_balance_usd:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"扣除金额 ${abs(amount_usd):.2f} 超过当前余额 ${api_key.current_balance_usd:.2f}",
|
||||
)
|
||||
|
||||
adapter = AdminAddBalanceAdapter(key_id=key_id, amount_usd=amount_usd)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{key_id}")
|
||||
async def get_api_key_detail(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
include_key: bool = Query(False, description="Include full decrypted key in response"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get API key detail, optionally include full key"""
|
||||
if include_key:
|
||||
adapter = AdminGetFullKeyAdapter(key_id=key_id)
|
||||
else:
|
||||
# Return basic key info without full key
|
||||
adapter = AdminGetKeyDetailAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminListStandaloneKeysAdapter(AdminApiAdapter):
|
||||
"""列出独立余额Keys"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
skip: int,
|
||||
limit: int,
|
||||
is_active: Optional[bool],
|
||||
):
|
||||
self.skip = skip
|
||||
self.limit = limit
|
||||
self.is_active = is_active
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
# 只查询独立余额Keys
|
||||
query = db.query(ApiKey).filter(ApiKey.is_standalone == True)
|
||||
|
||||
if self.is_active is not None:
|
||||
query = query.filter(ApiKey.is_active == self.is_active)
|
||||
|
||||
total = query.count()
|
||||
api_keys = (
|
||||
query.order_by(ApiKey.created_at.desc()).offset(self.skip).limit(self.limit).all()
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="list_standalone_api_keys",
|
||||
filter_is_active=self.is_active,
|
||||
limit=self.limit,
|
||||
skip=self.skip,
|
||||
total=total,
|
||||
)
|
||||
|
||||
return {
|
||||
"api_keys": [
|
||||
{
|
||||
"id": api_key.id,
|
||||
"user_id": api_key.user_id, # 创建者ID
|
||||
"name": api_key.name,
|
||||
"key_display": api_key.get_display_key(),
|
||||
"is_active": api_key.is_active,
|
||||
"is_standalone": api_key.is_standalone,
|
||||
"current_balance_usd": api_key.current_balance_usd,
|
||||
"balance_used_usd": float(api_key.balance_used_usd or 0),
|
||||
"total_requests": api_key.total_requests,
|
||||
"total_cost_usd": float(api_key.total_cost_usd or 0),
|
||||
"rate_limit": api_key.rate_limit,
|
||||
"allowed_providers": api_key.allowed_providers,
|
||||
"allowed_api_formats": api_key.allowed_api_formats,
|
||||
"allowed_models": api_key.allowed_models,
|
||||
"last_used_at": (
|
||||
api_key.last_used_at.isoformat() if api_key.last_used_at else None
|
||||
),
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
|
||||
"auto_delete_on_expiry": api_key.auto_delete_on_expiry,
|
||||
}
|
||||
for api_key in api_keys
|
||||
],
|
||||
"total": total,
|
||||
"limit": self.limit,
|
||||
"skip": self.skip,
|
||||
}
|
||||
|
||||
|
||||
class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
|
||||
"""创建独立余额Key"""
|
||||
|
||||
def __init__(self, key_data: CreateApiKeyRequest):
|
||||
self.key_data = key_data
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 独立Key必须设置初始余额
|
||||
if not self.key_data.initial_balance_usd or self.key_data.initial_balance_usd <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="创建独立余额Key必须设置有效的初始余额(initial_balance_usd > 0)",
|
||||
)
|
||||
|
||||
# 独立Key需要关联到管理员用户(从context获取)
|
||||
admin_user_id = context.user.id
|
||||
|
||||
# 创建独立Key
|
||||
api_key, plain_key = ApiKeyService.create_api_key(
|
||||
db=db,
|
||||
user_id=admin_user_id, # 关联到创建者
|
||||
name=self.key_data.name,
|
||||
allowed_providers=self.key_data.allowed_providers,
|
||||
allowed_api_formats=self.key_data.allowed_api_formats,
|
||||
allowed_models=self.key_data.allowed_models,
|
||||
rate_limit=self.key_data.rate_limit or 100,
|
||||
expire_days=self.key_data.expire_days,
|
||||
initial_balance_usd=self.key_data.initial_balance_usd,
|
||||
is_standalone=True, # 标记为独立Key
|
||||
auto_delete_on_expiry=self.key_data.auto_delete_on_expiry,
|
||||
)
|
||||
|
||||
logger.info(f"管理员创建独立余额Key: ID {api_key.id}, 初始余额 ${self.key_data.initial_balance_usd}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="create_standalone_api_key",
|
||||
key_id=api_key.id,
|
||||
initial_balance_usd=self.key_data.initial_balance_usd,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"key": plain_key, # 只在创建时返回完整密钥
|
||||
"name": api_key.name,
|
||||
"key_display": api_key.get_display_key(),
|
||||
"is_standalone": True,
|
||||
"current_balance_usd": api_key.current_balance_usd,
|
||||
"balance_used_usd": 0.0,
|
||||
"rate_limit": api_key.rate_limit,
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"message": "独立余额Key创建成功,请妥善保存完整密钥,后续将无法查看",
|
||||
}
|
||||
|
||||
|
||||
class AdminUpdateApiKeyAdapter(AdminApiAdapter):
|
||||
"""更新独立余额Key"""
|
||||
|
||||
def __init__(self, key_id: str, key_data: CreateApiKeyRequest):
|
||||
self.key_id = key_id
|
||||
self.key_data = key_data
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
|
||||
if not api_key:
|
||||
raise NotFoundException("API密钥不存在", "api_key")
|
||||
|
||||
# 构建更新数据
|
||||
update_data = {}
|
||||
if self.key_data.name is not None:
|
||||
update_data["name"] = self.key_data.name
|
||||
if self.key_data.rate_limit is not None:
|
||||
update_data["rate_limit"] = self.key_data.rate_limit
|
||||
if (
|
||||
hasattr(self.key_data, "auto_delete_on_expiry")
|
||||
and self.key_data.auto_delete_on_expiry is not None
|
||||
):
|
||||
update_data["auto_delete_on_expiry"] = self.key_data.auto_delete_on_expiry
|
||||
|
||||
# 访问限制配置(允许设置为空数组来清除限制)
|
||||
if hasattr(self.key_data, "allowed_providers"):
|
||||
update_data["allowed_providers"] = self.key_data.allowed_providers
|
||||
if hasattr(self.key_data, "allowed_api_formats"):
|
||||
update_data["allowed_api_formats"] = self.key_data.allowed_api_formats
|
||||
if hasattr(self.key_data, "allowed_models"):
|
||||
update_data["allowed_models"] = self.key_data.allowed_models
|
||||
|
||||
# 处理过期时间
|
||||
if self.key_data.expire_days is not None:
|
||||
if self.key_data.expire_days > 0:
|
||||
from datetime import timedelta
|
||||
|
||||
update_data["expires_at"] = datetime.now(timezone.utc) + timedelta(
|
||||
days=self.key_data.expire_days
|
||||
)
|
||||
else:
|
||||
# expire_days = 0 或负数表示永不过期
|
||||
update_data["expires_at"] = None
|
||||
elif hasattr(self.key_data, "expire_days") and self.key_data.expire_days is None:
|
||||
# 明确传递 None,设为永不过期
|
||||
update_data["expires_at"] = None
|
||||
|
||||
# 使用 ApiKeyService 更新
|
||||
updated_key = ApiKeyService.update_api_key(db, self.key_id, **update_data)
|
||||
if not updated_key:
|
||||
raise NotFoundException("更新失败", "api_key")
|
||||
|
||||
logger.info(f"管理员更新独立余额Key: ID {self.key_id}, 更新字段 {list(update_data.keys())}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="update_standalone_api_key",
|
||||
key_id=self.key_id,
|
||||
updated_fields=list(update_data.keys()),
|
||||
)
|
||||
|
||||
return {
|
||||
"id": updated_key.id,
|
||||
"name": updated_key.name,
|
||||
"key_display": updated_key.get_display_key(),
|
||||
"is_active": updated_key.is_active,
|
||||
"current_balance_usd": updated_key.current_balance_usd,
|
||||
"balance_used_usd": float(updated_key.balance_used_usd or 0),
|
||||
"rate_limit": updated_key.rate_limit,
|
||||
"expires_at": updated_key.expires_at.isoformat() if updated_key.expires_at else None,
|
||||
"updated_at": updated_key.updated_at.isoformat() if updated_key.updated_at else None,
|
||||
"message": "API密钥已更新",
|
||||
}
|
||||
|
||||
|
||||
class AdminToggleApiKeyAdapter(AdminApiAdapter):
|
||||
def __init__(self, key_id: str):
|
||||
self.key_id = key_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
|
||||
if not api_key:
|
||||
raise NotFoundException("API密钥不存在", "api_key")
|
||||
|
||||
api_key.is_active = not api_key.is_active
|
||||
api_key.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
logger.info(f"管理员切换API密钥状态: Key ID {self.key_id}, 新状态 {'启用' if api_key.is_active else '禁用'}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="toggle_api_key",
|
||||
target_key_id=api_key.id,
|
||||
user_id=api_key.user_id,
|
||||
new_status="enabled" if api_key.is_active else "disabled",
|
||||
)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"is_active": api_key.is_active,
|
||||
"message": f"API密钥已{'启用' if api_key.is_active else '禁用'}",
|
||||
}
|
||||
|
||||
|
||||
class AdminDeleteApiKeyAdapter(AdminApiAdapter):
|
||||
def __init__(self, key_id: str):
|
||||
self.key_id = key_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API密钥不存在")
|
||||
|
||||
user = api_key.user
|
||||
db.delete(api_key)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"管理员删除API密钥: Key ID {self.key_id}, 用户 {user.email if user else '未知'}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="delete_api_key",
|
||||
target_key_id=self.key_id,
|
||||
user_id=user.id if user else None,
|
||||
user_email=user.email if user else None,
|
||||
)
|
||||
return {"message": "API密钥已删除"}
|
||||
|
||||
|
||||
class AdminAddBalanceAdapter(AdminApiAdapter):
|
||||
"""为独立余额Key增加余额"""
|
||||
|
||||
def __init__(self, key_id: str, amount_usd: float):
|
||||
self.key_id = key_id
|
||||
self.amount_usd = amount_usd
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 使用 ApiKeyService 增加余额
|
||||
updated_key = ApiKeyService.add_balance(db, self.key_id, self.amount_usd)
|
||||
|
||||
if not updated_key:
|
||||
raise NotFoundException("余额充值失败:Key不存在或不是独立余额Key", "api_key")
|
||||
|
||||
logger.info(f"管理员为独立余额Key充值: ID {self.key_id}, 充值 ${self.amount_usd:.4f}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="add_balance_to_key",
|
||||
key_id=self.key_id,
|
||||
amount_usd=self.amount_usd,
|
||||
new_current_balance=updated_key.current_balance_usd,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": updated_key.id,
|
||||
"name": updated_key.name,
|
||||
"current_balance_usd": updated_key.current_balance_usd,
|
||||
"balance_used_usd": float(updated_key.balance_used_usd or 0),
|
||||
"message": f"余额充值成功,充值 ${self.amount_usd:.2f},当前余额 ${updated_key.current_balance_usd:.2f}",
|
||||
}
|
||||
|
||||
|
||||
class AdminGetFullKeyAdapter(AdminApiAdapter):
|
||||
"""获取完整的API密钥"""
|
||||
|
||||
def __init__(self, key_id: str):
|
||||
self.key_id = key_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
db = context.db
|
||||
|
||||
# 查找API密钥
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
|
||||
if not api_key:
|
||||
raise NotFoundException("API密钥不存在", "api_key")
|
||||
|
||||
# 解密完整密钥
|
||||
if not api_key.key_encrypted:
|
||||
raise HTTPException(status_code=400, detail="该密钥没有存储完整密钥信息")
|
||||
|
||||
try:
|
||||
full_key = crypto_service.decrypt(api_key.key_encrypted)
|
||||
except Exception as e:
|
||||
logger.error(f"解密API密钥失败: Key ID {self.key_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail="解密密钥失败")
|
||||
|
||||
logger.info(f"管理员查看完整API密钥: Key ID {self.key_id}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="view_full_api_key",
|
||||
key_id=self.key_id,
|
||||
key_name=api_key.name,
|
||||
)
|
||||
|
||||
return {
|
||||
"key": full_key,
|
||||
}
|
||||
|
||||
|
||||
class AdminGetKeyDetailAdapter(AdminApiAdapter):
|
||||
"""Get API key detail without full key"""
|
||||
|
||||
def __init__(self, key_id: str):
|
||||
self.key_id = key_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
|
||||
if not api_key:
|
||||
raise NotFoundException("API密钥不存在", "api_key")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="get_api_key_detail",
|
||||
key_id=self.key_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"user_id": api_key.user_id,
|
||||
"name": api_key.name,
|
||||
"key_display": api_key.get_display_key(),
|
||||
"is_active": api_key.is_active,
|
||||
"is_standalone": api_key.is_standalone,
|
||||
"current_balance_usd": api_key.current_balance_usd,
|
||||
"balance_used_usd": float(api_key.balance_used_usd or 0),
|
||||
"total_requests": api_key.total_requests,
|
||||
"total_cost_usd": float(api_key.total_cost_usd or 0),
|
||||
"rate_limit": api_key.rate_limit,
|
||||
"allowed_providers": api_key.allowed_providers,
|
||||
"allowed_api_formats": api_key.allowed_api_formats,
|
||||
"allowed_models": api_key.allowed_models,
|
||||
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
|
||||
}
|
||||
24
src/api/admin/endpoints/__init__.py
Normal file
24
src/api/admin/endpoints/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Endpoint management API routers."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .concurrency import router as concurrency_router
|
||||
from .health import router as health_router
|
||||
from .keys import router as keys_router
|
||||
from .routes import router as routes_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/endpoints", tags=["Endpoint Management"])
|
||||
|
||||
# Endpoint CRUD
|
||||
router.include_router(routes_router)
|
||||
|
||||
# Endpoint Keys management
|
||||
router.include_router(keys_router)
|
||||
|
||||
# Health monitoring
|
||||
router.include_router(health_router)
|
||||
|
||||
# Concurrency control
|
||||
router.include_router(concurrency_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
116
src/api/admin/endpoints/concurrency.py
Normal file
116
src/api/admin/endpoints/concurrency.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Endpoint 并发控制管理 API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.database import get_db
|
||||
from src.models.database import ProviderAPIKey, ProviderEndpoint
|
||||
from src.models.endpoint_models import (
|
||||
ConcurrencyStatusResponse,
|
||||
ResetConcurrencyRequest,
|
||||
)
|
||||
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
||||
|
||||
router = APIRouter(tags=["Concurrency Control"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/concurrency/endpoint/{endpoint_id}", response_model=ConcurrencyStatusResponse)
|
||||
async def get_endpoint_concurrency(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ConcurrencyStatusResponse:
|
||||
"""获取 Endpoint 当前并发状态"""
|
||||
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/concurrency/key/{key_id}", response_model=ConcurrencyStatusResponse)
|
||||
async def get_key_concurrency(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ConcurrencyStatusResponse:
|
||||
"""获取 Key 当前并发状态"""
|
||||
adapter = AdminKeyConcurrencyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/concurrency")
|
||||
async def reset_concurrency(
|
||||
request: ResetConcurrencyRequest,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Reset concurrency counters (admin function, use with caution)"""
|
||||
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminEndpointConcurrencyAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
endpoint_count, _ = await concurrency_manager.get_current_concurrency(
|
||||
endpoint_id=self.endpoint_id
|
||||
)
|
||||
|
||||
return ConcurrencyStatusResponse(
|
||||
endpoint_id=self.endpoint_id,
|
||||
endpoint_current_concurrency=endpoint_count,
|
||||
endpoint_max_concurrent=endpoint.max_concurrent,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminKeyConcurrencyAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
_, key_count = await concurrency_manager.get_current_concurrency(key_id=self.key_id)
|
||||
|
||||
return ConcurrencyStatusResponse(
|
||||
key_id=self.key_id,
|
||||
key_current_concurrency=key_count,
|
||||
key_max_concurrent=key.max_concurrent,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminResetConcurrencyAdapter(AdminApiAdapter):
|
||||
endpoint_id: Optional[str]
|
||||
key_id: Optional[str]
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
await concurrency_manager.reset_concurrency(
|
||||
endpoint_id=self.endpoint_id, key_id=self.key_id
|
||||
)
|
||||
return {"message": "并发计数已重置"}
|
||||
476
src/api/admin/endpoints/health.py
Normal file
476
src/api/admin/endpoints/health.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""
|
||||
Endpoint 健康监控 API
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, RequestCandidate
|
||||
from src.models.endpoint_models import (
|
||||
ApiFormatHealthMonitor,
|
||||
ApiFormatHealthMonitorResponse,
|
||||
EndpointHealthEvent,
|
||||
HealthStatusResponse,
|
||||
HealthSummaryResponse,
|
||||
)
|
||||
from src.services.health.endpoint import EndpointHealthService
|
||||
from src.services.health.monitor import health_monitor
|
||||
|
||||
router = APIRouter(tags=["Endpoint Health"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/health/summary", response_model=HealthSummaryResponse)
|
||||
async def get_health_summary(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> HealthSummaryResponse:
|
||||
"""获取健康状态摘要"""
|
||||
adapter = AdminHealthSummaryAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/health/status")
|
||||
async def get_endpoint_health_status(
|
||||
request: Request,
|
||||
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取端点健康状态(简化视图,与用户端点统一)
|
||||
|
||||
与 /health/api-formats 的区别:
|
||||
- /health/status: 返回聚合的时间线状态(50个时间段),基于 Usage 表
|
||||
- /health/api-formats: 返回详细的事件列表,基于 RequestCandidate 表
|
||||
"""
|
||||
adapter = AdminEndpointHealthStatusAdapter(lookback_hours=lookback_hours)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/health/api-formats", response_model=ApiFormatHealthMonitorResponse)
|
||||
async def get_api_format_health_monitor(
|
||||
request: Request,
|
||||
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
|
||||
per_format_limit: int = Query(60, ge=10, le=200, description="每个 API 格式的事件数量"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> ApiFormatHealthMonitorResponse:
|
||||
"""获取按 API 格式聚合的健康监控时间线(详细事件列表)"""
|
||||
adapter = AdminApiFormatHealthMonitorAdapter(
|
||||
lookback_hours=lookback_hours,
|
||||
per_format_limit=per_format_limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/health/key/{key_id}", response_model=HealthStatusResponse)
|
||||
async def get_key_health(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> HealthStatusResponse:
|
||||
"""获取 Key 健康状态"""
|
||||
adapter = AdminKeyHealthAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/health/keys/{key_id}")
|
||||
async def recover_key_health(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
Recover key health status
|
||||
|
||||
Resets health_score to 1.0, closes circuit breaker,
|
||||
cancels auto-disable, and resets all failure counts.
|
||||
|
||||
Parameters:
|
||||
- key_id: Key ID (path parameter)
|
||||
"""
|
||||
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/health/keys")
|
||||
async def recover_all_keys_health(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
Batch recover all circuit-broken keys
|
||||
|
||||
Finds all keys with circuit_breaker_open=True and:
|
||||
1. Resets health_score to 1.0
|
||||
2. Closes circuit breaker
|
||||
3. Resets failure counts
|
||||
"""
|
||||
adapter = AdminRecoverAllKeysHealthAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
class AdminHealthSummaryAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
summary = health_monitor.get_all_health_status(context.db)
|
||||
return HealthSummaryResponse(**summary)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminEndpointHealthStatusAdapter(AdminApiAdapter):
|
||||
"""管理员端点健康状态适配器(与用户端点统一,但包含管理员字段)"""
|
||||
|
||||
lookback_hours: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.services.health.endpoint import EndpointHealthService
|
||||
|
||||
db = context.db
|
||||
|
||||
# 使用共享服务获取健康状态(管理员视图)
|
||||
result = EndpointHealthService.get_endpoint_health_by_format(
|
||||
db=db,
|
||||
lookback_hours=self.lookback_hours,
|
||||
include_admin_fields=True, # 包含管理员字段
|
||||
use_cache=False, # 管理员不使用缓存,确保实时性
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="endpoint_health_status",
|
||||
format_count=len(result),
|
||||
lookback_hours=self.lookback_hours,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
|
||||
lookback_hours: int
|
||||
per_format_limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
now = datetime.now(timezone.utc)
|
||||
since = now - timedelta(hours=self.lookback_hours)
|
||||
|
||||
# 1. 获取所有活跃的 API 格式及其 Provider 数量
|
||||
active_formats = (
|
||||
db.query(
|
||||
ProviderEndpoint.api_format,
|
||||
func.count(func.distinct(ProviderEndpoint.provider_id)).label("provider_count"),
|
||||
)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.group_by(ProviderEndpoint.api_format)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构建所有格式的 provider_count 映射
|
||||
all_formats: Dict[str, int] = {}
|
||||
for api_format_enum, provider_count in active_formats:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
all_formats[api_format] = provider_count
|
||||
|
||||
# 1.1 获取所有活跃的 API 格式及其 API Key 数量
|
||||
active_keys = (
|
||||
db.query(
|
||||
ProviderEndpoint.api_format,
|
||||
func.count(ProviderAPIKey.id).label("key_count"),
|
||||
)
|
||||
.join(ProviderAPIKey, ProviderEndpoint.id == ProviderAPIKey.endpoint_id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
.group_by(ProviderEndpoint.api_format)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构建所有格式的 key_count 映射
|
||||
key_counts: Dict[str, int] = {}
|
||||
for api_format_enum, key_count in active_keys:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
key_counts[api_format] = key_count
|
||||
|
||||
# 1.2 建立每个 API 格式对应的 Endpoint ID 列表,供 Usage 时间线生成使用
|
||||
endpoint_rows = (
|
||||
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
endpoint_map: Dict[str, List[str]] = defaultdict(list)
|
||||
for api_format_enum, endpoint_id in endpoint_rows:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
endpoint_map[api_format].append(endpoint_id)
|
||||
|
||||
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
|
||||
# 只统计最终状态:success, failed, skipped
|
||||
final_statuses = ["success", "failed", "skipped"]
|
||||
status_counts_query = (
|
||||
db.query(
|
||||
ProviderEndpoint.api_format,
|
||||
RequestCandidate.status,
|
||||
func.count(RequestCandidate.id).label("count"),
|
||||
)
|
||||
.join(RequestCandidate, ProviderEndpoint.id == RequestCandidate.endpoint_id)
|
||||
.filter(
|
||||
RequestCandidate.created_at >= since,
|
||||
RequestCandidate.status.in_(final_statuses),
|
||||
)
|
||||
.group_by(ProviderEndpoint.api_format, RequestCandidate.status)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构建每个格式的状态统计
|
||||
status_counts: Dict[str, Dict[str, int]] = {}
|
||||
for api_format_enum, status, count in status_counts_query:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
if api_format not in status_counts:
|
||||
status_counts[api_format] = {"success": 0, "failed": 0, "skipped": 0}
|
||||
status_counts[api_format][status] = count
|
||||
|
||||
# 3. 获取最近一段时间的 RequestCandidate(限制数量)
|
||||
# 使用上面定义的 final_statuses,排除中间状态
|
||||
limit_rows = max(500, self.per_format_limit * 10)
|
||||
rows = (
|
||||
db.query(
|
||||
RequestCandidate,
|
||||
ProviderEndpoint.api_format,
|
||||
ProviderEndpoint.provider_id,
|
||||
)
|
||||
.join(ProviderEndpoint, RequestCandidate.endpoint_id == ProviderEndpoint.id)
|
||||
.filter(
|
||||
RequestCandidate.created_at >= since,
|
||||
RequestCandidate.status.in_(final_statuses),
|
||||
)
|
||||
.order_by(RequestCandidate.created_at.desc())
|
||||
.limit(limit_rows)
|
||||
.all()
|
||||
)
|
||||
|
||||
grouped_attempts: Dict[str, List[RequestCandidate]] = {}
|
||||
|
||||
for attempt, api_format_enum, provider_id in rows:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
if api_format not in grouped_attempts:
|
||||
grouped_attempts[api_format] = []
|
||||
|
||||
# 只保留每个 API 格式最近 per_format_limit 条记录
|
||||
if len(grouped_attempts[api_format]) < self.per_format_limit:
|
||||
grouped_attempts[api_format].append(attempt)
|
||||
|
||||
# 4. 为所有活跃格式生成监控数据(包括没有请求记录的)
|
||||
monitors: List[ApiFormatHealthMonitor] = []
|
||||
for api_format in all_formats:
|
||||
attempts = grouped_attempts.get(api_format, [])
|
||||
# 获取窗口内的真实统计数据
|
||||
# 只统计最终状态:success, failed, skipped
|
||||
# 中间状态(available, pending, used, started)不计入统计
|
||||
format_stats = status_counts.get(api_format, {"success": 0, "failed": 0, "skipped": 0})
|
||||
real_success_count = format_stats.get("success", 0)
|
||||
real_failed_count = format_stats.get("failed", 0)
|
||||
real_skipped_count = format_stats.get("skipped", 0)
|
||||
# total_attempts 只包含最终状态的请求数
|
||||
total_attempts = real_success_count + real_failed_count + real_skipped_count
|
||||
|
||||
# 时间线按时间正序
|
||||
attempts_sorted = list(reversed(attempts))
|
||||
events: List[EndpointHealthEvent] = []
|
||||
for attempt in attempts_sorted:
|
||||
event_timestamp = attempt.finished_at or attempt.started_at or attempt.created_at
|
||||
events.append(
|
||||
EndpointHealthEvent(
|
||||
timestamp=event_timestamp,
|
||||
status=attempt.status,
|
||||
status_code=attempt.status_code,
|
||||
latency_ms=attempt.latency_ms,
|
||||
error_type=attempt.error_type,
|
||||
error_message=attempt.error_message,
|
||||
)
|
||||
)
|
||||
|
||||
# 成功率 = success / (success + failed)
|
||||
# skipped 不算失败,不计入成功率分母
|
||||
# 无实际完成请求时成功率为 1.0(灰色状态)
|
||||
actual_completed = real_success_count + real_failed_count
|
||||
success_rate = real_success_count / actual_completed if actual_completed > 0 else 1.0
|
||||
last_event_at = events[-1].timestamp if events else None
|
||||
|
||||
# 生成 Usage 基于时间窗口的健康时间线
|
||||
timeline_data = EndpointHealthService._generate_timeline_from_usage(
|
||||
db=db,
|
||||
endpoint_ids=endpoint_map.get(api_format, []),
|
||||
now=now,
|
||||
lookback_hours=self.lookback_hours,
|
||||
)
|
||||
|
||||
monitors.append(
|
||||
ApiFormatHealthMonitor(
|
||||
api_format=api_format,
|
||||
total_attempts=total_attempts, # 真实总请求数
|
||||
success_count=real_success_count, # 真实成功数
|
||||
failed_count=real_failed_count, # 真实失败数
|
||||
skipped_count=real_skipped_count, # 真实跳过数
|
||||
success_rate=success_rate, # 基于真实统计的成功率
|
||||
provider_count=all_formats[api_format],
|
||||
key_count=key_counts.get(api_format, 0),
|
||||
last_event_at=last_event_at,
|
||||
events=events, # 限制为 per_format_limit 条(用于时间线显示)
|
||||
timeline=timeline_data.get("timeline", []),
|
||||
time_range_start=timeline_data.get("time_range_start"),
|
||||
time_range_end=timeline_data.get("time_range_end"),
|
||||
)
|
||||
)
|
||||
|
||||
response = ApiFormatHealthMonitorResponse(
|
||||
generated_at=now,
|
||||
formats=monitors,
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="api_format_health_monitor",
|
||||
format_count=len(monitors),
|
||||
lookback_hours=self.lookback_hours,
|
||||
per_format_limit=self.per_format_limit,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminKeyHealthAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
health_data = health_monitor.get_key_health(context.db, self.key_id)
|
||||
if not health_data:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
return HealthStatusResponse(
|
||||
key_id=health_data["key_id"],
|
||||
key_health_score=health_data["health_score"],
|
||||
key_consecutive_failures=health_data["consecutive_failures"],
|
||||
key_last_failure_at=health_data["last_failure_at"],
|
||||
key_is_active=health_data["is_active"],
|
||||
key_statistics=health_data["statistics"],
|
||||
circuit_breaker_open=health_data["circuit_breaker_open"],
|
||||
circuit_breaker_open_at=health_data["circuit_breaker_open_at"],
|
||||
next_probe_at=health_data["next_probe_at"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
key.health_score = 1.0
|
||||
key.consecutive_failures = 0
|
||||
key.last_failure_at = None
|
||||
key.circuit_breaker_open = False
|
||||
key.circuit_breaker_open_at = None
|
||||
key.next_probe_at = None
|
||||
if not key.is_active:
|
||||
key.is_active = True
|
||||
|
||||
db.commit()
|
||||
|
||||
admin_name = context.user.username if context.user else "admin"
|
||||
logger.info(f"管理员恢复Key健康状态: {self.key_id} (health_score: 1.0, circuit_breaker: closed)")
|
||||
|
||||
return {
|
||||
"message": "Key已完全恢复",
|
||||
"details": {
|
||||
"health_score": 1.0,
|
||||
"circuit_breaker_open": False,
|
||||
"is_active": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
||||
"""批量恢复所有熔断 Key 的健康状态"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 查找所有熔断的 Key
|
||||
circuit_open_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.circuit_breaker_open == True).all()
|
||||
)
|
||||
|
||||
if not circuit_open_keys:
|
||||
return {
|
||||
"message": "没有需要恢复的 Key",
|
||||
"recovered_count": 0,
|
||||
"recovered_keys": [],
|
||||
}
|
||||
|
||||
recovered_keys = []
|
||||
for key in circuit_open_keys:
|
||||
key.health_score = 1.0
|
||||
key.consecutive_failures = 0
|
||||
key.last_failure_at = None
|
||||
key.circuit_breaker_open = False
|
||||
key.circuit_breaker_open_at = None
|
||||
key.next_probe_at = None
|
||||
recovered_keys.append(
|
||||
{
|
||||
"key_id": key.id,
|
||||
"key_name": key.name,
|
||||
"endpoint_id": key.endpoint_id,
|
||||
}
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 重置健康监控器的计数
|
||||
from src.services.health.monitor import HealthMonitor, health_open_circuits
|
||||
|
||||
HealthMonitor._open_circuit_keys = 0
|
||||
health_open_circuits.set(0)
|
||||
|
||||
admin_name = context.user.username if context.user else "admin"
|
||||
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
|
||||
|
||||
return {
|
||||
"message": f"已恢复 {len(recovered_keys)} 个 Key",
|
||||
"recovered_count": len(recovered_keys),
|
||||
"recovered_keys": recovered_keys,
|
||||
}
|
||||
425
src/api/admin/endpoints/keys.py
Normal file
425
src/api/admin/endpoints/keys.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Endpoint API Keys 管理
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.key_capabilities import get_capability
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
from src.models.endpoint_models import (
|
||||
BatchUpdateKeyPriorityRequest,
|
||||
EndpointAPIKeyCreate,
|
||||
EndpointAPIKeyResponse,
|
||||
EndpointAPIKeyUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Endpoint Keys"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
|
||||
async def list_endpoint_keys(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[EndpointAPIKeyResponse]:
|
||||
"""获取 Endpoint 的所有 Keys"""
|
||||
adapter = AdminListEndpointKeysAdapter(
|
||||
endpoint_id=endpoint_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
|
||||
async def add_endpoint_key(
|
||||
endpoint_id: str,
|
||||
key_data: EndpointAPIKeyCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""为 Endpoint 添加 Key"""
|
||||
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
|
||||
async def update_endpoint_key(
|
||||
key_id: str,
|
||||
key_data: EndpointAPIKeyUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""更新 Endpoint Key"""
|
||||
adapter = AdminUpdateEndpointKeyAdapter(key_id=key_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/keys/grouped-by-format")
|
||||
async def get_keys_grouped_by_format(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取按 API 格式分组的所有 Keys(用于全局优先级管理)"""
|
||||
adapter = AdminGetKeysGroupedByFormatAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/keys/{key_id}")
|
||||
async def delete_endpoint_key(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除 Endpoint Key"""
|
||||
adapter = AdminDeleteEndpointKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{endpoint_id}/keys/batch-priority")
|
||||
async def batch_update_key_priority(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
priority_data: BatchUpdateKeyPriorityRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""批量更新 Endpoint 下 Keys 的优先级(用于拖动排序)"""
|
||||
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListEndpointKeysAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
|
||||
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
|
||||
.offset(self.skip)
|
||||
.limit(self.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
result: List[EndpointAPIKeyResponse] = []
|
||||
for key in keys:
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
||||
avg_response_time_ms = (
|
||||
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
||||
)
|
||||
|
||||
is_adaptive = key.max_concurrent is None
|
||||
key_dict = key.__dict__.copy()
|
||||
key_dict.pop("_sa_instance_state", None)
|
||||
key_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": None,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
result.append(EndpointAPIKeyResponse(**key_dict))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
key_data: EndpointAPIKeyCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
if self.key_data.endpoint_id != self.endpoint_id:
|
||||
raise InvalidRequestException("endpoint_id 不匹配")
|
||||
|
||||
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
|
||||
now = datetime.now(timezone.utc)
|
||||
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
|
||||
new_key = ProviderAPIKey(
|
||||
id=str(uuid.uuid4()),
|
||||
endpoint_id=self.endpoint_id,
|
||||
api_key=encrypted_key,
|
||||
name=self.key_data.name,
|
||||
note=self.key_data.note,
|
||||
rate_multiplier=self.key_data.rate_multiplier,
|
||||
internal_priority=self.key_data.internal_priority,
|
||||
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
|
||||
rate_limit=self.key_data.rate_limit,
|
||||
daily_limit=self.key_data.daily_limit,
|
||||
monthly_limit=self.key_data.monthly_limit,
|
||||
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
|
||||
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
|
||||
request_count=0,
|
||||
success_count=0,
|
||||
error_count=0,
|
||||
total_response_time_ms=0,
|
||||
is_active=True,
|
||||
last_used_at=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
db.add(new_key)
|
||||
db.commit()
|
||||
db.refresh(new_key)
|
||||
|
||||
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
|
||||
|
||||
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
|
||||
is_adaptive = new_key.max_concurrent is None
|
||||
response_dict = new_key.__dict__.copy()
|
||||
response_dict.pop("_sa_instance_state", None)
|
||||
response_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": self.key_data.api_key,
|
||||
"success_rate": 0.0,
|
||||
"avg_response_time_ms": 0.0,
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return EndpointAPIKeyResponse(**response_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
key_data: EndpointAPIKeyUpdate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
update_data = self.key_data.model_dump(exclude_unset=True)
|
||||
if "api_key" in update_data:
|
||||
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(key, field, value)
|
||||
key.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
|
||||
logger.info(f"[OK] 更新 Key: ID={self.key_id}, Updates={list(update_data.keys())}")
|
||||
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
||||
avg_response_time_ms = (
|
||||
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
||||
)
|
||||
|
||||
is_adaptive = key.max_concurrent is None
|
||||
response_dict = key.__dict__.copy()
|
||||
response_dict.pop("_sa_instance_state", None)
|
||||
response_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": None,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
return EndpointAPIKeyResponse(**response_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
endpoint_id = key.endpoint_id
|
||||
try:
|
||||
db.delete(key)
|
||||
db.commit()
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
|
||||
raise
|
||||
|
||||
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}")
|
||||
return {"message": f"Key {self.key_id} 已删除"}
|
||||
|
||||
|
||||
class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
keys = (
|
||||
db.query(ProviderAPIKey, ProviderEndpoint, Provider)
|
||||
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.order_by(
|
||||
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc()
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
grouped: Dict[str, List[dict]] = {}
|
||||
for key, endpoint, provider in keys:
|
||||
api_format = endpoint.api_format
|
||||
if api_format not in grouped:
|
||||
grouped[api_format] = []
|
||||
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
# 计算健康度指标
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else None
|
||||
avg_response_time_ms = (
|
||||
round(key.total_response_time_ms / key.success_count, 2)
|
||||
if key.success_count > 0
|
||||
else None
|
||||
)
|
||||
|
||||
# 将 capabilities dict 转换为启用的能力简短名称列表
|
||||
caps_list = []
|
||||
if key.capabilities:
|
||||
for cap_name, enabled in key.capabilities.items():
|
||||
if enabled:
|
||||
cap_def = get_capability(cap_name)
|
||||
caps_list.append(cap_def.short_name if cap_def else cap_name)
|
||||
|
||||
grouped[api_format].append(
|
||||
{
|
||||
"id": key.id,
|
||||
"name": key.name,
|
||||
"api_key_masked": masked_key,
|
||||
"internal_priority": key.internal_priority,
|
||||
"global_priority": key.global_priority,
|
||||
"rate_multiplier": key.rate_multiplier,
|
||||
"is_active": key.is_active,
|
||||
"circuit_breaker_open": key.circuit_breaker_open,
|
||||
"provider_name": provider.display_name or provider.name,
|
||||
"endpoint_base_url": endpoint.base_url,
|
||||
"api_format": api_format,
|
||||
"capabilities": caps_list,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": avg_response_time_ms,
|
||||
"request_count": key.request_count,
|
||||
}
|
||||
)
|
||||
|
||||
# 直接返回分组对象,供前端使用
|
||||
return grouped
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
priority_data: BatchUpdateKeyPriorityRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
# 获取所有需要更新的 Key ID
|
||||
key_ids = [item.key_id for item in self.priority_data.priorities]
|
||||
|
||||
# 验证所有 Key 都属于该 Endpoint
|
||||
keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
ProviderAPIKey.id.in_(key_ids),
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(keys) != len(key_ids):
|
||||
found_ids = {k.id for k in keys}
|
||||
missing_ids = set(key_ids) - found_ids
|
||||
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
|
||||
|
||||
# 批量更新优先级
|
||||
key_map = {k.id: k for k in keys}
|
||||
updated_count = 0
|
||||
for item in self.priority_data.priorities:
|
||||
key = key_map.get(item.key_id)
|
||||
if key and key.internal_priority != item.internal_priority:
|
||||
key.internal_priority = item.internal_priority
|
||||
key.updated_at = datetime.now(timezone.utc)
|
||||
updated_count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}")
|
||||
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count}
|
||||
345
src/api/admin/endpoints/routes.py
Normal file
345
src/api/admin/endpoints/routes.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
ProviderEndpoint CRUD 管理 API
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
from src.models.endpoint_models import (
|
||||
ProviderEndpointCreate,
|
||||
ProviderEndpointResponse,
|
||||
ProviderEndpointUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Endpoint Management"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
|
||||
async def list_provider_endpoints(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ProviderEndpointResponse]:
|
||||
"""获取指定 Provider 的所有 Endpoints"""
|
||||
adapter = AdminListProviderEndpointsAdapter(
|
||||
provider_id=provider_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/providers/{provider_id}/endpoints", response_model=ProviderEndpointResponse)
|
||||
async def create_provider_endpoint(
|
||||
provider_id: str,
|
||||
endpoint_data: ProviderEndpointCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""为 Provider 创建新的 Endpoint"""
|
||||
adapter = AdminCreateProviderEndpointAdapter(
|
||||
provider_id=provider_id,
|
||||
endpoint_data=endpoint_data,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{endpoint_id}", response_model=ProviderEndpointResponse)
|
||||
async def get_endpoint(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""获取 Endpoint 详情"""
|
||||
adapter = AdminGetProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{endpoint_id}", response_model=ProviderEndpointResponse)
|
||||
async def update_endpoint(
|
||||
endpoint_id: str,
|
||||
endpoint_data: ProviderEndpointUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""更新 Endpoint"""
|
||||
adapter = AdminUpdateProviderEndpointAdapter(
|
||||
endpoint_id=endpoint_id,
|
||||
endpoint_data=endpoint_data,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{endpoint_id}")
|
||||
async def delete_endpoint(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除 Endpoint(级联删除所有关联的 Keys)"""
|
||||
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListProviderEndpointsAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||
|
||||
endpoints = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(ProviderEndpoint.provider_id == self.provider_id)
|
||||
.order_by(ProviderEndpoint.created_at.desc())
|
||||
.offset(self.skip)
|
||||
.limit(self.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
endpoint_ids = [ep.id for ep in endpoints]
|
||||
total_keys_map = {}
|
||||
active_keys_map = {}
|
||||
if endpoint_ids:
|
||||
total_rows = (
|
||||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
|
||||
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
|
||||
.group_by(ProviderAPIKey.endpoint_id)
|
||||
.all()
|
||||
)
|
||||
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
|
||||
|
||||
active_rows = (
|
||||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.group_by(ProviderAPIKey.endpoint_id)
|
||||
.all()
|
||||
)
|
||||
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
|
||||
|
||||
result: List[ProviderEndpointResponse] = []
|
||||
for endpoint in endpoints:
|
||||
endpoint_dict = {
|
||||
**endpoint.__dict__,
|
||||
"provider_name": provider.name,
|
||||
"api_format": endpoint.api_format,
|
||||
"total_keys": total_keys_map.get(endpoint.id, 0),
|
||||
"active_keys": active_keys_map.get(endpoint.id, 0),
|
||||
}
|
||||
endpoint_dict.pop("_sa_instance_state", None)
|
||||
result.append(ProviderEndpointResponse(**endpoint_dict))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
endpoint_data: ProviderEndpointCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||
|
||||
if self.endpoint_data.provider_id != self.provider_id:
|
||||
raise InvalidRequestException("provider_id 不匹配")
|
||||
|
||||
existing = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(
|
||||
and_(
|
||||
ProviderEndpoint.provider_id == self.provider_id,
|
||||
ProviderEndpoint.api_format == self.endpoint_data.api_format,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise InvalidRequestException(
|
||||
f"Provider {provider.name} 已存在 {self.endpoint_data.api_format} 格式的 Endpoint"
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
new_endpoint = ProviderEndpoint(
|
||||
id=str(uuid.uuid4()),
|
||||
provider_id=self.provider_id,
|
||||
api_format=self.endpoint_data.api_format,
|
||||
base_url=self.endpoint_data.base_url,
|
||||
headers=self.endpoint_data.headers,
|
||||
timeout=self.endpoint_data.timeout,
|
||||
max_retries=self.endpoint_data.max_retries,
|
||||
max_concurrent=self.endpoint_data.max_concurrent,
|
||||
rate_limit=self.endpoint_data.rate_limit,
|
||||
is_active=True,
|
||||
config=self.endpoint_data.config,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
db.add(new_endpoint)
|
||||
db.commit()
|
||||
db.refresh(new_endpoint)
|
||||
|
||||
logger.info(f"[OK] 创建 Endpoint: Provider={provider.name}, Format={self.endpoint_data.api_format}, ID={new_endpoint.id}")
|
||||
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in new_endpoint.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name,
|
||||
api_format=new_endpoint.api_format,
|
||||
total_keys=0,
|
||||
active_keys=0,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint, Provider)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(ProviderEndpoint.id == self.endpoint_id)
|
||||
.first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
endpoint_obj, provider = endpoint
|
||||
total_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
)
|
||||
active_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in endpoint_obj.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name,
|
||||
api_format=endpoint_obj.api_format,
|
||||
total_keys=total_keys,
|
||||
active_keys=active_keys,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
endpoint_data: ProviderEndpointUpdate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
update_data = self.endpoint_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(endpoint, field, value)
|
||||
endpoint.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
db.refresh(endpoint)
|
||||
|
||||
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
|
||||
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
|
||||
|
||||
total_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
)
|
||||
active_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in endpoint.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name if provider else "Unknown",
|
||||
api_format=endpoint.api_format,
|
||||
total_keys=total_keys,
|
||||
active_keys=active_keys,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
keys_count = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
)
|
||||
db.delete(endpoint)
|
||||
db.commit()
|
||||
|
||||
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
|
||||
|
||||
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}
|
||||
16
src/api/admin/models/__init__.py
Normal file
16
src/api/admin/models/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
模型管理相关 Admin API
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .catalog import router as catalog_router
|
||||
from .global_models import router as global_models_router
|
||||
from .mappings import router as mappings_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
|
||||
|
||||
# 挂载子路由
|
||||
router.include_router(catalog_router)
|
||||
router.include_router(global_models_router)
|
||||
router.include_router(mappings_router)
|
||||
432
src/api/admin/models/catalog.py
Normal file
432
src/api/admin/models/catalog.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""
|
||||
统一模型目录 Admin API
|
||||
|
||||
阶段一:基于 ModelMapping 和 Model 的聚合视图
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignError,
|
||||
BatchAssignModelMappingRequest,
|
||||
BatchAssignModelMappingResponse,
|
||||
BatchAssignProviderResult,
|
||||
DeleteModelMappingResponse,
|
||||
ModelCapabilities,
|
||||
ModelCatalogItem,
|
||||
ModelCatalogProviderDetail,
|
||||
ModelCatalogResponse,
|
||||
ModelPriceRange,
|
||||
OrphanedModel,
|
||||
UpdateModelMappingRequest,
|
||||
UpdateModelMappingResponse,
|
||||
)
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("", response_model=ModelCatalogResponse)
|
||||
async def get_model_catalog(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelCatalogResponse:
|
||||
adapter = AdminGetModelCatalogAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/batch-assign", response_model=BatchAssignModelMappingResponse)
|
||||
async def batch_assign_model_mappings(
|
||||
request: Request,
|
||||
payload: BatchAssignModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignModelMappingResponse:
|
||||
adapter = AdminBatchAssignModelMappingsAdapter(payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
"""管理员查询统一模型目录
|
||||
|
||||
新架构说明:
|
||||
1. 以 GlobalModel 为中心聚合数据
|
||||
2. ModelMapping 表提供别名信息(provider_id=NULL 表示全局)
|
||||
3. Model 表提供关联提供商和价格
|
||||
"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
# 1. 获取所有活跃的 GlobalModel
|
||||
global_models: List[GlobalModel] = (
|
||||
db.query(GlobalModel).filter(GlobalModel.is_active == True).all()
|
||||
)
|
||||
|
||||
# 2. 获取所有活跃的别名(含全局和 Provider 特定)
|
||||
aliases_rows: List[ModelMapping] = (
|
||||
db.query(ModelMapping)
|
||||
.options(joinedload(ModelMapping.target_global_model))
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按 GlobalModel ID 组织别名
|
||||
aliases_by_global_model: Dict[str, List[str]] = {}
|
||||
for alias_row in aliases_rows:
|
||||
if not alias_row.target_global_model_id:
|
||||
continue
|
||||
gm_id = alias_row.target_global_model_id
|
||||
if gm_id not in aliases_by_global_model:
|
||||
aliases_by_global_model[gm_id] = []
|
||||
if alias_row.source_model not in aliases_by_global_model[gm_id]:
|
||||
aliases_by_global_model[gm_id].append(alias_row.source_model)
|
||||
|
||||
# 3. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
|
||||
models: List[Model] = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.provider), joinedload(Model.global_model))
|
||||
.filter(Model.is_active == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按 GlobalModel ID 组织关联提供商
|
||||
models_by_global_model: Dict[str, List[Model]] = {}
|
||||
for model in models:
|
||||
if model.global_model_id:
|
||||
models_by_global_model.setdefault(model.global_model_id, []).append(model)
|
||||
|
||||
# 4. 为每个 GlobalModel 构建 catalog item
|
||||
catalog_items: List[ModelCatalogItem] = []
|
||||
|
||||
for gm in global_models:
|
||||
gm_id = gm.id
|
||||
provider_entries: List[ModelCatalogProviderDetail] = []
|
||||
capability_flags = {
|
||||
"supports_vision": gm.default_supports_vision or False,
|
||||
"supports_function_calling": gm.default_supports_function_calling or False,
|
||||
"supports_streaming": gm.default_supports_streaming or False,
|
||||
}
|
||||
|
||||
# 遍历该 GlobalModel 的所有关联提供商
|
||||
for model in models_by_global_model.get(gm_id, []):
|
||||
provider = model.provider
|
||||
if not provider:
|
||||
continue
|
||||
|
||||
# 使用有效价格(考虑 GlobalModel 默认值)
|
||||
effective_input = model.get_effective_input_price()
|
||||
effective_output = model.get_effective_output_price()
|
||||
effective_tiered = model.get_effective_tiered_pricing()
|
||||
tier_count = len(effective_tiered.get("tiers", [])) if effective_tiered else 1
|
||||
|
||||
# 使用有效能力值
|
||||
capability_flags["supports_vision"] = (
|
||||
capability_flags["supports_vision"] or model.get_effective_supports_vision()
|
||||
)
|
||||
capability_flags["supports_function_calling"] = (
|
||||
capability_flags["supports_function_calling"]
|
||||
or model.get_effective_supports_function_calling()
|
||||
)
|
||||
capability_flags["supports_streaming"] = (
|
||||
capability_flags["supports_streaming"]
|
||||
or model.get_effective_supports_streaming()
|
||||
)
|
||||
|
||||
provider_entries.append(
|
||||
ModelCatalogProviderDetail(
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.name,
|
||||
provider_display_name=provider.display_name,
|
||||
model_id=model.id,
|
||||
target_model=model.provider_model_name,
|
||||
# 显示有效价格
|
||||
input_price_per_1m=effective_input,
|
||||
output_price_per_1m=effective_output,
|
||||
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||
cache_read_price_per_1m=model.get_effective_cache_read_price(),
|
||||
cache_1h_creation_price_per_1m=model.get_effective_1h_cache_creation_price(),
|
||||
price_per_request=model.get_effective_price_per_request(),
|
||||
effective_tiered_pricing=effective_tiered,
|
||||
tier_count=tier_count,
|
||||
supports_vision=model.get_effective_supports_vision(),
|
||||
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
supports_streaming=model.get_effective_supports_streaming(),
|
||||
is_active=bool(model.is_active),
|
||||
mapping_id=None, # 新架构中不再有 mapping_id
|
||||
)
|
||||
)
|
||||
|
||||
# 模型目录显示 GlobalModel 的第一个阶梯价格(不是 Provider 聚合价格)
|
||||
tiered = gm.default_tiered_pricing or {}
|
||||
first_tier = tiered.get("tiers", [{}])[0] if tiered.get("tiers") else {}
|
||||
price_range = ModelPriceRange(
|
||||
min_input=first_tier.get("input_price_per_1m", 0),
|
||||
max_input=first_tier.get("input_price_per_1m", 0),
|
||||
min_output=first_tier.get("output_price_per_1m", 0),
|
||||
max_output=first_tier.get("output_price_per_1m", 0),
|
||||
)
|
||||
|
||||
catalog_items.append(
|
||||
ModelCatalogItem(
|
||||
global_model_name=gm.name,
|
||||
display_name=gm.display_name,
|
||||
description=gm.description,
|
||||
aliases=aliases_by_global_model.get(gm_id, []),
|
||||
providers=provider_entries,
|
||||
price_range=price_range,
|
||||
total_providers=len(provider_entries),
|
||||
capabilities=ModelCapabilities(**capability_flags),
|
||||
)
|
||||
)
|
||||
|
||||
# 5. 查找孤立的别名(别名指向的 GlobalModel 不存在或不活跃)
|
||||
orphaned_rows = (
|
||||
db.query(ModelMapping.source_model, GlobalModel.name, func.count(ModelMapping.id))
|
||||
.outerjoin(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
or_(GlobalModel.id == None, GlobalModel.is_active == False),
|
||||
)
|
||||
.group_by(ModelMapping.source_model, GlobalModel.name)
|
||||
.all()
|
||||
)
|
||||
orphaned_models = [
|
||||
OrphanedModel(alias=row[0], global_model_name=row[1], mapping_count=row[2])
|
||||
for row in orphaned_rows
|
||||
if row[0]
|
||||
]
|
||||
|
||||
return ModelCatalogResponse(
|
||||
models=catalog_items,
|
||||
total=len(catalog_items),
|
||||
orphaned_models=orphaned_models,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchAssignModelMappingsAdapter(AdminApiAdapter):
|
||||
payload: BatchAssignModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
created: List[BatchAssignProviderResult] = []
|
||||
errors: List[BatchAssignError] = []
|
||||
|
||||
for provider_config in self.payload.providers:
|
||||
provider_id = provider_config.provider_id
|
||||
try:
|
||||
provider: Provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="Provider 不存在")
|
||||
)
|
||||
continue
|
||||
|
||||
model_id: Optional[str] = None
|
||||
created_model = False
|
||||
|
||||
if provider_config.create_model:
|
||||
model_data = provider_config.model_data
|
||||
if not model_data:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_data 配置")
|
||||
)
|
||||
continue
|
||||
|
||||
existing_model = ModelService.get_model_by_name(
|
||||
db, provider_id, model_data.provider_model_name
|
||||
)
|
||||
if existing_model:
|
||||
model_id = existing_model.id
|
||||
logger.info("模型 %s 已存在于 Provider %s,复用现有模型",
|
||||
model_data.provider_model_name,
|
||||
provider.name,
|
||||
)
|
||||
else:
|
||||
model = ModelService.create_model(db, provider_id, model_data)
|
||||
model_id = model.id
|
||||
created_model = True
|
||||
else:
|
||||
model_id = provider_config.model_id
|
||||
if not model_id:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_id")
|
||||
)
|
||||
continue
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == model_id, Model.provider_id == provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id, error="模型不存在或不属于当前 Provider")
|
||||
)
|
||||
continue
|
||||
|
||||
# 批量分配功能需要适配 GlobalModel 架构
|
||||
# 参见 docs/optimization-backlog.md 中的待办项
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id,
|
||||
error="批量分配功能暂时不可用,需要适配新的 GlobalModel 架构",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error("批量添加模型映射失败(需要适配新架构)")
|
||||
errors.append(BatchAssignError(provider_id=provider_id, error=str(exc)))
|
||||
|
||||
return BatchAssignModelMappingResponse(
|
||||
success=len(created) > 0,
|
||||
created_mappings=created,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/mappings/{mapping_id}", response_model=UpdateModelMappingResponse)
|
||||
async def update_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
payload: UpdateModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> UpdateModelMappingResponse:
|
||||
"""更新模型映射"""
|
||||
adapter = AdminUpdateModelMappingAdapter(mapping_id=mapping_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/mappings/{mapping_id}", response_model=DeleteModelMappingResponse)
|
||||
async def delete_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> DeleteModelMappingResponse:
|
||||
"""删除模型映射"""
|
||||
adapter = AdminDeleteModelMappingAdapter(mapping_id=mapping_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateModelMappingAdapter(AdminApiAdapter):
|
||||
"""更新模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
payload: UpdateModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
update_data = self.payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "provider_id" in update_data:
|
||||
new_provider_id = update_data["provider_id"]
|
||||
if new_provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider 不存在")
|
||||
mapping.provider_id = new_provider_id
|
||||
|
||||
if "target_global_model_id" in update_data:
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == update_data["target_global_model_id"],
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(status_code=404, detail="目标 GlobalModel 不存在或未激活")
|
||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
||||
|
||||
if "source_model" in update_data:
|
||||
new_source = update_data["source_model"].strip()
|
||||
if not new_source:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
mapping.source_model = new_source
|
||||
|
||||
if "is_active" in update_data:
|
||||
mapping.is_active = update_data["is_active"]
|
||||
|
||||
duplicate = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == mapping.source_model,
|
||||
ModelMapping.provider_id == mapping.provider_id,
|
||||
ModelMapping.id != mapping.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
||||
|
||||
return UpdateModelMappingResponse(
|
||||
success=True,
|
||||
mapping_id=mapping.id,
|
||||
message="映射更新成功",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteModelMappingAdapter(AdminApiAdapter):
|
||||
"""删除模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
source_model = mapping.source_model
|
||||
provider_id = mapping.provider_id
|
||||
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return DeleteModelMappingResponse(
|
||||
success=True,
|
||||
message=f"映射 {self.mapping_id} 已删除",
|
||||
)
|
||||
292
src/api/admin/models/global_models.py
Normal file
292
src/api/admin/models/global_models.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
GlobalModel Admin API
|
||||
|
||||
提供 GlobalModel 的 CRUD 操作接口
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignToProvidersRequest,
|
||||
BatchAssignToProvidersResponse,
|
||||
GlobalModelCreate,
|
||||
GlobalModelListResponse,
|
||||
GlobalModelResponse,
|
||||
GlobalModelUpdate,
|
||||
GlobalModelWithStats,
|
||||
)
|
||||
from src.services.model.global_model import GlobalModelService
|
||||
|
||||
router = APIRouter(prefix="/global", tags=["Admin - Global Models"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("", response_model=GlobalModelListResponse)
|
||||
async def list_global_models(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
is_active: Optional[bool] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelListResponse:
|
||||
"""获取 GlobalModel 列表"""
|
||||
adapter = AdminListGlobalModelsAdapter(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
search=search,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{global_model_id}", response_model=GlobalModelWithStats)
|
||||
async def get_global_model(
|
||||
request: Request,
|
||||
global_model_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelWithStats:
|
||||
"""获取单个 GlobalModel 详情(含统计信息)"""
|
||||
adapter = AdminGetGlobalModelAdapter(global_model_id=global_model_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("", response_model=GlobalModelResponse, status_code=201)
|
||||
async def create_global_model(
|
||||
request: Request,
|
||||
payload: GlobalModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelResponse:
|
||||
"""创建 GlobalModel"""
|
||||
adapter = AdminCreateGlobalModelAdapter(payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{global_model_id}", response_model=GlobalModelResponse)
|
||||
async def update_global_model(
|
||||
request: Request,
|
||||
global_model_id: str,
|
||||
payload: GlobalModelUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelResponse:
|
||||
"""更新 GlobalModel"""
|
||||
adapter = AdminUpdateGlobalModelAdapter(global_model_id=global_model_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{global_model_id}", status_code=204)
|
||||
async def delete_global_model(
|
||||
request: Request,
|
||||
global_model_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除 GlobalModel(级联删除所有关联的 Provider 模型实现)"""
|
||||
adapter = AdminDeleteGlobalModelAdapter(global_model_id=global_model_id)
|
||||
await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
return None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{global_model_id}/assign-to-providers", response_model=BatchAssignToProvidersResponse
|
||||
)
|
||||
async def batch_assign_to_providers(
|
||||
request: Request,
|
||||
global_model_id: str,
|
||||
payload: BatchAssignToProvidersRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignToProvidersResponse:
|
||||
"""批量为多个 Provider 添加 GlobalModel 实现"""
|
||||
adapter = AdminBatchAssignToProvidersAdapter(global_model_id=global_model_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ========== Adapters ==========
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListGlobalModelsAdapter(AdminApiAdapter):
|
||||
"""列出 GlobalModel"""
|
||||
|
||||
skip: int
|
||||
limit: int
|
||||
is_active: Optional[bool]
|
||||
search: Optional[str]
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy import func
|
||||
|
||||
from src.models.database import Model, ModelMapping
|
||||
|
||||
models = GlobalModelService.list_global_models(
|
||||
db=context.db,
|
||||
skip=self.skip,
|
||||
limit=self.limit,
|
||||
is_active=self.is_active,
|
||||
search=self.search,
|
||||
)
|
||||
|
||||
# 为每个 GlobalModel 添加统计数据
|
||||
model_responses = []
|
||||
for gm in models:
|
||||
# 统计关联的 Model 数量(去重 Provider)
|
||||
provider_count = (
|
||||
context.db.query(func.count(func.distinct(Model.provider_id)))
|
||||
.filter(Model.global_model_id == gm.id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# 统计别名数量
|
||||
alias_count = (
|
||||
context.db.query(func.count(ModelMapping.id))
|
||||
.filter(ModelMapping.target_global_model_id == gm.id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
response = GlobalModelResponse.model_validate(gm)
|
||||
response.provider_count = provider_count
|
||||
response.alias_count = alias_count
|
||||
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
|
||||
model_responses.append(response)
|
||||
|
||||
return GlobalModelListResponse(
|
||||
models=model_responses,
|
||||
total=len(models),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetGlobalModelAdapter(AdminApiAdapter):
|
||||
"""获取单个 GlobalModel"""
|
||||
|
||||
global_model_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
global_model = GlobalModelService.get_global_model(context.db, self.global_model_id)
|
||||
stats = GlobalModelService.get_global_model_stats(context.db, self.global_model_id)
|
||||
|
||||
return GlobalModelWithStats(
|
||||
**GlobalModelResponse.model_validate(global_model).model_dump(),
|
||||
total_models=stats["total_models"],
|
||||
total_providers=stats["total_providers"],
|
||||
price_range=stats["price_range"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminCreateGlobalModelAdapter(AdminApiAdapter):
|
||||
"""创建 GlobalModel"""
|
||||
|
||||
payload: GlobalModelCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
# 将 TieredPricingConfig 转换为 dict
|
||||
tiered_pricing_dict = self.payload.default_tiered_pricing.model_dump()
|
||||
|
||||
global_model = GlobalModelService.create_global_model(
|
||||
db=context.db,
|
||||
name=self.payload.name,
|
||||
display_name=self.payload.display_name,
|
||||
description=self.payload.description,
|
||||
official_url=self.payload.official_url,
|
||||
icon_url=self.payload.icon_url,
|
||||
is_active=self.payload.is_active,
|
||||
# 按次计费配置
|
||||
default_price_per_request=self.payload.default_price_per_request,
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing=tiered_pricing_dict,
|
||||
# 默认能力配置
|
||||
default_supports_vision=self.payload.default_supports_vision,
|
||||
default_supports_function_calling=self.payload.default_supports_function_calling,
|
||||
default_supports_streaming=self.payload.default_supports_streaming,
|
||||
default_supports_extended_thinking=self.payload.default_supports_extended_thinking,
|
||||
# Key 能力配置
|
||||
supported_capabilities=self.payload.supported_capabilities,
|
||||
)
|
||||
|
||||
logger.info(f"GlobalModel 已创建: id={global_model.id} name={global_model.name}")
|
||||
|
||||
return GlobalModelResponse.model_validate(global_model)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateGlobalModelAdapter(AdminApiAdapter):
|
||||
"""更新 GlobalModel"""
|
||||
|
||||
global_model_id: str
|
||||
payload: GlobalModelUpdate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
global_model = GlobalModelService.update_global_model(
|
||||
db=context.db,
|
||||
global_model_id=self.global_model_id,
|
||||
update_data=self.payload,
|
||||
)
|
||||
|
||||
logger.info(f"GlobalModel 已更新: id={global_model.id} name={global_model.name}")
|
||||
|
||||
# 失效相关缓存
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_global_model_changed(global_model.name)
|
||||
|
||||
return GlobalModelResponse.model_validate(global_model)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteGlobalModelAdapter(AdminApiAdapter):
|
||||
"""删除 GlobalModel(级联删除所有关联的 Provider 模型实现)"""
|
||||
|
||||
global_model_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
# 先获取 GlobalModel 信息(用于失效缓存)
|
||||
from src.models.database import GlobalModel
|
||||
|
||||
global_model = (
|
||||
context.db.query(GlobalModel).filter(GlobalModel.id == self.global_model_id).first()
|
||||
)
|
||||
model_name = global_model.name if global_model else None
|
||||
|
||||
GlobalModelService.delete_global_model(context.db, self.global_model_id)
|
||||
|
||||
logger.info(f"GlobalModel 已删除: id={self.global_model_id}")
|
||||
|
||||
# 失效相关缓存
|
||||
if model_name:
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_global_model_changed(model_name)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchAssignToProvidersAdapter(AdminApiAdapter):
|
||||
"""批量为 Provider 添加 GlobalModel 实现"""
|
||||
|
||||
global_model_id: str
|
||||
payload: BatchAssignToProvidersRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
result = GlobalModelService.batch_assign_to_providers(
|
||||
db=context.db,
|
||||
global_model_id=self.global_model_id,
|
||||
provider_ids=self.payload.provider_ids,
|
||||
create_models=self.payload.create_models,
|
||||
)
|
||||
|
||||
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
|
||||
|
||||
return BatchAssignToProvidersResponse(**result)
|
||||
303
src/api/admin/models/mappings.py
Normal file
303
src/api/admin/models/mappings.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""模型映射管理 API
|
||||
|
||||
提供模型映射的 CRUD 操作。
|
||||
|
||||
模型映射(Mapping)用于将源模型映射到目标模型,例如:
|
||||
- 请求 gpt-5.1 → Provider A 映射到 gpt-4
|
||||
- 用于处理 Provider 不支持请求模型的情况
|
||||
|
||||
映射必须关联到特定的 Provider。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import (
|
||||
ModelMappingCreate,
|
||||
ModelMappingResponse,
|
||||
ModelMappingUpdate,
|
||||
)
|
||||
from src.models.database import GlobalModel, ModelMapping, Provider, User
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/mappings", tags=["Model Mappings"])
|
||||
|
||||
|
||||
def _serialize_mapping(mapping: ModelMapping) -> ModelMappingResponse:
|
||||
target = mapping.target_global_model
|
||||
provider = mapping.provider
|
||||
scope = "provider" if mapping.provider_id else "global"
|
||||
return ModelMappingResponse(
|
||||
id=mapping.id,
|
||||
source_model=mapping.source_model,
|
||||
target_global_model_id=mapping.target_global_model_id,
|
||||
target_global_model_name=target.name if target else None,
|
||||
target_global_model_display_name=target.display_name if target else None,
|
||||
provider_id=mapping.provider_id,
|
||||
provider_name=provider.name if provider else None,
|
||||
scope=scope,
|
||||
mapping_type=mapping.mapping_type,
|
||||
is_active=mapping.is_active,
|
||||
created_at=mapping.created_at,
|
||||
updated_at=mapping.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[ModelMappingResponse])
|
||||
async def list_mappings(
|
||||
provider_id: Optional[str] = Query(None, description="按 Provider 筛选"),
|
||||
source_model: Optional[str] = Query(None, description="按源模型名筛选"),
|
||||
target_global_model_id: Optional[str] = Query(None, description="按目标模型筛选"),
|
||||
scope: Optional[str] = Query(None, description="global 或 provider"),
|
||||
mapping_type: Optional[str] = Query(None, description="映射类型: alias 或 mapping"),
|
||||
is_active: Optional[bool] = Query(None, description="按状态筛选"),
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取模型映射列表"""
|
||||
query = db.query(ModelMapping).options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
|
||||
if provider_id is not None:
|
||||
query = query.filter(ModelMapping.provider_id == provider_id)
|
||||
if scope == "global":
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
elif scope == "provider":
|
||||
query = query.filter(ModelMapping.provider_id.isnot(None))
|
||||
if mapping_type is not None:
|
||||
query = query.filter(ModelMapping.mapping_type == mapping_type)
|
||||
if source_model:
|
||||
query = query.filter(ModelMapping.source_model.ilike(f"%{source_model}%"))
|
||||
if target_global_model_id is not None:
|
||||
query = query.filter(ModelMapping.target_global_model_id == target_global_model_id)
|
||||
if is_active is not None:
|
||||
query = query.filter(ModelMapping.is_active == is_active)
|
||||
|
||||
mappings = query.offset(skip).limit(limit).all()
|
||||
return [_serialize_mapping(mapping) for mapping in mappings]
|
||||
|
||||
|
||||
@router.get("/{mapping_id}", response_model=ModelMappingResponse)
|
||||
async def get_mapping(
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取单个模型映射"""
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.post("", response_model=ModelMappingResponse, status_code=201)
|
||||
async def create_mapping(
|
||||
data: ModelMappingCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建模型映射"""
|
||||
source_model = data.source_model.strip()
|
||||
if not source_model:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
|
||||
# 验证 mapping_type
|
||||
if data.mapping_type not in ("alias", "mapping"):
|
||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
||||
|
||||
# 验证目标 GlobalModel 存在
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == data.target_global_model_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"目标模型 {data.target_global_model_id} 不存在或未激活"
|
||||
)
|
||||
|
||||
# 验证 Provider 存在
|
||||
provider = None
|
||||
provider_id = data.provider_id
|
||||
if provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_id} 不存在")
|
||||
|
||||
# 检查映射是否已存在(全局或同一 Provider 下不可重复)
|
||||
existing = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
# 创建映射
|
||||
mapping = ModelMapping(
|
||||
id=str(uuid.uuid4()),
|
||||
source_model=source_model,
|
||||
target_global_model_id=data.target_global_model_id,
|
||||
provider_id=provider_id,
|
||||
mapping_type=data.mapping_type,
|
||||
is_active=data.is_active,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
db.add(mapping)
|
||||
db.commit()
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
logger.info(f"创建模型映射: {source_model} -> {target_model.name} "
|
||||
f"(Provider: {provider.name if provider else 'global'}, ID: {mapping.id})")
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.patch("/{mapping_id}", response_model=ModelMappingResponse)
|
||||
async def update_mapping(
|
||||
mapping_id: str,
|
||||
data: ModelMappingUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新模型映射"""
|
||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# 更新 Provider
|
||||
if "provider_id" in update_data:
|
||||
new_provider_id = update_data["provider_id"]
|
||||
if new_provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {new_provider_id} 不存在")
|
||||
mapping.provider_id = new_provider_id
|
||||
|
||||
# 更新目标模型
|
||||
if "target_global_model_id" in update_data:
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == update_data["target_global_model_id"],
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"目标模型 {update_data['target_global_model_id']} 不存在或未激活",
|
||||
)
|
||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
||||
|
||||
# 更新源模型名
|
||||
if "source_model" in update_data:
|
||||
new_source = update_data["source_model"].strip()
|
||||
if not new_source:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
mapping.source_model = new_source
|
||||
|
||||
# 检查唯一约束
|
||||
duplicate = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == mapping.source_model,
|
||||
ModelMapping.provider_id == mapping.provider_id,
|
||||
ModelMapping.id != mapping_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
# 更新映射类型
|
||||
if "mapping_type" in update_data:
|
||||
if update_data["mapping_type"] not in ("alias", "mapping"):
|
||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
||||
mapping.mapping_type = update_data["mapping_type"]
|
||||
|
||||
# 更新状态
|
||||
if "is_active" in update_data:
|
||||
mapping.is_active = update_data["is_active"]
|
||||
|
||||
mapping.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
|
||||
logger.info(f"更新模型映射 (ID: {mapping.id})")
|
||||
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.delete("/{mapping_id}", status_code=204)
|
||||
async def delete_mapping(
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除模型映射"""
|
||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
source_model = mapping.source_model
|
||||
provider_id = mapping.provider_id
|
||||
|
||||
logger.info(f"删除模型映射: {source_model} -> {mapping.target_global_model_id} (ID: {mapping.id})")
|
||||
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return None
|
||||
14
src/api/admin/monitoring/__init__.py
Normal file
14
src/api/admin/monitoring/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Admin monitoring router合集。"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .audit import router as audit_router
|
||||
from .cache import router as cache_router
|
||||
from .trace import router as trace_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(audit_router)
|
||||
router.include_router(cache_router)
|
||||
router.include_router(trace_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
399
src/api/admin/monitoring/audit.py
Normal file
399
src/api/admin/monitoring/audit.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""管理员监控与审计端点。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_query
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import (
|
||||
ApiKey,
|
||||
AuditEventType,
|
||||
AuditLog,
|
||||
Provider,
|
||||
Usage,
|
||||
)
|
||||
from src.models.database import User as DBUser
|
||||
from src.services.health.monitor import HealthMonitor
|
||||
from src.services.system.audit import audit_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/admin/monitoring", tags=["Admin - Monitoring"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/audit-logs")
|
||||
async def get_audit_logs(
|
||||
request: Request,
|
||||
user_id: Optional[str] = Query(None, description="用户ID筛选 (支持UUID)"),
|
||||
event_type: Optional[str] = Query(None, description="事件类型筛选"),
|
||||
days: int = Query(7, description="查询天数"),
|
||||
limit: int = Query(100, description="返回数量限制"),
|
||||
offset: int = Query(0, description="偏移量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminGetAuditLogsAdapter(
|
||||
user_id=user_id,
|
||||
event_type=event_type,
|
||||
days=days,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/system-status")
|
||||
async def get_system_status(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminSystemStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/suspicious-activities")
|
||||
async def get_suspicious_activities(
|
||||
request: Request,
|
||||
hours: int = Query(24, description="时间范围(小时)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminSuspiciousActivitiesAdapter(hours=hours)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/user-behavior/{user_id}")
|
||||
async def analyze_user_behavior(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
days: int = Query(30, description="分析天数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminUserBehaviorAdapter(user_id=user_id, days=days)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/resilience-status")
|
||||
async def get_resilience_status(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminResilienceStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/resilience/error-stats")
|
||||
async def reset_error_stats(request: Request, db: Session = Depends(get_db)):
|
||||
"""Reset resilience error statistics"""
|
||||
adapter = AdminResetErrorStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/resilience/circuit-history")
|
||||
async def get_circuit_history(
|
||||
request: Request,
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminCircuitHistoryAdapter(limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetAuditLogsAdapter(AdminApiAdapter):
|
||||
user_id: Optional[str]
|
||||
event_type: Optional[str]
|
||||
days: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(days=self.days)
|
||||
|
||||
base_query = (
|
||||
db.query(AuditLog, DBUser)
|
||||
.outerjoin(DBUser, AuditLog.user_id == DBUser.id)
|
||||
.filter(AuditLog.created_at >= cutoff_time)
|
||||
)
|
||||
if self.user_id:
|
||||
base_query = base_query.filter(AuditLog.user_id == self.user_id)
|
||||
if self.event_type:
|
||||
base_query = base_query.filter(AuditLog.event_type == self.event_type)
|
||||
|
||||
ordered_query = base_query.order_by(AuditLog.created_at.desc())
|
||||
total, logs_with_users = paginate_query(ordered_query, self.limit, self.offset)
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": log.id,
|
||||
"event_type": log.event_type,
|
||||
"user_id": log.user_id,
|
||||
"user_email": user.email if user else None,
|
||||
"user_username": user.username if user else None,
|
||||
"description": log.description,
|
||||
"ip_address": log.ip_address,
|
||||
"status_code": log.status_code,
|
||||
"error_message": log.error_message,
|
||||
"metadata": log.event_metadata,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
}
|
||||
for log, user in logs_with_users
|
||||
]
|
||||
meta = PaginationMeta(
|
||||
total=total,
|
||||
limit=self.limit,
|
||||
offset=self.offset,
|
||||
count=len(items),
|
||||
)
|
||||
|
||||
payload = build_pagination_payload(
|
||||
items,
|
||||
meta,
|
||||
filters={
|
||||
"user_id": self.user_id,
|
||||
"event_type": self.event_type,
|
||||
"days": self.days,
|
||||
},
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="monitor_audit_logs",
|
||||
filter_user_id=self.user_id,
|
||||
filter_event_type=self.event_type,
|
||||
days=self.days,
|
||||
limit=self.limit,
|
||||
offset=self.offset,
|
||||
total=total,
|
||||
result_count=meta.count,
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
class AdminSystemStatusAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
total_users = db.query(func.count(DBUser.id)).scalar()
|
||||
active_users = db.query(func.count(DBUser.id)).filter(DBUser.is_active.is_(True)).scalar()
|
||||
|
||||
total_providers = db.query(func.count(Provider.id)).scalar()
|
||||
active_providers = (
|
||||
db.query(func.count(Provider.id)).filter(Provider.is_active.is_(True)).scalar()
|
||||
)
|
||||
|
||||
total_api_keys = db.query(func.count(ApiKey.id)).scalar()
|
||||
active_api_keys = (
|
||||
db.query(func.count(ApiKey.id)).filter(ApiKey.is_active.is_(True)).scalar()
|
||||
)
|
||||
|
||||
today_start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
today_requests = (
|
||||
db.query(func.count(Usage.id)).filter(Usage.created_at >= today_start).scalar()
|
||||
)
|
||||
today_tokens = (
|
||||
db.query(func.sum(Usage.total_tokens)).filter(Usage.created_at >= today_start).scalar()
|
||||
or 0
|
||||
)
|
||||
today_cost = (
|
||||
db.query(func.sum(Usage.total_cost_usd))
|
||||
.filter(Usage.created_at >= today_start)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
recent_errors = (
|
||||
db.query(AuditLog)
|
||||
.filter(
|
||||
AuditLog.event_type.in_(
|
||||
[
|
||||
AuditEventType.REQUEST_FAILED.value,
|
||||
AuditEventType.SUSPICIOUS_ACTIVITY.value,
|
||||
]
|
||||
),
|
||||
AuditLog.created_at >= datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="system_status_snapshot",
|
||||
total_users=int(total_users or 0),
|
||||
active_users=int(active_users or 0),
|
||||
total_providers=int(total_providers or 0),
|
||||
active_providers=int(active_providers or 0),
|
||||
total_api_keys=int(total_api_keys or 0),
|
||||
active_api_keys=int(active_api_keys or 0),
|
||||
today_requests=int(today_requests or 0),
|
||||
today_tokens=int(today_tokens or 0),
|
||||
today_cost=float(today_cost or 0.0),
|
||||
recent_errors=int(recent_errors or 0),
|
||||
)
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"users": {"total": total_users, "active": active_users},
|
||||
"providers": {"total": total_providers, "active": active_providers},
|
||||
"api_keys": {"total": total_api_keys, "active": active_api_keys},
|
||||
"today_stats": {
|
||||
"requests": today_requests,
|
||||
"tokens": today_tokens,
|
||||
"cost_usd": f"${today_cost:.4f}",
|
||||
},
|
||||
"recent_errors": recent_errors,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminSuspiciousActivitiesAdapter(AdminApiAdapter):
|
||||
hours: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
activities = audit_service.get_suspicious_activities(db=db, hours=self.hours, limit=100)
|
||||
response = {
|
||||
"activities": [
|
||||
{
|
||||
"id": activity.id,
|
||||
"event_type": activity.event_type,
|
||||
"user_id": activity.user_id,
|
||||
"description": activity.description,
|
||||
"ip_address": activity.ip_address,
|
||||
"metadata": activity.event_metadata,
|
||||
"created_at": activity.created_at.isoformat() if activity.created_at else None,
|
||||
}
|
||||
for activity in activities
|
||||
],
|
||||
"count": len(activities),
|
||||
"time_range_hours": self.hours,
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="monitor_suspicious_activity",
|
||||
hours=self.hours,
|
||||
result_count=len(activities),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUserBehaviorAdapter(AdminApiAdapter):
|
||||
user_id: str
|
||||
days: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
result = audit_service.analyze_user_behavior(
|
||||
db=context.db,
|
||||
user_id=self.user_id,
|
||||
days=self.days,
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="monitor_user_behavior",
|
||||
target_user_id=self.user_id,
|
||||
days=self.days,
|
||||
contains_summary=bool(result),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class AdminResilienceStatusAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
from src.core.resilience import resilience_manager
|
||||
except ImportError as exc:
|
||||
raise HTTPException(status_code=503, detail="韧性管理系统未启用") from exc
|
||||
|
||||
error_stats = resilience_manager.get_error_stats()
|
||||
recent_errors = [
|
||||
{
|
||||
"error_id": info["error_id"],
|
||||
"error_type": info["error_type"],
|
||||
"operation": info["operation"],
|
||||
"timestamp": info["timestamp"].isoformat(),
|
||||
"context": info.get("context", {}),
|
||||
}
|
||||
for info in resilience_manager.last_errors[-10:]
|
||||
]
|
||||
|
||||
total_errors = error_stats.get("total_errors", 0)
|
||||
circuit_breakers = error_stats.get("circuit_breakers", {})
|
||||
circuit_breakers_open = sum(
|
||||
1 for status in circuit_breakers.values() if status.get("state") == "open"
|
||||
)
|
||||
health_score = max(0, 100 - (total_errors * 2) - (circuit_breakers_open * 20))
|
||||
|
||||
response = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"health_score": health_score,
|
||||
"status": (
|
||||
"healthy" if health_score > 80 else "degraded" if health_score > 50 else "critical"
|
||||
),
|
||||
"error_statistics": error_stats,
|
||||
"recent_errors": recent_errors,
|
||||
"recommendations": _get_health_recommendations(error_stats, health_score),
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="resilience_status",
|
||||
health_score=health_score,
|
||||
error_total=error_stats.get("total_errors") if isinstance(error_stats, dict) else None,
|
||||
open_circuit_breakers=circuit_breakers_open,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class AdminResetErrorStatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
from src.core.resilience import resilience_manager
|
||||
except ImportError as exc:
|
||||
raise HTTPException(status_code=503, detail="韧性管理系统未启用") from exc
|
||||
|
||||
old_stats = resilience_manager.get_error_stats()
|
||||
resilience_manager.error_stats.clear()
|
||||
resilience_manager.last_errors.clear()
|
||||
|
||||
logger.info(f"管理员 {context.user.email if context.user else 'unknown'} 重置了错误统计")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="reset_error_stats",
|
||||
previous_total_errors=(
|
||||
old_stats.get("total_errors") if isinstance(old_stats, dict) else None
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "错误统计已重置",
|
||||
"previous_stats": old_stats,
|
||||
"reset_by": context.user.email if context.user else None,
|
||||
"reset_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class AdminCircuitHistoryAdapter(AdminApiAdapter):
|
||||
def __init__(self, limit: int = 50):
|
||||
super().__init__()
|
||||
self.limit = limit
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
history = HealthMonitor.get_circuit_history(self.limit)
|
||||
context.add_audit_metadata(
|
||||
action="circuit_history",
|
||||
limit=self.limit,
|
||||
result_count=len(history),
|
||||
)
|
||||
return {"items": history, "count": len(history)}
|
||||
|
||||
|
||||
def _get_health_recommendations(error_stats: dict, health_score: int) -> List[str]:
|
||||
recommendations: List[str] = []
|
||||
if health_score < 50:
|
||||
recommendations.append("系统健康状况严重,请立即检查错误日志")
|
||||
if error_stats.get("total_errors", 0) > 100:
|
||||
recommendations.append("错误频率过高,建议检查系统配置和外部依赖")
|
||||
|
||||
circuit_breakers = error_stats.get("circuit_breakers", {})
|
||||
open_breakers = [k for k, v in circuit_breakers.items() if v.get("state") == "open"]
|
||||
if open_breakers:
|
||||
recommendations.append(f"以下服务熔断器已打开:{', '.join(open_breakers)}")
|
||||
|
||||
if health_score > 90:
|
||||
recommendations.append("系统运行良好")
|
||||
return recommendations
|
||||
871
src/api/admin/monitoring/cache.py
Normal file
871
src/api/admin/monitoring/cache.py
Normal file
@@ -0,0 +1,871 @@
|
||||
"""
|
||||
缓存监控端点
|
||||
|
||||
提供缓存亲和性统计、管理和监控功能
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.clients.redis_client import get_redis_client_sync
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import ApiKey, User
|
||||
from src.services.cache.affinity_manager import get_affinity_manager
|
||||
from src.services.cache.aware_scheduler import get_cache_aware_scheduler
|
||||
|
||||
router = APIRouter(prefix="/api/admin/monitoring/cache", tags=["Admin - Monitoring: Cache"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
def mask_api_key(api_key: Optional[str], prefix_len: int = 8, suffix_len: int = 4) -> Optional[str]:
|
||||
"""
|
||||
脱敏 API Key,显示前缀 + 星号 + 后缀
|
||||
例如: sk-jhiId-xxxxxxxxxxxAABB -> sk-jhiId-********AABB
|
||||
|
||||
Args:
|
||||
api_key: 原始 API Key
|
||||
prefix_len: 显示的前缀长度,默认 8
|
||||
suffix_len: 显示的后缀长度,默认 4
|
||||
"""
|
||||
if not api_key:
|
||||
return None
|
||||
total_visible = prefix_len + suffix_len
|
||||
if len(api_key) <= total_visible:
|
||||
# Key 太短,直接返回部分内容 + 星号
|
||||
return api_key[:prefix_len] + "********"
|
||||
return f"{api_key[:prefix_len]}********{api_key[-suffix_len:]}"
|
||||
|
||||
|
||||
def decrypt_and_mask(encrypted_key: Optional[str], prefix_len: int = 8) -> Optional[str]:
|
||||
"""
|
||||
解密 API Key 后脱敏显示
|
||||
|
||||
Args:
|
||||
encrypted_key: 加密后的 API Key
|
||||
prefix_len: 显示的前缀长度
|
||||
"""
|
||||
if not encrypted_key:
|
||||
return None
|
||||
try:
|
||||
decrypted = crypto_service.decrypt(encrypted_key)
|
||||
return mask_api_key(decrypted, prefix_len)
|
||||
except Exception:
|
||||
# 解密失败时返回 None
|
||||
return None
|
||||
|
||||
|
||||
def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
|
||||
"""
|
||||
将用户标识符(username/email/user_id/api_key_id)解析为 user_id
|
||||
|
||||
支持的输入格式:
|
||||
1. User UUID (36位,带横杠)
|
||||
2. Username (用户名)
|
||||
3. Email (邮箱)
|
||||
4. API Key ID (36位UUID)
|
||||
|
||||
返回:
|
||||
- user_id (UUID字符串) 或 None
|
||||
"""
|
||||
identifier = identifier.strip()
|
||||
|
||||
# 1. 先尝试作为 User UUID 查询
|
||||
user = db.query(User).filter(User.id == identifier).first()
|
||||
if user:
|
||||
logger.debug(f"通过User ID解析: {identifier[:8]}... -> {user.username}")
|
||||
return user.id
|
||||
|
||||
# 2. 尝试作为 Username 查询
|
||||
user = db.query(User).filter(User.username == identifier).first()
|
||||
if user:
|
||||
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...")
|
||||
return user.id
|
||||
|
||||
# 3. 尝试作为 Email 查询
|
||||
user = db.query(User).filter(User.email == identifier).first()
|
||||
if user:
|
||||
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...")
|
||||
return user.id
|
||||
|
||||
# 4. 尝试作为 API Key ID 查询
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first()
|
||||
if api_key:
|
||||
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...")
|
||||
return api_key.user_id
|
||||
|
||||
# 无法识别
|
||||
logger.debug(f"无法识别的用户标识符: {identifier}")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_cache_stats(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取缓存亲和性统计信息
|
||||
|
||||
返回:
|
||||
- 缓存命中率
|
||||
- 缓存用户数
|
||||
- Provider切换次数
|
||||
- Key切换次数
|
||||
- 缓存预留配置
|
||||
"""
|
||||
adapter = AdminCacheStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/affinity/{user_identifier}")
|
||||
async def get_user_affinity(
|
||||
user_identifier: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
查询指定用户的所有缓存亲和性
|
||||
|
||||
参数:
|
||||
- user_identifier: 用户标识符,支持以下格式:
|
||||
* 用户名 (username),如: yuanhonghu
|
||||
* 邮箱 (email),如: user@example.com
|
||||
* 用户UUID (user_id),如: 550e8400-e29b-41d4-a716-446655440000
|
||||
* API Key ID,如: 660e8400-e29b-41d4-a716-446655440000
|
||||
|
||||
返回:
|
||||
- 用户信息
|
||||
- 所有端点的缓存亲和性列表(每个端点一条记录)
|
||||
"""
|
||||
adapter = AdminGetUserAffinityAdapter(user_identifier=user_identifier)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/affinities")
|
||||
async def list_affinities(
|
||||
request: Request,
|
||||
keyword: Optional[str] = None,
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回数量限制"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取所有缓存亲和性列表,可选按关键词过滤
|
||||
|
||||
参数:
|
||||
- keyword: 可选,支持用户名/邮箱/User ID/API Key ID 或模糊匹配
|
||||
"""
|
||||
adapter = AdminListAffinitiesAdapter(keyword=keyword, limit=limit, offset=offset)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/users/{user_identifier}")
|
||||
async def clear_user_cache(
|
||||
user_identifier: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Clear cache affinity for a specific user
|
||||
|
||||
Parameters:
|
||||
- user_identifier: User identifier (username, email, user_id, or API Key ID)
|
||||
"""
|
||||
adapter = AdminClearUserCacheAdapter(user_identifier=user_identifier)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("")
|
||||
async def clear_all_cache(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Clear all cache affinities
|
||||
|
||||
Warning: This affects all users, use with caution
|
||||
"""
|
||||
adapter = AdminClearAllCacheAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/providers/{provider_id}")
|
||||
async def clear_provider_cache(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Clear cache affinities for a specific provider
|
||||
|
||||
Parameters:
|
||||
- provider_id: Provider ID
|
||||
"""
|
||||
adapter = AdminClearProviderCacheAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_cache_config(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取缓存相关配置
|
||||
|
||||
返回:
|
||||
- 缓存TTL
|
||||
- 缓存预留比例
|
||||
"""
|
||||
adapter = AdminCacheConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/metrics", response_class=PlainTextResponse)
|
||||
async def get_cache_metrics(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
|
||||
"""
|
||||
adapter = AdminCacheMetricsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- 缓存监控适配器 --------
|
||||
|
||||
|
||||
class AdminCacheStatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
||||
stats = await scheduler.get_stats()
|
||||
logger.info("缓存统计信息查询成功")
|
||||
context.add_audit_metadata(
|
||||
action="cache_stats",
|
||||
scheduler=stats.get("scheduler"),
|
||||
total_affinities=stats.get("total_affinities"),
|
||||
cache_hit_rate=stats.get("cache_hit_rate"),
|
||||
provider_switches=stats.get("provider_switches"),
|
||||
)
|
||||
return {"status": "ok", "data": stats}
|
||||
except Exception as exc:
|
||||
logger.exception(f"获取缓存统计信息失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"获取缓存统计失败: {exc}")
|
||||
|
||||
|
||||
class AdminCacheMetricsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
||||
stats = await scheduler.get_stats()
|
||||
payload = self._format_prometheus(stats)
|
||||
context.add_audit_metadata(
|
||||
action="cache_metrics_export",
|
||||
scheduler=stats.get("scheduler"),
|
||||
metrics_lines=payload.count("\n"),
|
||||
)
|
||||
return PlainTextResponse(payload)
|
||||
except Exception as exc:
|
||||
logger.exception(f"导出缓存指标失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"导出缓存指标失败: {exc}")
|
||||
|
||||
def _format_prometheus(self, stats: Dict[str, Any]) -> str:
|
||||
"""
|
||||
将 scheduler/affinity 指标转换为 Prometheus 文本格式。
|
||||
"""
|
||||
scheduler_metrics = stats.get("scheduler_metrics", {})
|
||||
affinity_stats = stats.get("affinity_stats", {})
|
||||
|
||||
metric_map: List[Tuple[str, str, float]] = [
|
||||
(
|
||||
"cache_scheduler_total_batches",
|
||||
"Total batches pulled from provider list",
|
||||
float(scheduler_metrics.get("total_batches", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_last_batch_size",
|
||||
"Size of the latest candidate batch",
|
||||
float(scheduler_metrics.get("last_batch_size", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_total_candidates",
|
||||
"Total candidates enumerated by scheduler",
|
||||
float(scheduler_metrics.get("total_candidates", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_last_candidate_count",
|
||||
"Number of candidates in the most recent batch",
|
||||
float(scheduler_metrics.get("last_candidate_count", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_cache_hits",
|
||||
"Cache hits counted during scheduling",
|
||||
float(scheduler_metrics.get("cache_hits", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_cache_misses",
|
||||
"Cache misses counted during scheduling",
|
||||
float(scheduler_metrics.get("cache_misses", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_cache_hit_rate",
|
||||
"Cache hit rate during scheduling",
|
||||
float(scheduler_metrics.get("cache_hit_rate", 0.0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_concurrency_denied",
|
||||
"Times candidate rejected due to concurrency limits",
|
||||
float(scheduler_metrics.get("concurrency_denied", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_avg_candidates_per_batch",
|
||||
"Average candidates per batch",
|
||||
float(scheduler_metrics.get("avg_candidates_per_batch", 0.0)),
|
||||
),
|
||||
]
|
||||
|
||||
affinity_map: List[Tuple[str, str, float]] = [
|
||||
(
|
||||
"cache_affinity_total",
|
||||
"Total cache affinities stored",
|
||||
float(affinity_stats.get("total_affinities", 0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_hits",
|
||||
"Affinity cache hits",
|
||||
float(affinity_stats.get("cache_hits", 0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_misses",
|
||||
"Affinity cache misses",
|
||||
float(affinity_stats.get("cache_misses", 0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_hit_rate",
|
||||
"Affinity cache hit rate",
|
||||
float(affinity_stats.get("cache_hit_rate", 0.0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_invalidations",
|
||||
"Affinity invalidations",
|
||||
float(affinity_stats.get("cache_invalidations", 0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_provider_switches",
|
||||
"Affinity provider switches",
|
||||
float(affinity_stats.get("provider_switches", 0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_key_switches",
|
||||
"Affinity key switches",
|
||||
float(affinity_stats.get("key_switches", 0)),
|
||||
),
|
||||
]
|
||||
|
||||
lines = []
|
||||
for name, help_text, value in metric_map + affinity_map:
|
||||
lines.append(f"# HELP {name} {help_text}")
|
||||
lines.append(f"# TYPE {name} gauge")
|
||||
lines.append(f"{name} {value}")
|
||||
|
||||
scheduler_name = stats.get("scheduler", "cache_aware")
|
||||
lines.append(f'cache_scheduler_info{{scheduler="{scheduler_name}"}} 1')
|
||||
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetUserAffinityAdapter(AdminApiAdapter):
|
||||
user_identifier: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
try:
|
||||
user_id = resolve_user_identifier(db, self.user_identifier)
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"无法识别的用户标识符: {self.user_identifier}。支持用户名、邮箱、User ID或API Key ID",
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
redis_client = get_redis_client_sync()
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
|
||||
# 获取该用户的所有缓存亲和性
|
||||
all_affinities = await affinity_mgr.list_affinities()
|
||||
user_affinities = [aff for aff in all_affinities if aff.get("user_id") == user_id]
|
||||
|
||||
if not user_affinities:
|
||||
response = {
|
||||
"status": "not_found",
|
||||
"message": f"用户 {user.username} ({user.email}) 没有缓存亲和性",
|
||||
"user_info": {
|
||||
"user_id": user_id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
},
|
||||
"affinities": [],
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="cache_user_affinity",
|
||||
user_identifier=self.user_identifier,
|
||||
resolved_user_id=user_id,
|
||||
affinity_count=0,
|
||||
status="not_found",
|
||||
)
|
||||
return response
|
||||
|
||||
response = {
|
||||
"status": "ok",
|
||||
"user_info": {
|
||||
"user_id": user_id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
},
|
||||
"affinities": [
|
||||
{
|
||||
"provider_id": aff["provider_id"],
|
||||
"endpoint_id": aff["endpoint_id"],
|
||||
"key_id": aff["key_id"],
|
||||
"api_format": aff.get("api_format"),
|
||||
"model_name": aff.get("model_name"),
|
||||
"created_at": aff["created_at"],
|
||||
"expire_at": aff["expire_at"],
|
||||
"request_count": aff["request_count"],
|
||||
}
|
||||
for aff in user_affinities
|
||||
],
|
||||
"total_endpoints": len(user_affinities),
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="cache_user_affinity",
|
||||
user_identifier=self.user_identifier,
|
||||
resolved_user_id=user_id,
|
||||
affinity_count=len(user_affinities),
|
||||
status="ok",
|
||||
)
|
||||
return response
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"查询用户缓存亲和性失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {exc}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListAffinitiesAdapter(AdminApiAdapter):
|
||||
keyword: Optional[str]
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
redis_client = get_redis_client_sync()
|
||||
if not redis_client:
|
||||
raise HTTPException(status_code=503, detail="Redis未初始化,无法获取缓存亲和性")
|
||||
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
matched_user_id = None
|
||||
matched_api_key_id = None
|
||||
raw_affinities: List[Dict[str, Any]] = []
|
||||
|
||||
if self.keyword:
|
||||
# 首先检查是否是 API Key ID(affinity_key)
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.keyword).first()
|
||||
if api_key:
|
||||
# 直接通过 affinity_key 过滤
|
||||
matched_api_key_id = str(api_key.id)
|
||||
matched_user_id = str(api_key.user_id)
|
||||
all_affinities = await affinity_mgr.list_affinities()
|
||||
raw_affinities = [
|
||||
aff for aff in all_affinities if aff.get("affinity_key") == matched_api_key_id
|
||||
]
|
||||
else:
|
||||
# 尝试解析为用户标识
|
||||
user_id = resolve_user_identifier(db, self.keyword)
|
||||
if user_id:
|
||||
matched_user_id = user_id
|
||||
# 获取该用户所有的 API Key ID
|
||||
user_api_keys = db.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
user_api_key_ids = {str(k.id) for k in user_api_keys}
|
||||
# 过滤出该用户所有 API Key 的亲和性
|
||||
all_affinities = await affinity_mgr.list_affinities()
|
||||
raw_affinities = [
|
||||
aff for aff in all_affinities if aff.get("affinity_key") in user_api_key_ids
|
||||
]
|
||||
else:
|
||||
# 关键词不是有效标识,返回所有亲和性(后续会进行模糊匹配)
|
||||
raw_affinities = await affinity_mgr.list_affinities()
|
||||
else:
|
||||
raw_affinities = await affinity_mgr.list_affinities()
|
||||
|
||||
# 收集所有 affinity_key (API Key ID)
|
||||
affinity_keys = {
|
||||
item.get("affinity_key") for item in raw_affinities if item.get("affinity_key")
|
||||
}
|
||||
|
||||
# 批量查询用户 API Key 信息
|
||||
user_api_key_map: Dict[str, ApiKey] = {}
|
||||
if affinity_keys:
|
||||
user_api_keys = db.query(ApiKey).filter(ApiKey.id.in_(list(affinity_keys))).all()
|
||||
user_api_key_map = {str(k.id): k for k in user_api_keys}
|
||||
|
||||
# 收集所有 user_id
|
||||
user_ids = {str(k.user_id) for k in user_api_key_map.values()}
|
||||
user_map: Dict[str, User] = {}
|
||||
if user_ids:
|
||||
users = db.query(User).filter(User.id.in_(list(user_ids))).all()
|
||||
user_map = {str(user.id): user for user in users}
|
||||
|
||||
# 收集所有provider_id、endpoint_id、key_id
|
||||
provider_ids = {
|
||||
item.get("provider_id") for item in raw_affinities if item.get("provider_id")
|
||||
}
|
||||
endpoint_ids = {
|
||||
item.get("endpoint_id") for item in raw_affinities if item.get("endpoint_id")
|
||||
}
|
||||
key_ids = {item.get("key_id") for item in raw_affinities if item.get("key_id")}
|
||||
|
||||
# 批量查询Provider、Endpoint、Key信息
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
provider_map = {}
|
||||
if provider_ids:
|
||||
providers = db.query(Provider).filter(Provider.id.in_(list(provider_ids))).all()
|
||||
provider_map = {p.id: p for p in providers}
|
||||
|
||||
endpoint_map = {}
|
||||
if endpoint_ids:
|
||||
endpoints = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id.in_(list(endpoint_ids))).all()
|
||||
)
|
||||
endpoint_map = {e.id: e for e in endpoints}
|
||||
|
||||
key_map = {}
|
||||
if key_ids:
|
||||
keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.id.in_(list(key_ids))).all()
|
||||
key_map = {k.id: k for k in keys}
|
||||
|
||||
# 收集所有 model_name(实际存储的是 global_model_id)并批量查询 GlobalModel
|
||||
from src.models.database import GlobalModel
|
||||
|
||||
global_model_ids = {
|
||||
item.get("model_name") for item in raw_affinities if item.get("model_name")
|
||||
}
|
||||
global_model_map: Dict[str, GlobalModel] = {}
|
||||
if global_model_ids:
|
||||
# model_name 可能是 UUID 格式的 global_model_id,也可能是原始模型名称
|
||||
global_models = db.query(GlobalModel).filter(
|
||||
GlobalModel.id.in_(list(global_model_ids))
|
||||
).all()
|
||||
global_model_map = {str(gm.id): gm for gm in global_models}
|
||||
|
||||
keyword_lower = self.keyword.lower() if self.keyword else None
|
||||
items = []
|
||||
for affinity in raw_affinities:
|
||||
affinity_key = affinity.get("affinity_key")
|
||||
if not affinity_key:
|
||||
continue
|
||||
|
||||
# 通过 affinity_key(API Key ID)找到用户 API Key 和用户
|
||||
user_api_key = user_api_key_map.get(affinity_key)
|
||||
user = user_map.get(str(user_api_key.user_id)) if user_api_key else None
|
||||
user_id = str(user_api_key.user_id) if user_api_key else None
|
||||
|
||||
provider_id = affinity.get("provider_id")
|
||||
endpoint_id = affinity.get("endpoint_id")
|
||||
key_id = affinity.get("key_id")
|
||||
|
||||
provider = provider_map.get(provider_id)
|
||||
endpoint = endpoint_map.get(endpoint_id)
|
||||
key = key_map.get(key_id)
|
||||
|
||||
# 用户 API Key 脱敏显示(解密 key_encrypted 后脱敏)
|
||||
user_api_key_masked = None
|
||||
if user_api_key and user_api_key.key_encrypted:
|
||||
user_api_key_masked = decrypt_and_mask(user_api_key.key_encrypted)
|
||||
|
||||
# Provider Key 脱敏显示(解密 api_key 后脱敏)
|
||||
provider_key_masked = None
|
||||
if key and key.api_key:
|
||||
provider_key_masked = decrypt_and_mask(key.api_key)
|
||||
|
||||
item = {
|
||||
"affinity_key": affinity_key,
|
||||
"user_api_key_name": user_api_key.name if user_api_key else None,
|
||||
"user_api_key_prefix": user_api_key_masked,
|
||||
"is_standalone": user_api_key.is_standalone if user_api_key else False,
|
||||
"user_id": user_id,
|
||||
"username": user.username if user else None,
|
||||
"email": user.email if user else None,
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.display_name if provider else None,
|
||||
"endpoint_id": endpoint_id,
|
||||
"endpoint_api_format": (
|
||||
endpoint.api_format if endpoint and endpoint.api_format else None
|
||||
),
|
||||
"endpoint_url": endpoint.base_url if endpoint else None,
|
||||
"key_id": key_id,
|
||||
"key_name": key.name if key else None,
|
||||
"key_prefix": provider_key_masked,
|
||||
"rate_multiplier": key.rate_multiplier if key else 1.0,
|
||||
"model_name": (
|
||||
global_model_map.get(affinity.get("model_name")).name
|
||||
if affinity.get("model_name") and global_model_map.get(affinity.get("model_name"))
|
||||
else affinity.get("model_name") # 如果找不到 GlobalModel,显示原始值
|
||||
),
|
||||
"model_display_name": (
|
||||
global_model_map.get(affinity.get("model_name")).display_name
|
||||
if affinity.get("model_name") and global_model_map.get(affinity.get("model_name"))
|
||||
else None
|
||||
),
|
||||
"api_format": affinity.get("api_format"),
|
||||
"created_at": affinity.get("created_at"),
|
||||
"expire_at": affinity.get("expire_at"),
|
||||
"request_count": affinity.get("request_count", 0),
|
||||
}
|
||||
|
||||
if keyword_lower and not matched_user_id and not matched_api_key_id:
|
||||
searchable = [
|
||||
item["affinity_key"],
|
||||
item["user_api_key_name"] or "",
|
||||
item["user_id"] or "",
|
||||
item["username"] or "",
|
||||
item["email"] or "",
|
||||
item["provider_id"] or "",
|
||||
item["key_id"] or "",
|
||||
]
|
||||
if not any(keyword_lower in str(value).lower() for value in searchable if value):
|
||||
continue
|
||||
|
||||
items.append(item)
|
||||
|
||||
items.sort(key=lambda x: x.get("expire_at") or 0, reverse=True)
|
||||
paged_items, meta = paginate_sequence(items, self.limit, self.offset)
|
||||
payload = build_pagination_payload(
|
||||
paged_items,
|
||||
meta,
|
||||
matched_user_id=matched_user_id,
|
||||
)
|
||||
response = {
|
||||
"status": "ok",
|
||||
"data": payload,
|
||||
}
|
||||
result_count = meta.count if hasattr(meta, "count") else len(paged_items)
|
||||
context.add_audit_metadata(
|
||||
action="cache_affinity_list",
|
||||
keyword=self.keyword,
|
||||
matched_user_id=matched_user_id,
|
||||
matched_api_key_id=matched_api_key_id,
|
||||
limit=self.limit,
|
||||
offset=self.offset,
|
||||
result_count=result_count,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminClearUserCacheAdapter(AdminApiAdapter):
|
||||
user_identifier: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
|
||||
# 首先检查是否直接是 API Key ID (affinity_key)
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.user_identifier).first()
|
||||
if api_key:
|
||||
# 直接按 affinity_key 清除
|
||||
affinity_key = str(api_key.id)
|
||||
user = db.query(User).filter(User.id == api_key.user_id).first()
|
||||
|
||||
all_affinities = await affinity_mgr.list_affinities()
|
||||
target_affinities = [
|
||||
aff for aff in all_affinities if aff.get("affinity_key") == affinity_key
|
||||
]
|
||||
|
||||
count = 0
|
||||
for aff in target_affinities:
|
||||
api_format = aff.get("api_format")
|
||||
model_name = aff.get("model_name")
|
||||
endpoint_id = aff.get("endpoint_id")
|
||||
if api_format and model_name:
|
||||
await affinity_mgr.invalidate_affinity(
|
||||
affinity_key, api_format, model_name, endpoint_id=endpoint_id
|
||||
)
|
||||
count += 1
|
||||
|
||||
logger.info(f"已清除API Key缓存亲和性: api_key_name={api_key.name}, affinity_key={affinity_key[:8]}..., 清除数量={count}")
|
||||
|
||||
response = {
|
||||
"status": "ok",
|
||||
"message": f"已清除 API Key {api_key.name} 的缓存亲和性",
|
||||
"user_info": {
|
||||
"user_id": str(api_key.user_id),
|
||||
"username": user.username if user else None,
|
||||
"email": user.email if user else None,
|
||||
"api_key_id": affinity_key,
|
||||
"api_key_name": api_key.name,
|
||||
},
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="cache_clear_api_key",
|
||||
user_identifier=self.user_identifier,
|
||||
resolved_api_key_id=affinity_key,
|
||||
cleared_count=count,
|
||||
)
|
||||
return response
|
||||
|
||||
# 如果不是 API Key ID,尝试解析为用户标识
|
||||
user_id = resolve_user_identifier(db, self.user_identifier)
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"无法识别的标识符: {self.user_identifier}。支持用户名、邮箱、User ID或API Key ID",
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
# 获取该用户所有的 API Key
|
||||
user_api_keys = db.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
user_api_key_ids = {str(k.id) for k in user_api_keys}
|
||||
|
||||
# 获取该用户所有 API Key 的缓存亲和性并逐个失效
|
||||
all_affinities = await affinity_mgr.list_affinities()
|
||||
user_affinities = [
|
||||
aff for aff in all_affinities if aff.get("affinity_key") in user_api_key_ids
|
||||
]
|
||||
|
||||
count = 0
|
||||
for aff in user_affinities:
|
||||
affinity_key = aff.get("affinity_key")
|
||||
api_format = aff.get("api_format")
|
||||
model_name = aff.get("model_name")
|
||||
endpoint_id = aff.get("endpoint_id")
|
||||
if affinity_key and api_format and model_name:
|
||||
await affinity_mgr.invalidate_affinity(
|
||||
affinity_key, api_format, model_name, endpoint_id=endpoint_id
|
||||
)
|
||||
count += 1
|
||||
|
||||
logger.info(f"已清除用户缓存亲和性: username={user.username}, user_id={user_id[:8]}..., 清除数量={count}")
|
||||
|
||||
response = {
|
||||
"status": "ok",
|
||||
"message": f"已清除用户 {user.username} 的所有缓存亲和性",
|
||||
"user_info": {"user_id": user_id, "username": user.username, "email": user.email},
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="cache_clear_user",
|
||||
user_identifier=self.user_identifier,
|
||||
resolved_user_id=user_id,
|
||||
cleared_count=count,
|
||||
)
|
||||
return response
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除用户缓存亲和性失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
class AdminClearAllCacheAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
count = await affinity_mgr.clear_all()
|
||||
logger.warning(f"已清除所有缓存亲和性(管理员操作): {count} 个")
|
||||
context.add_audit_metadata(
|
||||
action="cache_clear_all",
|
||||
cleared_count=count,
|
||||
)
|
||||
return {"status": "ok", "message": "已清除所有缓存亲和性", "count": count}
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除所有缓存亲和性失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminClearProviderCacheAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
count = await affinity_mgr.invalidate_all_for_provider(self.provider_id)
|
||||
logger.info(f"已清除Provider缓存亲和性: provider_id={self.provider_id[:8]}..., count={count}")
|
||||
context.add_audit_metadata(
|
||||
action="cache_clear_provider",
|
||||
provider_id=self.provider_id,
|
||||
cleared_count=count,
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": "已清除Provider的缓存亲和性",
|
||||
"provider_id": self.provider_id,
|
||||
"count": count,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除Provider缓存亲和性失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
class AdminCacheConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.services.cache.affinity_manager import CacheAffinityManager
|
||||
from src.services.cache.aware_scheduler import CacheAwareScheduler
|
||||
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
|
||||
|
||||
# 获取动态预留管理器的配置
|
||||
reservation_manager = get_adaptive_reservation_manager()
|
||||
reservation_stats = reservation_manager.get_stats()
|
||||
|
||||
response = {
|
||||
"status": "ok",
|
||||
"data": {
|
||||
"cache_ttl_seconds": CacheAffinityManager.DEFAULT_CACHE_TTL,
|
||||
"cache_reservation_ratio": CacheAwareScheduler.CACHE_RESERVATION_RATIO,
|
||||
"dynamic_reservation": {
|
||||
"enabled": True,
|
||||
"config": reservation_stats["config"],
|
||||
"description": {
|
||||
"probe_phase_requests": "探测阶段请求数阈值",
|
||||
"probe_reservation": "探测阶段预留比例",
|
||||
"stable_min_reservation": "稳定阶段最小预留比例",
|
||||
"stable_max_reservation": "稳定阶段最大预留比例",
|
||||
"low_load_threshold": "低负载阈值(低于此值使用最小预留)",
|
||||
"high_load_threshold": "高负载阈值(高于此值根据置信度使用较高预留)",
|
||||
},
|
||||
},
|
||||
"description": {
|
||||
"cache_ttl": "缓存亲和性有效期(秒)",
|
||||
"cache_reservation_ratio": "静态预留比例(已被动态预留替代)",
|
||||
"dynamic_reservation": "动态预留机制配置",
|
||||
},
|
||||
},
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="cache_config",
|
||||
cache_ttl_seconds=CacheAffinityManager.DEFAULT_CACHE_TTL,
|
||||
cache_reservation_ratio=CacheAwareScheduler.CACHE_RESERVATION_RATIO,
|
||||
dynamic_reservation_enabled=True,
|
||||
)
|
||||
return response
|
||||
280
src/api/admin/monitoring/trace.py
Normal file
280
src/api/admin/monitoring/trace.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
请求链路追踪 API 端点
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderEndpoint, ProviderAPIKey
|
||||
from src.services.request.candidate import RequestCandidateService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/monitoring/trace", tags=["Admin - Monitoring: Trace"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
class CandidateResponse(BaseModel):
|
||||
"""候选记录响应"""
|
||||
|
||||
id: str
|
||||
request_id: str
|
||||
candidate_index: int
|
||||
retry_index: int = 0 # 重试序号(从0开始)
|
||||
provider_id: Optional[str] = None
|
||||
provider_name: Optional[str] = None
|
||||
provider_website: Optional[str] = None # Provider 官网
|
||||
endpoint_id: Optional[str] = None
|
||||
endpoint_name: Optional[str] = None # 端点显示名称(api_format)
|
||||
key_id: Optional[str] = None
|
||||
key_name: Optional[str] = None # 密钥名称
|
||||
key_preview: Optional[str] = None # 密钥脱敏预览(如 sk-***abc)
|
||||
key_capabilities: Optional[dict] = None # Key 支持的能力
|
||||
required_capabilities: Optional[dict] = None # 请求实际需要的能力标签
|
||||
status: str # 'pending', 'success', 'failed', 'skipped'
|
||||
skip_reason: Optional[str] = None
|
||||
is_cached: bool = False
|
||||
# 执行结果字段
|
||||
status_code: Optional[int] = None
|
||||
error_type: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
latency_ms: Optional[int] = None
|
||||
concurrent_requests: Optional[int] = None
|
||||
extra_data: Optional[dict] = None
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
finished_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class RequestTraceResponse(BaseModel):
|
||||
"""请求追踪完整响应"""
|
||||
|
||||
request_id: str
|
||||
total_candidates: int
|
||||
final_status: str # 'success', 'failed', 'streaming', 'pending'
|
||||
total_latency_ms: int
|
||||
candidates: List[CandidateResponse]
|
||||
|
||||
|
||||
@router.get("/{request_id}", response_model=RequestTraceResponse)
|
||||
async def get_request_trace(
|
||||
request_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取特定请求的完整追踪信息"""
|
||||
|
||||
adapter = AdminGetRequestTraceAdapter(request_id=request_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/stats/provider/{provider_id}")
|
||||
async def get_provider_failure_rate(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(100, ge=1, le=1000, description="统计最近的尝试数量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取某个 Provider 的失败率统计
|
||||
|
||||
需要管理员权限
|
||||
"""
|
||||
adapter = AdminProviderFailureRateAdapter(provider_id=provider_id, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- 请求追踪适配器 --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetRequestTraceAdapter(AdminApiAdapter):
|
||||
request_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 只查询 candidates
|
||||
candidates = RequestCandidateService.get_candidates_by_request_id(db, self.request_id)
|
||||
|
||||
# 如果没有数据,返回 404
|
||||
if not candidates:
|
||||
raise HTTPException(status_code=404, detail="Request not found")
|
||||
|
||||
# 计算总延迟(只统计已完成的候选:success 或 failed)
|
||||
# 使用显式的 is not None 检查,避免过滤掉 0ms 的快速响应
|
||||
total_latency = sum(
|
||||
c.latency_ms
|
||||
for c in candidates
|
||||
if c.status in ("success", "failed") and c.latency_ms is not None
|
||||
)
|
||||
|
||||
# 判断最终状态:
|
||||
# 1. status="success" 即视为成功(无论 status_code 是什么)
|
||||
# - 流式请求即使客户端断开(499),只要 Provider 成功返回数据,也算成功
|
||||
# 2. 同时检查 status_code 在 200-299 范围,作为额外的成功判断条件
|
||||
# - 用于兼容非流式请求或未正确设置 status 的旧数据
|
||||
# 3. status="streaming" 表示流式请求正在进行中
|
||||
# 4. status="pending" 表示请求尚未开始执行
|
||||
has_success = any(
|
||||
c.status == "success"
|
||||
or (c.status_code is not None and 200 <= c.status_code < 300)
|
||||
for c in candidates
|
||||
)
|
||||
has_streaming = any(c.status == "streaming" for c in candidates)
|
||||
has_pending = any(c.status == "pending" for c in candidates)
|
||||
|
||||
if has_success:
|
||||
final_status = "success"
|
||||
elif has_streaming:
|
||||
# 有候选正在流式传输中
|
||||
final_status = "streaming"
|
||||
elif has_pending:
|
||||
# 有候选正在等待执行
|
||||
final_status = "pending"
|
||||
else:
|
||||
final_status = "failed"
|
||||
|
||||
# 批量加载 provider 信息,避免 N+1 查询
|
||||
provider_ids = {c.provider_id for c in candidates if c.provider_id}
|
||||
provider_map = {}
|
||||
provider_website_map = {}
|
||||
if provider_ids:
|
||||
providers = db.query(Provider).filter(Provider.id.in_(provider_ids)).all()
|
||||
for p in providers:
|
||||
provider_map[p.id] = p.name
|
||||
provider_website_map[p.id] = p.website
|
||||
|
||||
# 批量加载 endpoint 信息
|
||||
endpoint_ids = {c.endpoint_id for c in candidates if c.endpoint_id}
|
||||
endpoint_map = {}
|
||||
if endpoint_ids:
|
||||
endpoints = db.query(ProviderEndpoint).filter(ProviderEndpoint.id.in_(endpoint_ids)).all()
|
||||
endpoint_map = {e.id: e.api_format for e in endpoints}
|
||||
|
||||
# 批量加载 key 信息
|
||||
key_ids = {c.key_id for c in candidates if c.key_id}
|
||||
key_map = {}
|
||||
key_preview_map = {}
|
||||
key_capabilities_map = {}
|
||||
if key_ids:
|
||||
keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.id.in_(key_ids)).all()
|
||||
for k in keys:
|
||||
key_map[k.id] = k.name
|
||||
key_capabilities_map[k.id] = k.capabilities
|
||||
# 生成脱敏预览:保留前缀和最后4位
|
||||
api_key = k.api_key or ""
|
||||
if len(api_key) > 8:
|
||||
# 检测常见前缀模式
|
||||
prefix_end = 0
|
||||
for prefix in ["sk-", "key-", "api-", "ak-"]:
|
||||
if api_key.lower().startswith(prefix):
|
||||
prefix_end = len(prefix)
|
||||
break
|
||||
if prefix_end > 0:
|
||||
key_preview_map[k.id] = f"{api_key[:prefix_end]}***{api_key[-4:]}"
|
||||
else:
|
||||
key_preview_map[k.id] = f"{api_key[:3]}***{api_key[-4:]}"
|
||||
elif len(api_key) > 4:
|
||||
key_preview_map[k.id] = f"***{api_key[-4:]}"
|
||||
else:
|
||||
key_preview_map[k.id] = "***"
|
||||
|
||||
# 构建 candidate 响应列表
|
||||
candidate_responses: List[CandidateResponse] = []
|
||||
for candidate in candidates:
|
||||
provider_name = (
|
||||
provider_map.get(candidate.provider_id) if candidate.provider_id else None
|
||||
)
|
||||
provider_website = (
|
||||
provider_website_map.get(candidate.provider_id) if candidate.provider_id else None
|
||||
)
|
||||
endpoint_name = (
|
||||
endpoint_map.get(candidate.endpoint_id) if candidate.endpoint_id else None
|
||||
)
|
||||
key_name = (
|
||||
key_map.get(candidate.key_id) if candidate.key_id else None
|
||||
)
|
||||
key_preview = (
|
||||
key_preview_map.get(candidate.key_id) if candidate.key_id else None
|
||||
)
|
||||
key_capabilities = (
|
||||
key_capabilities_map.get(candidate.key_id) if candidate.key_id else None
|
||||
)
|
||||
|
||||
candidate_responses.append(
|
||||
CandidateResponse(
|
||||
id=candidate.id,
|
||||
request_id=candidate.request_id,
|
||||
candidate_index=candidate.candidate_index,
|
||||
retry_index=candidate.retry_index,
|
||||
provider_id=candidate.provider_id,
|
||||
provider_name=provider_name,
|
||||
provider_website=provider_website,
|
||||
endpoint_id=candidate.endpoint_id,
|
||||
endpoint_name=endpoint_name,
|
||||
key_id=candidate.key_id,
|
||||
key_name=key_name,
|
||||
key_preview=key_preview,
|
||||
key_capabilities=key_capabilities,
|
||||
required_capabilities=candidate.required_capabilities,
|
||||
status=candidate.status,
|
||||
skip_reason=candidate.skip_reason,
|
||||
is_cached=candidate.is_cached,
|
||||
status_code=candidate.status_code,
|
||||
error_type=candidate.error_type,
|
||||
error_message=candidate.error_message,
|
||||
latency_ms=candidate.latency_ms,
|
||||
concurrent_requests=candidate.concurrent_requests,
|
||||
extra_data=candidate.extra_data,
|
||||
created_at=candidate.created_at,
|
||||
started_at=candidate.started_at,
|
||||
finished_at=candidate.finished_at,
|
||||
)
|
||||
)
|
||||
|
||||
response = RequestTraceResponse(
|
||||
request_id=self.request_id,
|
||||
total_candidates=len(candidates),
|
||||
final_status=final_status,
|
||||
total_latency_ms=total_latency,
|
||||
candidates=candidate_responses,
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="trace_request_detail",
|
||||
request_id=self.request_id,
|
||||
total_candidates=len(candidates),
|
||||
final_status=final_status,
|
||||
total_latency_ms=total_latency,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminProviderFailureRateAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
result = RequestCandidateService.get_candidate_stats_by_provider(
|
||||
db=context.db,
|
||||
provider_id=self.provider_id,
|
||||
limit=self.limit,
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="trace_provider_failure_rate",
|
||||
provider_id=self.provider_id,
|
||||
limit=self.limit,
|
||||
total_attempts=result.get("total_attempts"),
|
||||
failure_rate=result.get("failure_rate"),
|
||||
)
|
||||
return result
|
||||
410
src/api/admin/provider_query.py
Normal file
410
src/api/admin/provider_query.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Provider Query API 端点
|
||||
用于查询提供商的余额、使用记录等信息
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.database.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
|
||||
|
||||
# 初始化适配器注册
|
||||
from src.plugins.provider_query import init # noqa
|
||||
from src.plugins.provider_query import get_query_registry
|
||||
from src.plugins.provider_query.base import QueryCapability
|
||||
from src.utils.auth_utils import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
|
||||
|
||||
|
||||
# ============ Request/Response Models ============
|
||||
|
||||
|
||||
class BalanceQueryRequest(BaseModel):
|
||||
"""余额查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
|
||||
|
||||
|
||||
class UsageSummaryQueryRequest(BaseModel):
|
||||
"""使用汇总查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None
|
||||
period: str = "month" # day, week, month, year
|
||||
|
||||
|
||||
class ModelsQueryRequest(BaseModel):
|
||||
"""模型列表查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None
|
||||
|
||||
|
||||
# ============ API Endpoints ============
|
||||
|
||||
|
||||
@router.get("/adapters")
|
||||
async def list_adapters(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取所有可用的查询适配器
|
||||
|
||||
Returns:
|
||||
适配器列表
|
||||
"""
|
||||
registry = get_query_registry()
|
||||
adapters = registry.list_adapters()
|
||||
|
||||
return {"success": True, "data": adapters}
|
||||
|
||||
|
||||
@router.get("/capabilities/{provider_id}")
|
||||
async def get_provider_capabilities(
|
||||
provider_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取提供商支持的查询能力
|
||||
|
||||
Args:
|
||||
provider_id: 提供商 ID
|
||||
|
||||
Returns:
|
||||
支持的查询能力列表
|
||||
"""
|
||||
# 获取提供商
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
registry = get_query_registry()
|
||||
capabilities = registry.get_capabilities_for_provider(provider.name)
|
||||
|
||||
if capabilities is None:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"capabilities": [],
|
||||
"has_adapter": False,
|
||||
"message": "No query adapter available for this provider",
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"capabilities": [c.name for c in capabilities],
|
||||
"has_adapter": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/balance")
|
||||
async def query_balance(
|
||||
request: BalanceQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商余额
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
余额信息
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
# 查找指定的 API Key
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
# 使用第一个可用的 API Key
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询余额
|
||||
registry = get_query_registry()
|
||||
query_result = await registry.query_provider_balance(
|
||||
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
|
||||
)
|
||||
|
||||
if not query_result.success:
|
||||
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/usage-summary")
|
||||
async def query_usage_summary(
|
||||
request: UsageSummaryQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商使用汇总
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
使用汇总信息
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key(逻辑同上)
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询使用汇总
|
||||
registry = get_query_registry()
|
||||
query_result = await registry.query_provider_usage(
|
||||
provider_type=provider.name,
|
||||
api_key=api_key_value,
|
||||
period=request.period,
|
||||
endpoint_config=endpoint_config,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/models")
|
||||
async def query_available_models(
|
||||
request: ModelsQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商可用模型
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询模型
|
||||
registry = get_query_registry()
|
||||
adapter = registry.get_adapter_for_provider(provider.name)
|
||||
|
||||
if not adapter:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
|
||||
)
|
||||
|
||||
query_result = await adapter.query_available_models(
|
||||
api_key=api_key_value, endpoint_config=endpoint_config
|
||||
)
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/cache/{provider_id}")
|
||||
async def clear_query_cache(
|
||||
provider_id: str,
|
||||
api_key_id: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
清除查询缓存
|
||||
|
||||
Args:
|
||||
provider_id: 提供商 ID
|
||||
api_key_id: 可选,指定清除某个 API Key 的缓存
|
||||
|
||||
Returns:
|
||||
清除结果
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# 获取提供商
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
registry = get_query_registry()
|
||||
adapter = registry.get_adapter_for_provider(provider.name)
|
||||
|
||||
if adapter:
|
||||
if api_key_id:
|
||||
# 获取 API Key 值来清除缓存
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
|
||||
api_key = result.scalar_one_or_none()
|
||||
if api_key:
|
||||
adapter.clear_cache(api_key.api_key)
|
||||
else:
|
||||
adapter.clear_cache()
|
||||
|
||||
return {"success": True, "message": "Cache cleared successfully"}
|
||||
272
src/api/admin/provider_strategy.py
Normal file
272
src/api/admin/provider_strategy.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
提供商策略管理 API 端点
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider
|
||||
from src.models.database_extensions import ProviderUsageTracking
|
||||
|
||||
router = APIRouter(prefix="/api/admin/provider-strategy", tags=["Provider Strategy"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
class ProviderBillingUpdate(BaseModel):
|
||||
billing_type: ProviderBillingType
|
||||
monthly_quota_usd: Optional[float] = None
|
||||
quota_reset_day: int = Field(default=30, ge=1, le=365) # 重置周期(天数)
|
||||
quota_last_reset_at: Optional[str] = None # 当前周期开始时间
|
||||
quota_expires_at: Optional[str] = None
|
||||
rpm_limit: Optional[int] = Field(default=None, ge=0)
|
||||
provider_priority: int = Field(default=100, ge=0, le=200)
|
||||
|
||||
|
||||
@router.put("/providers/{provider_id}/billing")
|
||||
async def update_provider_billing(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminProviderBillingAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}/stats")
|
||||
async def get_provider_stats(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
hours: int = 24,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminProviderStatsAdapter(provider_id=provider_id, hours=hours)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/providers/{provider_id}/quota")
|
||||
async def reset_provider_quota(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Reset provider quota usage to zero"""
|
||||
adapter = AdminProviderResetQuotaAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/strategies")
|
||||
async def list_available_strategies(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminListStrategiesAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminProviderBillingAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
|
||||
async def handle(self, context):
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
config = ProviderBillingUpdate.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
provider.billing_type = config.billing_type
|
||||
provider.monthly_quota_usd = config.monthly_quota_usd
|
||||
provider.quota_reset_day = config.quota_reset_day
|
||||
provider.rpm_limit = config.rpm_limit
|
||||
provider.provider_priority = config.provider_priority
|
||||
|
||||
from dateutil import parser
|
||||
from sqlalchemy import func
|
||||
|
||||
from src.models.database import Usage
|
||||
|
||||
if config.quota_last_reset_at:
|
||||
new_reset_at = parser.parse(config.quota_last_reset_at)
|
||||
provider.quota_last_reset_at = new_reset_at
|
||||
|
||||
# 自动同步该周期内的历史使用量
|
||||
period_usage = (
|
||||
db.query(func.coalesce(func.sum(Usage.total_cost_usd), 0))
|
||||
.filter(
|
||||
Usage.provider_id == self.provider_id,
|
||||
Usage.created_at >= new_reset_at,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
provider.monthly_used_usd = float(period_usage or 0)
|
||||
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
|
||||
|
||||
if config.quota_expires_at:
|
||||
provider.quota_expires_at = parser.parse(config.quota_expires_at)
|
||||
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
logger.info(f"Updated billing config for provider {provider.name}")
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": "Provider billing config updated successfully",
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"billing_type": provider.billing_type.value,
|
||||
"provider_priority": provider.provider_priority,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AdminProviderStatsAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str, hours: int):
|
||||
self.provider_id = provider_id
|
||||
self.hours = hours
|
||||
|
||||
async def handle(self, context):
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
since = datetime.now() - timedelta(hours=self.hours)
|
||||
stats = (
|
||||
db.query(ProviderUsageTracking)
|
||||
.filter(
|
||||
ProviderUsageTracking.provider_id == self.provider_id,
|
||||
ProviderUsageTracking.window_start >= since,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
total_requests = sum(s.total_requests for s in stats)
|
||||
total_success = sum(s.successful_requests for s in stats)
|
||||
total_failures = sum(s.failed_requests for s in stats)
|
||||
avg_response_time = sum(s.avg_response_time_ms for s in stats) / len(stats) if stats else 0
|
||||
total_cost = sum(s.total_cost_usd for s in stats)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"provider_id": self.provider_id,
|
||||
"provider_name": provider.name,
|
||||
"period_hours": self.hours,
|
||||
"billing_info": {
|
||||
"billing_type": provider.billing_type.value,
|
||||
"monthly_quota_usd": provider.monthly_quota_usd,
|
||||
"monthly_used_usd": provider.monthly_used_usd,
|
||||
"quota_remaining_usd": (
|
||||
provider.monthly_quota_usd - provider.monthly_used_usd
|
||||
if provider.monthly_quota_usd is not None
|
||||
else None
|
||||
),
|
||||
"quota_expires_at": (
|
||||
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
|
||||
),
|
||||
},
|
||||
"rpm_info": {
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"rpm_used": provider.rpm_used,
|
||||
"rpm_reset_at": (
|
||||
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
|
||||
),
|
||||
},
|
||||
"usage_stats": {
|
||||
"total_requests": total_requests,
|
||||
"successful_requests": total_success,
|
||||
"failed_requests": total_failures,
|
||||
"success_rate": total_success / total_requests if total_requests > 0 else 0,
|
||||
"avg_response_time_ms": round(avg_response_time, 2),
|
||||
"total_cost_usd": round(total_cost, 4),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AdminProviderResetQuotaAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
|
||||
async def handle(self, context):
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
if provider.billing_type != ProviderBillingType.MONTHLY_QUOTA:
|
||||
raise HTTPException(status_code=400, detail="Only monthly quota providers can be reset")
|
||||
|
||||
old_used = provider.monthly_used_usd
|
||||
provider.monthly_used_usd = 0.0
|
||||
provider.rpm_used = 0
|
||||
provider.rpm_reset_at = None
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Manually reset quota for provider {provider.name}")
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": "Provider quota reset successfully",
|
||||
"provider_name": provider.name,
|
||||
"previous_used": old_used,
|
||||
"current_used": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AdminListStrategiesAdapter(AdminApiAdapter):
|
||||
async def handle(self, context):
|
||||
from src.plugins.manager import get_plugin_manager
|
||||
|
||||
plugin_manager = get_plugin_manager()
|
||||
lb_plugins = plugin_manager.plugins.get("load_balancer", {})
|
||||
|
||||
strategies = []
|
||||
for name, plugin in lb_plugins.items():
|
||||
try:
|
||||
strategies.append(
|
||||
{
|
||||
"name": getattr(plugin, "name", name),
|
||||
"priority": getattr(plugin, "priority", 0),
|
||||
"version": (
|
||||
getattr(plugin.metadata, "version", "1.0.0")
|
||||
if hasattr(plugin, "metadata")
|
||||
else "1.0.0"
|
||||
),
|
||||
"description": (
|
||||
getattr(plugin.metadata, "description", "")
|
||||
if hasattr(plugin, "metadata")
|
||||
else ""
|
||||
),
|
||||
"author": (
|
||||
getattr(plugin.metadata, "author", "Unknown")
|
||||
if hasattr(plugin, "metadata")
|
||||
else "Unknown"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.error(f"Error accessing plugin {name}: {exc}")
|
||||
continue
|
||||
|
||||
strategies.sort(key=lambda x: x["priority"], reverse=True)
|
||||
return JSONResponse({"strategies": strategies, "total": len(strategies)})
|
||||
20
src/api/admin/providers/__init__.py
Normal file
20
src/api/admin/providers/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Provider admin routes export."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .models import router as models_router
|
||||
from .routes import router as routes_router
|
||||
from .summary import router as summary_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/providers", tags=["Admin - Providers"])
|
||||
|
||||
# Provider CRUD
|
||||
router.include_router(routes_router)
|
||||
|
||||
# Provider summary & health monitor
|
||||
router.include_router(summary_router)
|
||||
|
||||
# Provider models management
|
||||
router.include_router(models_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
443
src/api/admin/providers/models.py
Normal file
443
src/api/admin/providers/models.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Provider 模型管理 API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import (
|
||||
ModelCreate,
|
||||
ModelResponse,
|
||||
ModelUpdate,
|
||||
)
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignModelsToProviderRequest,
|
||||
BatchAssignModelsToProviderResponse,
|
||||
)
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
ModelMapping,
|
||||
Provider,
|
||||
)
|
||||
from src.models.pydantic_models import (
|
||||
ProviderAvailableSourceModel,
|
||||
ProviderAvailableSourceModelsResponse,
|
||||
)
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
router = APIRouter(tags=["Model Management"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/{provider_id}/models", response_model=List[ModelResponse])
|
||||
async def list_provider_models(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
is_active: Optional[bool] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ModelResponse]:
|
||||
"""获取提供商的所有模型(管理员)"""
|
||||
adapter = AdminListProviderModelsAdapter(
|
||||
provider_id=provider_id,
|
||||
is_active=is_active,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/{provider_id}/models", response_model=ModelResponse)
|
||||
async def create_provider_model(
|
||||
provider_id: str,
|
||||
model_data: ModelCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelResponse:
|
||||
"""创建模型(管理员)"""
|
||||
adapter = AdminCreateProviderModelAdapter(provider_id=provider_id, model_data=model_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{provider_id}/models/{model_id}", response_model=ModelResponse)
|
||||
async def get_provider_model(
|
||||
provider_id: str,
|
||||
model_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelResponse:
|
||||
"""获取模型详情(管理员)"""
|
||||
adapter = AdminGetProviderModelAdapter(provider_id=provider_id, model_id=model_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{provider_id}/models/{model_id}", response_model=ModelResponse)
|
||||
async def update_provider_model(
|
||||
provider_id: str,
|
||||
model_id: str,
|
||||
model_data: ModelUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelResponse:
|
||||
"""更新模型(管理员)"""
|
||||
adapter = AdminUpdateProviderModelAdapter(
|
||||
provider_id=provider_id,
|
||||
model_id=model_id,
|
||||
model_data=model_data,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{provider_id}/models/{model_id}")
|
||||
async def delete_provider_model(
|
||||
provider_id: str,
|
||||
model_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除模型(管理员)"""
|
||||
adapter = AdminDeleteProviderModelAdapter(provider_id=provider_id, model_id=model_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/{provider_id}/models/batch", response_model=List[ModelResponse])
|
||||
async def batch_create_provider_models(
|
||||
provider_id: str,
|
||||
models_data: List[ModelCreate],
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ModelResponse]:
|
||||
"""批量创建模型(管理员)"""
|
||||
adapter = AdminBatchCreateModelsAdapter(provider_id=provider_id, models_data=models_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{provider_id}/available-source-models",
|
||||
response_model=ProviderAvailableSourceModelsResponse,
|
||||
)
|
||||
async def get_provider_available_source_models(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取该 Provider 支持的所有统一模型名(source_model)
|
||||
|
||||
包括:
|
||||
1. 通过 ModelMapping 映射的模型
|
||||
2. 直连模型(Model.provider_model_name 直接作为统一模型名)
|
||||
"""
|
||||
adapter = AdminGetProviderAvailableSourceModelsAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider_id}/assign-global-models",
|
||||
response_model=BatchAssignModelsToProviderResponse,
|
||||
)
|
||||
async def batch_assign_global_models_to_provider(
|
||||
provider_id: str,
|
||||
payload: BatchAssignModelsToProviderRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignModelsToProviderResponse:
|
||||
"""批量为 Provider 关联 GlobalModels(自动继承价格和能力配置)"""
|
||||
adapter = AdminBatchAssignModelsToProviderAdapter(
|
||||
provider_id=provider_id, payload=payload
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListProviderModelsAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
is_active: Optional[bool]
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
models = ModelService.get_models_by_provider(
|
||||
db, self.provider_id, self.skip, self.limit, self.is_active
|
||||
)
|
||||
return [ModelService.convert_to_response(model) for model in models]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminCreateProviderModelAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
model_data: ModelCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
try:
|
||||
model = ModelService.create_model(db, self.provider_id, self.model_data)
|
||||
logger.info(f"Model created: {model.provider_model_name} for provider {provider.name} by {context.user.username}")
|
||||
return ModelService.convert_to_response(model)
|
||||
except Exception as exc:
|
||||
raise InvalidRequestException(str(exc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetProviderModelAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
model_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
raise NotFoundException("Model not found", "model")
|
||||
|
||||
return ModelService.convert_to_response(model)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateProviderModelAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
model_id: str
|
||||
model_data: ModelUpdate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
raise NotFoundException("Model not found", "model")
|
||||
|
||||
try:
|
||||
updated_model = ModelService.update_model(db, self.model_id, self.model_data)
|
||||
logger.info(f"Model updated: {updated_model.provider_model_name} by {context.user.username}")
|
||||
return ModelService.convert_to_response(updated_model)
|
||||
except Exception as exc:
|
||||
raise InvalidRequestException(str(exc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteProviderModelAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
model_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
raise NotFoundException("Model not found", "model")
|
||||
|
||||
model_name = model.provider_model_name
|
||||
try:
|
||||
ModelService.delete_model(db, self.model_id)
|
||||
logger.info(f"Model deleted: {model_name} by {context.user.username}")
|
||||
return {"message": f"Model '{model_name}' deleted successfully"}
|
||||
except Exception as exc:
|
||||
raise InvalidRequestException(str(exc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchCreateModelsAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
models_data: List[ModelCreate]
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
try:
|
||||
models = ModelService.batch_create_models(db, self.provider_id, self.models_data)
|
||||
logger.info(f"Batch created {len(models)} models for provider {provider.name} by {context.user.username}")
|
||||
return [ModelService.convert_to_response(model) for model in models]
|
||||
except Exception as exc:
|
||||
raise InvalidRequestException(str(exc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""
|
||||
返回 Provider 支持的所有 GlobalModel
|
||||
|
||||
方案 A 逻辑:
|
||||
1. 查询该 Provider 的所有 Model
|
||||
2. 通过 Model.global_model_id 获取 GlobalModel
|
||||
3. 查询所有指向该 GlobalModel 的别名(ModelMapping.alias)
|
||||
"""
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
# 1. 查询该 Provider 的所有活跃 Model(预加载 GlobalModel)
|
||||
models = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.global_model))
|
||||
.filter(Model.provider_id == self.provider_id, Model.is_active == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 2. 构建以 GlobalModel 为主键的字典
|
||||
global_models_dict: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for model in models:
|
||||
global_model = model.global_model
|
||||
if not global_model or not global_model.is_active:
|
||||
continue
|
||||
|
||||
global_model_name = global_model.name
|
||||
|
||||
# 如果该 GlobalModel 还未处理,初始化
|
||||
if global_model_name not in global_models_dict:
|
||||
# 查询指向该 GlobalModel 的所有别名/映射
|
||||
alias_rows = (
|
||||
db.query(ModelMapping.source_model)
|
||||
.filter(
|
||||
ModelMapping.target_global_model_id == global_model.id,
|
||||
ModelMapping.is_active == True,
|
||||
or_(
|
||||
ModelMapping.provider_id == self.provider_id,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
alias_list = [alias[0] for alias in alias_rows]
|
||||
|
||||
global_models_dict[global_model_name] = {
|
||||
"global_model_name": global_model_name,
|
||||
"display_name": global_model.display_name,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"has_alias": len(alias_list) > 0,
|
||||
"aliases": alias_list,
|
||||
"model_id": model.id,
|
||||
"price": {
|
||||
"input_price_per_1m": model.get_effective_input_price(),
|
||||
"output_price_per_1m": model.get_effective_output_price(),
|
||||
"cache_creation_price_per_1m": model.get_effective_cache_creation_price(),
|
||||
"cache_read_price_per_1m": model.get_effective_cache_read_price(),
|
||||
"price_per_request": model.get_effective_price_per_request(),
|
||||
},
|
||||
"capabilities": {
|
||||
"supports_vision": bool(model.supports_vision),
|
||||
"supports_function_calling": bool(model.supports_function_calling),
|
||||
"supports_streaming": bool(model.supports_streaming),
|
||||
},
|
||||
"is_active": bool(model.is_active),
|
||||
}
|
||||
|
||||
models_list = [
|
||||
ProviderAvailableSourceModel(**global_models_dict[name])
|
||||
for name in sorted(global_models_dict.keys())
|
||||
]
|
||||
|
||||
return ProviderAvailableSourceModelsResponse(models=models_list, total=len(models_list))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
|
||||
"""批量为 Provider 关联 GlobalModels"""
|
||||
|
||||
provider_id: str
|
||||
payload: BatchAssignModelsToProviderRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
success = []
|
||||
errors = []
|
||||
|
||||
for global_model_id in self.payload.global_model_ids:
|
||||
try:
|
||||
global_model = (
|
||||
db.query(GlobalModel).filter(GlobalModel.id == global_model_id).first()
|
||||
)
|
||||
if not global_model:
|
||||
errors.append(
|
||||
{"global_model_id": global_model_id, "error": "GlobalModel not found"}
|
||||
)
|
||||
continue
|
||||
|
||||
# 检查是否已存在关联
|
||||
existing = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == self.provider_id,
|
||||
Model.global_model_id == global_model_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
errors.append(
|
||||
{
|
||||
"global_model_id": global_model_id,
|
||||
"global_model_name": global_model.name,
|
||||
"error": "Already associated",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 创建新的 Model 记录,继承 GlobalModel 的配置
|
||||
new_model = Model(
|
||||
provider_id=self.provider_id,
|
||||
global_model_id=global_model_id,
|
||||
provider_model_name=global_model.name,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(new_model)
|
||||
db.flush()
|
||||
|
||||
success.append(
|
||||
{
|
||||
"global_model_id": global_model_id,
|
||||
"global_model_name": global_model.name,
|
||||
"model_id": new_model.id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append({"global_model_id": global_model_id, "error": str(e)})
|
||||
|
||||
db.commit()
|
||||
logger.info(
|
||||
f"Batch assigned {len(success)} GlobalModels to provider {provider.name} by {context.user.username}"
|
||||
)
|
||||
|
||||
return BatchAssignModelsToProviderResponse(success=success, errors=errors)
|
||||
249
src/api/admin/providers/routes.py
Normal file
249
src/api/admin/providers/routes.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""管理员 Provider 管理路由。"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.database import get_db
|
||||
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
|
||||
from src.models.database import Provider
|
||||
|
||||
router = APIRouter(tags=["Provider CRUD"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_providers(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
is_active: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminListProvidersAdapter(skip=skip, limit=limit, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_provider(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminCreateProviderAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{provider_id}")
|
||||
async def update_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminUpdateProviderAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{provider_id}")
|
||||
async def delete_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminDeleteProviderAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminListProvidersAdapter(AdminApiAdapter):
|
||||
def __init__(self, skip: int, limit: int, is_active: Optional[bool]):
|
||||
self.skip = skip
|
||||
self.limit = limit
|
||||
self.is_active = is_active
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = db.query(Provider)
|
||||
if self.is_active is not None:
|
||||
query = query.filter(Provider.is_active == self.is_active)
|
||||
providers = query.offset(self.skip).limit(self.limit).all()
|
||||
|
||||
data = []
|
||||
for provider in providers:
|
||||
api_format = getattr(provider, "api_format", None)
|
||||
base_url = getattr(provider, "base_url", None)
|
||||
api_key = getattr(provider, "api_key", None)
|
||||
priority = getattr(provider, "priority", provider.provider_priority)
|
||||
|
||||
data.append(
|
||||
{
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"api_format": api_format.value if api_format else None,
|
||||
"base_url": base_url,
|
||||
"api_key": "***" if api_key else None,
|
||||
"priority": priority,
|
||||
"is_active": provider.is_active,
|
||||
"created_at": provider.created_at.isoformat(),
|
||||
"updated_at": provider.updated_at.isoformat() if provider.updated_at else None,
|
||||
}
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="list_providers",
|
||||
filter_is_active=self.is_active,
|
||||
limit=self.limit,
|
||||
skip=self.skip,
|
||||
result_count=len(data),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AdminCreateProviderAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
|
||||
validated_data = CreateProviderRequest.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
# 将 Pydantic 验证错误转换为友好的错误信息
|
||||
errors = []
|
||||
for error in exc.errors():
|
||||
field = " -> ".join(str(x) for x in error["loc"])
|
||||
errors.append(f"{field}: {error['msg']}")
|
||||
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
|
||||
|
||||
try:
|
||||
# 检查名称是否已存在
|
||||
existing = db.query(Provider).filter(Provider.name == validated_data.name).first()
|
||||
if existing:
|
||||
raise InvalidRequestException(f"提供商名称 '{validated_data.name}' 已存在")
|
||||
|
||||
# 将验证后的数据转换为枚举类型
|
||||
billing_type = (
|
||||
ProviderBillingType(validated_data.billing_type)
|
||||
if validated_data.billing_type
|
||||
else ProviderBillingType.PAY_AS_YOU_GO
|
||||
)
|
||||
|
||||
# 创建 Provider 对象
|
||||
provider = Provider(
|
||||
name=validated_data.name,
|
||||
display_name=validated_data.display_name,
|
||||
description=validated_data.description,
|
||||
website=validated_data.website,
|
||||
billing_type=billing_type,
|
||||
monthly_quota_usd=validated_data.monthly_quota_usd,
|
||||
quota_reset_day=validated_data.quota_reset_day,
|
||||
quota_last_reset_at=validated_data.quota_last_reset_at,
|
||||
quota_expires_at=validated_data.quota_expires_at,
|
||||
rpm_limit=validated_data.rpm_limit,
|
||||
provider_priority=validated_data.provider_priority,
|
||||
is_active=validated_data.is_active,
|
||||
rate_limit=validated_data.rate_limit,
|
||||
concurrent_limit=validated_data.concurrent_limit,
|
||||
config=validated_data.config,
|
||||
)
|
||||
|
||||
db.add(provider)
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="create_provider",
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.name,
|
||||
billing_type=provider.billing_type.value if provider.billing_type else None,
|
||||
is_active=provider.is_active,
|
||||
provider_priority=provider.provider_priority,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"message": "提供商创建成功",
|
||||
}
|
||||
except InvalidRequestException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
class AdminUpdateProviderAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
# 查找 Provider
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("提供商不存在", "provider")
|
||||
|
||||
try:
|
||||
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
|
||||
validated_data = UpdateProviderRequest.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
# 将 Pydantic 验证错误转换为友好的错误信息
|
||||
errors = []
|
||||
for error in exc.errors():
|
||||
field = " -> ".join(str(x) for x in error["loc"])
|
||||
errors.append(f"{field}: {error['msg']}")
|
||||
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
|
||||
|
||||
try:
|
||||
# 更新字段(只更新非 None 的字段)
|
||||
update_data = validated_data.model_dump(exclude_unset=True)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field == "billing_type" and value is not None:
|
||||
# billing_type 需要转换为枚举
|
||||
setattr(provider, field, ProviderBillingType(value))
|
||||
else:
|
||||
setattr(provider, field, value)
|
||||
|
||||
provider.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="update_provider",
|
||||
provider_id=provider.id,
|
||||
changed_fields=list(update_data.keys()),
|
||||
is_active=provider.is_active,
|
||||
provider_priority=provider.provider_priority,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"is_active": provider.is_active,
|
||||
"message": "提供商更新成功",
|
||||
}
|
||||
except InvalidRequestException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
class AdminDeleteProviderAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("提供商不存在", "provider")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="delete_provider",
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.name,
|
||||
)
|
||||
db.delete(provider)
|
||||
db.commit()
|
||||
return {"message": "提供商已删除"}
|
||||
348
src/api/admin/providers/summary.py
Normal file
348
src/api/admin/providers/summary.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
Provider 摘要与健康监控 API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import case, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import (
|
||||
Model,
|
||||
Provider,
|
||||
ProviderAPIKey,
|
||||
ProviderEndpoint,
|
||||
RequestCandidate,
|
||||
)
|
||||
from src.models.endpoint_models import (
|
||||
EndpointHealthEvent,
|
||||
EndpointHealthMonitor,
|
||||
ProviderEndpointHealthMonitorResponse,
|
||||
ProviderUpdateRequest,
|
||||
ProviderWithEndpointsSummary,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Provider Summary"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/summary", response_model=List[ProviderWithEndpointsSummary])
|
||||
async def get_providers_summary(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ProviderWithEndpointsSummary]:
|
||||
"""获取所有 Providers 的摘要信息(包含 Endpoints 和 Keys 统计)"""
|
||||
adapter = AdminProviderSummaryAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{provider_id}/summary", response_model=ProviderWithEndpointsSummary)
|
||||
async def get_provider_summary(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderWithEndpointsSummary:
|
||||
"""获取单个 Provider 的摘要信息(包含 Endpoints 和 Keys 统计)"""
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {provider_id} not found")
|
||||
|
||||
return _build_provider_summary(db, provider)
|
||||
|
||||
|
||||
@router.get("/{provider_id}/health-monitor", response_model=ProviderEndpointHealthMonitorResponse)
|
||||
async def get_provider_health_monitor(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
|
||||
per_endpoint_limit: int = Query(48, ge=10, le=200, description="每个端点的事件数量"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointHealthMonitorResponse:
|
||||
"""获取 Provider 下所有端点的健康监控时间线"""
|
||||
|
||||
adapter = AdminProviderHealthMonitorAdapter(
|
||||
provider_id=provider_id,
|
||||
lookback_hours=lookback_hours,
|
||||
per_endpoint_limit=per_endpoint_limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{provider_id}", response_model=ProviderWithEndpointsSummary)
|
||||
async def update_provider_settings(
|
||||
provider_id: str,
|
||||
update_data: ProviderUpdateRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderWithEndpointsSummary:
|
||||
"""更新 Provider 基础配置(display_name, description, priority, weight 等)"""
|
||||
|
||||
adapter = AdminUpdateProviderSettingsAdapter(provider_id=provider_id, update_data=update_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndpointsSummary:
|
||||
endpoints = db.query(ProviderEndpoint).filter(ProviderEndpoint.provider_id == provider.id).all()
|
||||
|
||||
total_endpoints = len(endpoints)
|
||||
active_endpoints = sum(1 for e in endpoints if e.is_active)
|
||||
endpoint_ids = [e.id for e in endpoints]
|
||||
|
||||
# Key 统计(合并为单个查询)
|
||||
total_keys = 0
|
||||
active_keys = 0
|
||||
if endpoint_ids:
|
||||
key_stats = db.query(
|
||||
func.count(ProviderAPIKey.id).label("total"),
|
||||
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
|
||||
).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).first()
|
||||
total_keys = key_stats.total or 0
|
||||
active_keys = int(key_stats.active or 0)
|
||||
|
||||
# Model 统计(合并为单个查询)
|
||||
model_stats = db.query(
|
||||
func.count(Model.id).label("total"),
|
||||
func.sum(case((Model.is_active == True, 1), else_=0)).label("active"),
|
||||
).filter(Model.provider_id == provider.id).first()
|
||||
total_models = model_stats.total or 0
|
||||
active_models = int(model_stats.active or 0)
|
||||
|
||||
api_formats = [e.api_format for e in endpoints]
|
||||
|
||||
# 优化: 一次性加载所有 endpoint 的 keys,避免 N+1 查询
|
||||
all_keys = []
|
||||
if endpoint_ids:
|
||||
all_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).all()
|
||||
)
|
||||
|
||||
# 按 endpoint_id 分组 keys
|
||||
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {}
|
||||
for key in all_keys:
|
||||
if key.endpoint_id not in keys_by_endpoint:
|
||||
keys_by_endpoint[key.endpoint_id] = []
|
||||
keys_by_endpoint[key.endpoint_id].append(key)
|
||||
|
||||
endpoint_health_map: dict[str, float] = {}
|
||||
for endpoint in endpoints:
|
||||
keys = keys_by_endpoint.get(endpoint.id, [])
|
||||
if keys:
|
||||
health_scores = [k.health_score for k in keys if k.health_score is not None]
|
||||
avg_health = sum(health_scores) / len(health_scores) if health_scores else 1.0
|
||||
endpoint_health_map[endpoint.id] = avg_health
|
||||
else:
|
||||
endpoint_health_map[endpoint.id] = 1.0
|
||||
|
||||
all_health_scores = list(endpoint_health_map.values())
|
||||
avg_health_score = sum(all_health_scores) / len(all_health_scores) if all_health_scores else 1.0
|
||||
unhealthy_endpoints = sum(1 for score in all_health_scores if score < 0.5)
|
||||
|
||||
# 计算每个端点的活跃密钥数量
|
||||
active_keys_by_endpoint: dict[str, int] = {}
|
||||
for endpoint_id, keys in keys_by_endpoint.items():
|
||||
active_keys_by_endpoint[endpoint_id] = sum(1 for k in keys if k.is_active)
|
||||
|
||||
endpoint_health_details = [
|
||||
{
|
||||
"api_format": e.api_format,
|
||||
"health_score": endpoint_health_map.get(e.id, 1.0),
|
||||
"is_active": e.is_active,
|
||||
"active_keys": active_keys_by_endpoint.get(e.id, 0),
|
||||
}
|
||||
for e in endpoints
|
||||
]
|
||||
|
||||
return ProviderWithEndpointsSummary(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
display_name=provider.display_name,
|
||||
description=provider.description,
|
||||
website=provider.website,
|
||||
provider_priority=provider.provider_priority,
|
||||
is_active=provider.is_active,
|
||||
billing_type=provider.billing_type.value if provider.billing_type else None,
|
||||
monthly_quota_usd=provider.monthly_quota_usd,
|
||||
monthly_used_usd=provider.monthly_used_usd,
|
||||
quota_reset_day=provider.quota_reset_day,
|
||||
quota_last_reset_at=provider.quota_last_reset_at,
|
||||
quota_expires_at=provider.quota_expires_at,
|
||||
rpm_limit=provider.rpm_limit,
|
||||
rpm_used=provider.rpm_used,
|
||||
rpm_reset_at=provider.rpm_reset_at,
|
||||
total_endpoints=total_endpoints,
|
||||
active_endpoints=active_endpoints,
|
||||
total_keys=total_keys,
|
||||
active_keys=active_keys,
|
||||
total_models=total_models,
|
||||
active_models=active_models,
|
||||
avg_health_score=avg_health_score,
|
||||
unhealthy_endpoints=unhealthy_endpoints,
|
||||
api_formats=api_formats,
|
||||
endpoint_health_details=endpoint_health_details,
|
||||
created_at=provider.created_at,
|
||||
updated_at=provider.updated_at,
|
||||
)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
lookback_hours: int
|
||||
per_endpoint_limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||
|
||||
endpoints = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(ProviderEndpoint.provider_id == self.provider_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
since = now - timedelta(hours=self.lookback_hours)
|
||||
|
||||
endpoint_ids = [endpoint.id for endpoint in endpoints]
|
||||
if not endpoint_ids:
|
||||
response = ProviderEndpointHealthMonitorResponse(
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.display_name or provider.name,
|
||||
generated_at=now,
|
||||
endpoints=[],
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="provider_health_monitor",
|
||||
provider_id=self.provider_id,
|
||||
endpoint_count=0,
|
||||
lookback_hours=self.lookback_hours,
|
||||
)
|
||||
return response
|
||||
|
||||
limit_rows = max(200, self.per_endpoint_limit * max(1, len(endpoint_ids)) * 2)
|
||||
attempts_query = (
|
||||
db.query(RequestCandidate)
|
||||
.filter(
|
||||
RequestCandidate.endpoint_id.in_(endpoint_ids),
|
||||
RequestCandidate.created_at >= since,
|
||||
)
|
||||
.order_by(RequestCandidate.created_at.desc())
|
||||
)
|
||||
attempts = attempts_query.limit(limit_rows).all()
|
||||
|
||||
buffered_attempts: Dict[str, List[RequestCandidate]] = {eid: [] for eid in endpoint_ids}
|
||||
counters: Dict[str, int] = {eid: 0 for eid in endpoint_ids}
|
||||
|
||||
for attempt in attempts:
|
||||
if not attempt.endpoint_id or attempt.endpoint_id not in buffered_attempts:
|
||||
continue
|
||||
if counters[attempt.endpoint_id] >= self.per_endpoint_limit:
|
||||
continue
|
||||
buffered_attempts[attempt.endpoint_id].append(attempt)
|
||||
counters[attempt.endpoint_id] += 1
|
||||
|
||||
endpoint_monitors: List[EndpointHealthMonitor] = []
|
||||
for endpoint in endpoints:
|
||||
attempt_list = list(reversed(buffered_attempts.get(endpoint.id, [])))
|
||||
events: List[EndpointHealthEvent] = []
|
||||
for attempt in attempt_list:
|
||||
event_timestamp = attempt.finished_at or attempt.started_at or attempt.created_at
|
||||
events.append(
|
||||
EndpointHealthEvent(
|
||||
timestamp=event_timestamp,
|
||||
status=attempt.status,
|
||||
status_code=attempt.status_code,
|
||||
latency_ms=attempt.latency_ms,
|
||||
error_type=attempt.error_type,
|
||||
error_message=attempt.error_message,
|
||||
)
|
||||
)
|
||||
|
||||
success_count = sum(1 for event in events if event.status == "success")
|
||||
failed_count = sum(1 for event in events if event.status == "failed")
|
||||
skipped_count = sum(1 for event in events if event.status == "skipped")
|
||||
total_attempts = len(events)
|
||||
success_rate = success_count / total_attempts if total_attempts else 1.0
|
||||
last_event_at = events[-1].timestamp if events else None
|
||||
|
||||
endpoint_monitors.append(
|
||||
EndpointHealthMonitor(
|
||||
endpoint_id=endpoint.id,
|
||||
api_format=endpoint.api_format,
|
||||
is_active=endpoint.is_active,
|
||||
total_attempts=total_attempts,
|
||||
success_count=success_count,
|
||||
failed_count=failed_count,
|
||||
skipped_count=skipped_count,
|
||||
success_rate=success_rate,
|
||||
last_event_at=last_event_at,
|
||||
events=events,
|
||||
)
|
||||
)
|
||||
|
||||
response = ProviderEndpointHealthMonitorResponse(
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.display_name or provider.name,
|
||||
generated_at=now,
|
||||
endpoints=endpoint_monitors,
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="provider_health_monitor",
|
||||
provider_id=self.provider_id,
|
||||
endpoint_count=len(endpoint_monitors),
|
||||
lookback_hours=self.lookback_hours,
|
||||
per_endpoint_limit=self.per_endpoint_limit,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class AdminProviderSummaryAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
providers = (
|
||||
db.query(Provider)
|
||||
.order_by(Provider.provider_priority.asc(), Provider.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
return [_build_provider_summary(db, provider) for provider in providers]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateProviderSettingsAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
update_data: ProviderUpdateRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
update_dict = self.update_data.model_dump(exclude_unset=True)
|
||||
if "billing_type" in update_dict and update_dict["billing_type"] is not None:
|
||||
update_dict["billing_type"] = ProviderBillingType(update_dict["billing_type"])
|
||||
|
||||
for key, value in update_dict.items():
|
||||
setattr(provider, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
admin_name = context.user.username if context.user else "admin"
|
||||
logger.info(f"Provider {provider.name} updated by {admin_name}: {update_dict}")
|
||||
|
||||
return _build_provider_summary(db, provider)
|
||||
14
src/api/admin/security/__init__.py
Normal file
14
src/api/admin/security/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
安全管理 API
|
||||
|
||||
提供 IP 黑白名单管理等安全功能
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .ip_management import router as ip_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(ip_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
202
src/api/admin/security/ip_management.py
Normal file
202
src/api/admin/security/ip_management.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
IP 安全管理接口
|
||||
|
||||
提供 IP 黑白名单管理和速率限制统计
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.adapter import ApiMode
|
||||
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.services.rate_limit.ip_limiter import IPRateLimiter
|
||||
|
||||
router = APIRouter(prefix="/api/admin/security/ip", tags=["IP Security"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# ========== Pydantic 模型 ==========
|
||||
|
||||
|
||||
class AddIPToBlacklistRequest(BaseModel):
|
||||
"""添加 IP 到黑名单请求"""
|
||||
|
||||
ip_address: str = Field(..., description="IP 地址")
|
||||
reason: str = Field(..., min_length=1, max_length=200, description="加入黑名单的原因")
|
||||
ttl: Optional[int] = Field(None, gt=0, description="过期时间(秒),None 表示永久")
|
||||
|
||||
|
||||
class RemoveIPFromBlacklistRequest(BaseModel):
|
||||
"""从黑名单移除 IP 请求"""
|
||||
|
||||
ip_address: str = Field(..., description="IP 地址")
|
||||
|
||||
|
||||
class AddIPToWhitelistRequest(BaseModel):
|
||||
"""添加 IP 到白名单请求"""
|
||||
|
||||
ip_address: str = Field(..., description="IP 地址或 CIDR 格式(如 192.168.1.0/24)")
|
||||
|
||||
|
||||
class RemoveIPFromWhitelistRequest(BaseModel):
|
||||
"""从白名单移除 IP 请求"""
|
||||
|
||||
ip_address: str = Field(..., description="IP 地址")
|
||||
|
||||
|
||||
# ========== API 端点 ==========
|
||||
|
||||
|
||||
@router.post("/blacklist")
|
||||
async def add_to_blacklist(request: Request, db: Session = Depends(get_db)):
|
||||
"""Add IP to blacklist"""
|
||||
adapter = AddToBlacklistAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.delete("/blacklist/{ip_address}")
|
||||
async def remove_from_blacklist(ip_address: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Remove IP from blacklist"""
|
||||
adapter = RemoveFromBlacklistAdapter(ip_address=ip_address)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.get("/blacklist/stats")
|
||||
async def get_blacklist_stats(request: Request, db: Session = Depends(get_db)):
|
||||
"""Get blacklist statistics"""
|
||||
adapter = GetBlacklistStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.post("/whitelist")
|
||||
async def add_to_whitelist(request: Request, db: Session = Depends(get_db)):
|
||||
"""Add IP to whitelist"""
|
||||
adapter = AddToWhitelistAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.delete("/whitelist/{ip_address}")
|
||||
async def remove_from_whitelist(ip_address: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Remove IP from whitelist"""
|
||||
adapter = RemoveFromWhitelistAdapter(ip_address=ip_address)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.get("/whitelist")
|
||||
async def get_whitelist(request: Request, db: Session = Depends(get_db)):
|
||||
"""Get whitelist"""
|
||||
adapter = GetWhitelistAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
# ========== 适配器实现 ==========
|
||||
|
||||
|
||||
class AddToBlacklistAdapter(AuthenticatedApiAdapter):
|
||||
"""添加 IP 到黑名单适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
req = AddIPToBlacklistRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
success = await IPRateLimiter.add_to_blacklist(req.ip_address, req.reason, req.ttl)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="添加 IP 到黑名单失败(Redis 不可用)",
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"IP {req.ip_address} 已加入黑名单",
|
||||
"reason": req.reason,
|
||||
"ttl": req.ttl or "永久",
|
||||
}
|
||||
|
||||
|
||||
class RemoveFromBlacklistAdapter(AuthenticatedApiAdapter):
|
||||
"""从黑名单移除 IP 适配器"""
|
||||
|
||||
def __init__(self, ip_address: str):
|
||||
self.ip_address = ip_address
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
success = await IPRateLimiter.remove_from_blacklist(self.ip_address)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"IP {self.ip_address} 不在黑名单中"
|
||||
)
|
||||
|
||||
return {"success": True, "message": f"IP {self.ip_address} 已从黑名单移除"}
|
||||
|
||||
|
||||
class GetBlacklistStatsAdapter(AuthenticatedApiAdapter):
|
||||
"""获取黑名单统计适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
stats = await IPRateLimiter.get_blacklist_stats()
|
||||
return stats
|
||||
|
||||
|
||||
class AddToWhitelistAdapter(AuthenticatedApiAdapter):
|
||||
"""添加 IP 到白名单适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
req = AddIPToWhitelistRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
success = await IPRateLimiter.add_to_whitelist(req.ip_address)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"添加 IP 到白名单失败(无效的 IP 格式或 Redis 不可用)",
|
||||
)
|
||||
|
||||
return {"success": True, "message": f"IP {req.ip_address} 已加入白名单"}
|
||||
|
||||
|
||||
class RemoveFromWhitelistAdapter(AuthenticatedApiAdapter):
|
||||
"""从白名单移除 IP 适配器"""
|
||||
|
||||
def __init__(self, ip_address: str):
|
||||
self.ip_address = ip_address
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
success = await IPRateLimiter.remove_from_whitelist(self.ip_address)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"IP {self.ip_address} 不在白名单中"
|
||||
)
|
||||
|
||||
return {"success": True, "message": f"IP {self.ip_address} 已从白名单移除"}
|
||||
|
||||
|
||||
class GetWhitelistAdapter(AuthenticatedApiAdapter):
|
||||
"""获取白名单适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
whitelist = await IPRateLimiter.get_whitelist()
|
||||
return {"whitelist": list(whitelist), "total": len(whitelist)}
|
||||
312
src/api/admin/system.py
Normal file
312
src/api/admin/system.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""系统设置API端点。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException, translate_pydantic_error
|
||||
from src.database import get_db
|
||||
from src.models.api import SystemSettingsRequest, SystemSettingsResponse
|
||||
from src.models.database import ApiKey, Provider, Usage, User
|
||||
from src.services.system.config import SystemConfigService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/settings")
|
||||
async def get_system_settings(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取系统设置(管理员)"""
|
||||
|
||||
adapter = AdminGetSystemSettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/settings")
|
||||
async def update_system_settings(http_request: Request, db: Session = Depends(get_db)):
|
||||
"""更新系统设置(管理员)"""
|
||||
|
||||
adapter = AdminUpdateSystemSettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/configs")
|
||||
async def get_all_system_configs(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取所有系统配置(管理员)"""
|
||||
|
||||
adapter = AdminGetAllConfigsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/configs/{key}")
|
||||
async def get_system_config(key: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""获取特定系统配置(管理员)"""
|
||||
|
||||
adapter = AdminGetSystemConfigAdapter(key=key)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/configs/{key}")
|
||||
async def set_system_config(
|
||||
key: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""设置系统配置(管理员)"""
|
||||
|
||||
adapter = AdminSetSystemConfigAdapter(key=key)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/configs/{key}")
|
||||
async def delete_system_config(key: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""删除系统配置(管理员)"""
|
||||
|
||||
adapter = AdminDeleteSystemConfigAdapter(key=key)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_system_stats(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminSystemStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/cleanup")
|
||||
async def trigger_cleanup(request: Request, db: Session = Depends(get_db)):
|
||||
"""Manually trigger usage record cleanup task"""
|
||||
adapter = AdminTriggerCleanupAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/api-formats")
|
||||
async def get_api_formats(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取所有可用的API格式列表"""
|
||||
adapter = AdminGetApiFormatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- 系统设置适配器 --------
|
||||
|
||||
|
||||
class AdminGetSystemSettingsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
default_provider = SystemConfigService.get_default_provider(db)
|
||||
default_model = SystemConfigService.get_config(db, "default_model")
|
||||
enable_usage_tracking = (
|
||||
SystemConfigService.get_config(db, "enable_usage_tracking", "true") == "true"
|
||||
)
|
||||
|
||||
return SystemSettingsResponse(
|
||||
default_provider=default_provider,
|
||||
default_model=default_model,
|
||||
enable_usage_tracking=enable_usage_tracking,
|
||||
)
|
||||
|
||||
|
||||
class AdminUpdateSystemSettingsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
settings_request = SystemSettingsRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
if settings_request.default_provider is not None:
|
||||
provider = (
|
||||
db.query(Provider)
|
||||
.filter(
|
||||
Provider.name == settings_request.default_provider,
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not provider and settings_request.default_provider != "":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"提供商 '{settings_request.default_provider}' 不存在或未启用",
|
||||
)
|
||||
|
||||
if settings_request.default_provider:
|
||||
SystemConfigService.set_default_provider(db, settings_request.default_provider)
|
||||
else:
|
||||
SystemConfigService.delete_config(db, "default_provider")
|
||||
|
||||
if settings_request.default_model is not None:
|
||||
if settings_request.default_model:
|
||||
SystemConfigService.set_config(db, "default_model", settings_request.default_model)
|
||||
else:
|
||||
SystemConfigService.delete_config(db, "default_model")
|
||||
|
||||
if settings_request.enable_usage_tracking is not None:
|
||||
SystemConfigService.set_config(
|
||||
db,
|
||||
"enable_usage_tracking",
|
||||
str(settings_request.enable_usage_tracking).lower(),
|
||||
)
|
||||
|
||||
return {"message": "系统设置更新成功"}
|
||||
|
||||
|
||||
class AdminGetAllConfigsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
return SystemConfigService.get_all_configs(context.db)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetSystemConfigAdapter(AdminApiAdapter):
|
||||
key: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
value = SystemConfigService.get_config(context.db, self.key)
|
||||
if value is None:
|
||||
raise NotFoundException(f"配置项 '{self.key}' 不存在")
|
||||
return {"key": self.key, "value": value}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminSetSystemConfigAdapter(AdminApiAdapter):
|
||||
key: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
payload = context.ensure_json_body()
|
||||
config = SystemConfigService.set_config(
|
||||
context.db,
|
||||
self.key,
|
||||
payload.get("value"),
|
||||
payload.get("description"),
|
||||
)
|
||||
|
||||
return {
|
||||
"key": config.key,
|
||||
"value": config.value,
|
||||
"description": config.description,
|
||||
"updated_at": config.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteSystemConfigAdapter(AdminApiAdapter):
|
||||
key: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
deleted = SystemConfigService.delete_config(context.db, self.key)
|
||||
if not deleted:
|
||||
raise NotFoundException(f"配置项 '{self.key}' 不存在")
|
||||
return {"message": f"配置项 '{self.key}' 已删除"}
|
||||
|
||||
|
||||
class AdminSystemStatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
total_users = db.query(User).count()
|
||||
active_users = db.query(User).filter(User.is_active.is_(True)).count()
|
||||
total_providers = db.query(Provider).count()
|
||||
active_providers = db.query(Provider).filter(Provider.is_active.is_(True)).count()
|
||||
total_api_keys = db.query(ApiKey).count()
|
||||
total_requests = db.query(Usage).count()
|
||||
|
||||
return {
|
||||
"users": {"total": total_users, "active": active_users},
|
||||
"providers": {"total": total_providers, "active": active_providers},
|
||||
"api_keys": total_api_keys,
|
||||
"requests": total_requests,
|
||||
}
|
||||
|
||||
|
||||
class AdminTriggerCleanupAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""手动触发清理任务"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
from src.services.system.cleanup_scheduler import get_cleanup_scheduler
|
||||
|
||||
db = context.db
|
||||
|
||||
# 获取清理前的统计信息
|
||||
total_before = db.query(Usage).count()
|
||||
with_body_before = (
|
||||
db.query(Usage)
|
||||
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
|
||||
.count()
|
||||
)
|
||||
with_headers_before = (
|
||||
db.query(Usage)
|
||||
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
|
||||
.count()
|
||||
)
|
||||
|
||||
# 触发清理
|
||||
cleanup_scheduler = get_cleanup_scheduler()
|
||||
await cleanup_scheduler._perform_cleanup()
|
||||
|
||||
# 获取清理后的统计信息
|
||||
total_after = db.query(Usage).count()
|
||||
with_body_after = (
|
||||
db.query(Usage)
|
||||
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
|
||||
.count()
|
||||
)
|
||||
with_headers_after = (
|
||||
db.query(Usage)
|
||||
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "清理任务执行完成",
|
||||
"stats": {
|
||||
"total_records": {
|
||||
"before": total_before,
|
||||
"after": total_after,
|
||||
"deleted": total_before - total_after,
|
||||
},
|
||||
"body_fields": {
|
||||
"before": with_body_before,
|
||||
"after": with_body_after,
|
||||
"cleaned": with_body_before - with_body_after,
|
||||
},
|
||||
"header_fields": {
|
||||
"before": with_headers_before,
|
||||
"after": with_headers_after,
|
||||
"cleaned": with_headers_before - with_headers_after,
|
||||
},
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class AdminGetApiFormatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""获取所有可用的API格式"""
|
||||
from src.core.api_format_metadata import API_FORMAT_DEFINITIONS
|
||||
from src.core.enums import APIFormat
|
||||
|
||||
_ = context # 参数保留以符合接口规范
|
||||
|
||||
formats = []
|
||||
for api_format in APIFormat:
|
||||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||||
formats.append(
|
||||
{
|
||||
"value": api_format.value,
|
||||
"label": api_format.value,
|
||||
"default_path": definition.default_path if definition else "/",
|
||||
"aliases": list(definition.aliases) if definition else [],
|
||||
}
|
||||
)
|
||||
|
||||
return {"formats": formats}
|
||||
5
src/api/admin/usage/__init__.py
Normal file
5
src/api/admin/usage/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Usage admin routes export."""
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
818
src/api/admin/usage/routes.py
Normal file
818
src/api/admin/usage/routes.py
Normal file
@@ -0,0 +1,818 @@
|
||||
"""管理员使用情况统计路由。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.database import get_db
|
||||
from src.models.database import (
|
||||
ApiKey,
|
||||
Provider,
|
||||
ProviderAPIKey,
|
||||
ProviderEndpoint,
|
||||
RequestCandidate,
|
||||
Usage,
|
||||
User,
|
||||
)
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/usage", tags=["Admin - Usage"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# ==================== RESTful Routes ====================
|
||||
|
||||
|
||||
@router.get("/aggregation/stats")
|
||||
async def get_usage_aggregation(
|
||||
request: Request,
|
||||
group_by: str = Query(..., description="Aggregation dimension: model, user, provider, or api_format"),
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get usage aggregation by specified dimension.
|
||||
|
||||
- group_by=model: Aggregate by model
|
||||
- group_by=user: Aggregate by user
|
||||
- group_by=provider: Aggregate by provider
|
||||
- group_by=api_format: Aggregate by API format
|
||||
"""
|
||||
if group_by == "model":
|
||||
adapter = AdminUsageByModelAdapter(start_date=start_date, end_date=end_date, limit=limit)
|
||||
elif group_by == "user":
|
||||
adapter = AdminUsageByUserAdapter(start_date=start_date, end_date=end_date, limit=limit)
|
||||
elif group_by == "provider":
|
||||
adapter = AdminUsageByProviderAdapter(start_date=start_date, end_date=end_date, limit=limit)
|
||||
elif group_by == "api_format":
|
||||
adapter = AdminUsageByApiFormatAdapter(start_date=start_date, end_date=end_date, limit=limit)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid group_by value: {group_by}. Must be one of: model, user, provider, api_format"
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_usage_stats(
|
||||
request: Request,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminUsageStatsAdapter(start_date=start_date, end_date=end_date)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/records")
|
||||
async def get_usage_records(
|
||||
request: Request,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
status: Optional[str] = None, # stream, standard, error
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminUsageRecordsAdapter(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
model=model,
|
||||
provider=provider,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/active")
|
||||
async def get_active_requests(
|
||||
request: Request,
|
||||
ids: Optional[str] = Query(None, description="逗号分隔的请求 ID 列表,用于查询特定请求的状态"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取活跃请求的状态(轻量级接口,用于前端轮询)
|
||||
|
||||
- 如果提供 ids 参数,只返回这些 ID 对应请求的最新状态
|
||||
- 如果不提供 ids,返回所有 pending/streaming 状态的请求
|
||||
"""
|
||||
adapter = AdminActiveRequestsAdapter(ids=ids)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# NOTE: This route must be defined AFTER all other routes to avoid matching
|
||||
# routes like /stats, /records, /active, etc.
|
||||
@router.get("/{usage_id}")
|
||||
async def get_usage_detail(
|
||||
usage_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get detailed information of a specific usage record.
|
||||
|
||||
Includes request/response headers and body.
|
||||
"""
|
||||
adapter = AdminUsageDetailAdapter(usage_id=usage_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminUsageStatsAdapter(AdminApiAdapter):
|
||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime]):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = db.query(Usage)
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
total_stats = query.with_entities(
|
||||
func.count(Usage.id).label("total_requests"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
func.sum(Usage.actual_total_cost_usd).label("total_actual_cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
|
||||
).first()
|
||||
|
||||
# 缓存统计
|
||||
cache_stats = query.with_entities(
|
||||
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
|
||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
|
||||
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
|
||||
).first()
|
||||
|
||||
# 错误统计
|
||||
error_count = query.filter(
|
||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||
).count()
|
||||
|
||||
activity_heatmap = UsageService.get_daily_activity(
|
||||
db=db,
|
||||
window_days=365,
|
||||
include_actual_cost=True,
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_stats",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
)
|
||||
|
||||
total_requests = total_stats.total_requests if total_stats else 0
|
||||
avg_response_time_ms = float(total_stats.avg_response_time_ms or 0) if total_stats else 0
|
||||
avg_response_time = avg_response_time_ms / 1000.0
|
||||
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"total_tokens": int(total_stats.total_tokens or 0),
|
||||
"total_cost": float(total_stats.total_cost or 0),
|
||||
"total_actual_cost": float(total_stats.total_actual_cost or 0),
|
||||
"avg_response_time": round(avg_response_time, 2),
|
||||
"error_count": error_count,
|
||||
"error_rate": (
|
||||
round((error_count / total_requests) * 100, 2) if total_requests > 0 else 0
|
||||
),
|
||||
"cache_stats": {
|
||||
"cache_creation_tokens": (
|
||||
int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
|
||||
),
|
||||
"cache_read_tokens": int(cache_stats.cache_read_tokens or 0) if cache_stats else 0,
|
||||
"cache_creation_cost": (
|
||||
float(cache_stats.cache_creation_cost or 0) if cache_stats else 0
|
||||
),
|
||||
"cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0,
|
||||
},
|
||||
"activity_heatmap": activity_heatmap,
|
||||
}
|
||||
|
||||
|
||||
class AdminUsageByModelAdapter(AdminApiAdapter):
|
||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.limit = limit
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = db.query(
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("request_count"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
|
||||
)
|
||||
# 过滤掉 pending/streaming 状态的请求(尚未完成的请求不应计入统计)
|
||||
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
|
||||
# 过滤掉 unknown/pending provider(请求未到达任何提供商)
|
||||
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
|
||||
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
query = query.group_by(Usage.model).order_by(func.count(Usage.id).desc()).limit(self.limit)
|
||||
stats = query.all()
|
||||
context.add_audit_metadata(
|
||||
action="usage_by_model",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
limit=self.limit,
|
||||
result_count=len(stats),
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"model": model,
|
||||
"request_count": count,
|
||||
"total_tokens": int(tokens or 0),
|
||||
"total_cost": float(cost or 0),
|
||||
"actual_cost": float(actual_cost or 0),
|
||||
}
|
||||
for model, count, tokens, cost, actual_cost in stats
|
||||
]
|
||||
|
||||
|
||||
class AdminUsageByUserAdapter(AdminApiAdapter):
|
||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.limit = limit
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = (
|
||||
db.query(
|
||||
User.id,
|
||||
User.email,
|
||||
User.username,
|
||||
func.count(Usage.id).label("request_count"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
)
|
||||
.join(Usage, Usage.user_id == User.id)
|
||||
.group_by(User.id, User.email, User.username)
|
||||
)
|
||||
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
query = query.order_by(func.count(Usage.id).desc()).limit(self.limit)
|
||||
stats = query.all()
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_by_user",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
limit=self.limit,
|
||||
result_count=len(stats),
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"user_id": user_id,
|
||||
"email": email,
|
||||
"username": username,
|
||||
"request_count": count,
|
||||
"total_tokens": int(tokens or 0),
|
||||
"total_cost": float(cost or 0),
|
||||
}
|
||||
for user_id, email, username, count, tokens, cost in stats
|
||||
]
|
||||
|
||||
|
||||
class AdminUsageByProviderAdapter(AdminApiAdapter):
|
||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.limit = limit
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 从 request_candidates 表统计每个 Provider 的尝试次数和成功率
|
||||
# 这样可以正确统计 Fallback 场景(一个请求可能尝试多个 Provider)
|
||||
from sqlalchemy import case, Integer
|
||||
|
||||
attempt_query = db.query(
|
||||
RequestCandidate.provider_id,
|
||||
func.count(RequestCandidate.id).label("attempt_count"),
|
||||
func.sum(
|
||||
case((RequestCandidate.status == "success", 1), else_=0)
|
||||
).label("success_count"),
|
||||
func.sum(
|
||||
case((RequestCandidate.status == "failed", 1), else_=0)
|
||||
).label("failed_count"),
|
||||
func.avg(RequestCandidate.latency_ms).label("avg_latency_ms"),
|
||||
).filter(
|
||||
RequestCandidate.provider_id.isnot(None),
|
||||
# 只统计实际执行的尝试(排除 available/skipped 状态)
|
||||
RequestCandidate.status.in_(["success", "failed"]),
|
||||
)
|
||||
|
||||
if self.start_date:
|
||||
attempt_query = attempt_query.filter(RequestCandidate.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
attempt_query = attempt_query.filter(RequestCandidate.created_at <= self.end_date)
|
||||
|
||||
attempt_stats = (
|
||||
attempt_query.group_by(RequestCandidate.provider_id)
|
||||
.order_by(func.count(RequestCandidate.id).desc())
|
||||
.limit(self.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 从 Usage 表获取 token 和费用统计(基于成功的请求)
|
||||
usage_query = db.query(
|
||||
Usage.provider_id,
|
||||
func.count(Usage.id).label("request_count"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
|
||||
).filter(
|
||||
Usage.provider_id.isnot(None),
|
||||
# 过滤掉 pending/streaming 状态的请求
|
||||
Usage.status.notin_(["pending", "streaming"]),
|
||||
)
|
||||
|
||||
if self.start_date:
|
||||
usage_query = usage_query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
usage_query = usage_query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
usage_stats = usage_query.group_by(Usage.provider_id).all()
|
||||
usage_map = {str(u.provider_id): u for u in usage_stats}
|
||||
|
||||
# 获取所有相关的 Provider ID
|
||||
provider_ids = set()
|
||||
for stat in attempt_stats:
|
||||
if stat.provider_id:
|
||||
provider_ids.add(stat.provider_id)
|
||||
for stat in usage_stats:
|
||||
if stat.provider_id:
|
||||
provider_ids.add(stat.provider_id)
|
||||
|
||||
# 获取 Provider 名称映射
|
||||
provider_map = {}
|
||||
if provider_ids:
|
||||
providers_data = (
|
||||
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
|
||||
)
|
||||
provider_map = {str(p.id): p.name for p in providers_data}
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_by_provider",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
limit=self.limit,
|
||||
result_count=len(attempt_stats),
|
||||
)
|
||||
|
||||
result = []
|
||||
for stat in attempt_stats:
|
||||
provider_id_str = str(stat.provider_id) if stat.provider_id else None
|
||||
attempt_count = stat.attempt_count or 0
|
||||
success_count = int(stat.success_count or 0)
|
||||
failed_count = int(stat.failed_count or 0)
|
||||
success_rate = (success_count / attempt_count * 100) if attempt_count > 0 else 0
|
||||
|
||||
# 从 usage_map 获取 token 和费用信息
|
||||
usage_stat = usage_map.get(provider_id_str)
|
||||
|
||||
result.append({
|
||||
"provider_id": provider_id_str,
|
||||
"provider": provider_map.get(provider_id_str, "Unknown"),
|
||||
"request_count": attempt_count, # 尝试次数
|
||||
"total_tokens": int(usage_stat.total_tokens or 0) if usage_stat else 0,
|
||||
"total_cost": float(usage_stat.total_cost or 0) if usage_stat else 0,
|
||||
"actual_cost": float(usage_stat.actual_cost or 0) if usage_stat else 0,
|
||||
"avg_response_time_ms": float(stat.avg_latency_ms or 0),
|
||||
"success_rate": round(success_rate, 2),
|
||||
"error_count": failed_count,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class AdminUsageByApiFormatAdapter(AdminApiAdapter):
|
||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.limit = limit
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = db.query(
|
||||
Usage.api_format,
|
||||
func.count(Usage.id).label("request_count"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
|
||||
)
|
||||
# 过滤掉 pending/streaming 状态的请求
|
||||
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
|
||||
# 过滤掉 unknown/pending provider
|
||||
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
|
||||
# 只统计有 api_format 的记录
|
||||
query = query.filter(Usage.api_format.isnot(None))
|
||||
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
query = (
|
||||
query.group_by(Usage.api_format)
|
||||
.order_by(func.count(Usage.id).desc())
|
||||
.limit(self.limit)
|
||||
)
|
||||
stats = query.all()
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_by_api_format",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
limit=self.limit,
|
||||
result_count=len(stats),
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"api_format": api_format or "unknown",
|
||||
"request_count": count,
|
||||
"total_tokens": int(tokens or 0),
|
||||
"total_cost": float(cost or 0),
|
||||
"actual_cost": float(actual_cost or 0),
|
||||
"avg_response_time_ms": float(avg_response_time or 0),
|
||||
}
|
||||
for api_format, count, tokens, cost, actual_cost, avg_response_time in stats
|
||||
]
|
||||
|
||||
|
||||
class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
start_date: Optional[datetime],
|
||||
end_date: Optional[datetime],
|
||||
user_id: Optional[str],
|
||||
username: Optional[str],
|
||||
model: Optional[str],
|
||||
provider: Optional[str],
|
||||
status: Optional[str],
|
||||
limit: int,
|
||||
offset: int,
|
||||
):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.user_id = user_id
|
||||
self.username = username
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
self.status = status
|
||||
self.limit = limit
|
||||
self.offset = offset
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = (
|
||||
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey)
|
||||
.outerjoin(User, Usage.user_id == User.id)
|
||||
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
|
||||
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
|
||||
)
|
||||
if self.user_id:
|
||||
query = query.filter(Usage.user_id == self.user_id)
|
||||
if self.username:
|
||||
# 支持用户名模糊搜索
|
||||
query = query.filter(User.username.ilike(f"%{self.username}%"))
|
||||
if self.model:
|
||||
# 支持模型名模糊搜索
|
||||
query = query.filter(Usage.model.ilike(f"%{self.model}%"))
|
||||
if self.provider:
|
||||
# 支持提供商名称搜索(通过 Provider 表)
|
||||
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
|
||||
query = query.filter(Provider.name.ilike(f"%{self.provider}%"))
|
||||
if self.status:
|
||||
# 状态筛选
|
||||
# 旧的筛选值(基于 is_stream 和 status_code):stream, standard, error
|
||||
# 新的筛选值(基于 status 字段):pending, streaming, completed, failed, active
|
||||
if self.status == "stream":
|
||||
query = query.filter(Usage.is_stream == True) # noqa: E712
|
||||
elif self.status == "standard":
|
||||
query = query.filter(Usage.is_stream == False) # noqa: E712
|
||||
elif self.status == "error":
|
||||
query = query.filter(
|
||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||
)
|
||||
elif self.status in ("pending", "streaming", "completed", "failed"):
|
||||
# 新的状态筛选:直接按 status 字段过滤
|
||||
query = query.filter(Usage.status == self.status)
|
||||
elif self.status == "active":
|
||||
# 活跃请求:pending 或 streaming 状态
|
||||
query = query.filter(Usage.status.in_(["pending", "streaming"]))
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
total = query.count()
|
||||
records = (
|
||||
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
|
||||
)
|
||||
|
||||
request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id]
|
||||
fallback_map = {}
|
||||
if request_ids:
|
||||
# 只统计实际执行的候选(success 或 failed),不包括 skipped/pending/available
|
||||
executed_counts = (
|
||||
db.query(RequestCandidate.request_id, func.count(RequestCandidate.id))
|
||||
.filter(
|
||||
RequestCandidate.request_id.in_(request_ids),
|
||||
RequestCandidate.status.in_(["success", "failed"]),
|
||||
)
|
||||
.group_by(RequestCandidate.request_id)
|
||||
.all()
|
||||
)
|
||||
# 如果实际执行的候选数 > 1,说明发生了 Provider 切换
|
||||
fallback_map = {req_id: count > 1 for req_id, count in executed_counts}
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_records",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
user_id=self.user_id,
|
||||
username=self.username,
|
||||
model=self.model,
|
||||
provider=self.provider,
|
||||
status=self.status,
|
||||
limit=self.limit,
|
||||
offset=self.offset,
|
||||
total=total,
|
||||
)
|
||||
|
||||
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
|
||||
provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id]
|
||||
provider_map = {}
|
||||
if provider_ids:
|
||||
providers_data = (
|
||||
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
|
||||
)
|
||||
provider_map = {str(p.id): p.name for p in providers_data}
|
||||
|
||||
data = []
|
||||
for usage, user, endpoint, api_key in records:
|
||||
actual_cost = (
|
||||
float(usage.actual_total_cost_usd)
|
||||
if usage.actual_total_cost_usd is not None
|
||||
else 0.0
|
||||
)
|
||||
rate_multiplier = (
|
||||
float(usage.rate_multiplier) if usage.rate_multiplier is not None else 1.0
|
||||
)
|
||||
|
||||
# 提供商名称优先级:关联的 Provider 表 > usage.provider 字段
|
||||
provider_name = usage.provider
|
||||
if usage.provider_id and str(usage.provider_id) in provider_map:
|
||||
provider_name = provider_map[str(usage.provider_id)]
|
||||
|
||||
data.append(
|
||||
{
|
||||
"id": usage.id,
|
||||
"user_id": user.id if user else None,
|
||||
"user_email": user.email if user else "已删除用户",
|
||||
"username": user.username if user else "已删除用户",
|
||||
"provider": provider_name,
|
||||
"model": usage.model,
|
||||
"target_model": usage.target_model, # 映射后的目标模型名
|
||||
"input_tokens": usage.input_tokens,
|
||||
"output_tokens": usage.output_tokens,
|
||||
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
"cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
"cost": float(usage.total_cost_usd),
|
||||
"actual_cost": actual_cost,
|
||||
"rate_multiplier": rate_multiplier,
|
||||
"response_time_ms": usage.response_time_ms,
|
||||
"created_at": usage.created_at.isoformat(),
|
||||
"is_stream": usage.is_stream,
|
||||
"input_price_per_1m": usage.input_price_per_1m,
|
||||
"output_price_per_1m": usage.output_price_per_1m,
|
||||
"cache_creation_price_per_1m": usage.cache_creation_price_per_1m,
|
||||
"cache_read_price_per_1m": usage.cache_read_price_per_1m,
|
||||
"status_code": usage.status_code,
|
||||
"error_message": usage.error_message,
|
||||
"status": usage.status, # 请求状态: pending, streaming, completed, failed
|
||||
"has_fallback": fallback_map.get(usage.request_id, False),
|
||||
"api_format": usage.api_format
|
||||
or (endpoint.api_format if endpoint and endpoint.api_format else None),
|
||||
"api_key_name": api_key.name if api_key else None,
|
||||
"request_metadata": usage.request_metadata, # Provider 响应元数据
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"records": data,
|
||||
"total": total,
|
||||
"limit": self.limit,
|
||||
"offset": self.offset,
|
||||
}
|
||||
|
||||
|
||||
class AdminActiveRequestsAdapter(AdminApiAdapter):
|
||||
"""轻量级活跃请求状态查询适配器"""
|
||||
|
||||
def __init__(self, ids: Optional[str]):
|
||||
self.ids = ids
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
if self.ids:
|
||||
# 查询指定 ID 的请求状态
|
||||
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
|
||||
if not id_list:
|
||||
return {"requests": []}
|
||||
|
||||
records = (
|
||||
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
|
||||
.filter(Usage.id.in_(id_list))
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
# 查询所有活跃请求(pending 或 streaming)
|
||||
records = (
|
||||
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
|
||||
.filter(Usage.status.in_(["pending", "streaming"]))
|
||||
.order_by(Usage.created_at.desc())
|
||||
.limit(50)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"requests": [
|
||||
{
|
||||
"id": r.id,
|
||||
"status": r.status,
|
||||
"input_tokens": r.input_tokens,
|
||||
"output_tokens": r.output_tokens,
|
||||
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
|
||||
"response_time_ms": r.response_time_ms,
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUsageDetailAdapter(AdminApiAdapter):
|
||||
"""Get detailed usage record with request/response body"""
|
||||
|
||||
usage_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
usage_record = db.query(Usage).filter(Usage.id == self.usage_id).first()
|
||||
if not usage_record:
|
||||
raise HTTPException(status_code=404, detail="Usage record not found")
|
||||
|
||||
user = db.query(User).filter(User.id == usage_record.user_id).first()
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == usage_record.api_key_id).first()
|
||||
|
||||
# 获取阶梯计费信息
|
||||
tiered_pricing_info = await self._get_tiered_pricing_info(db, usage_record)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_detail",
|
||||
usage_id=self.usage_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": usage_record.id,
|
||||
"request_id": usage_record.request_id,
|
||||
"user": {
|
||||
"id": user.id if user else None,
|
||||
"username": user.username if user else "Unknown",
|
||||
"email": user.email if user else None,
|
||||
},
|
||||
"api_key": {
|
||||
"id": api_key.id if api_key else None,
|
||||
"name": api_key.name if api_key else None,
|
||||
"display": api_key.get_display_key() if api_key else None,
|
||||
},
|
||||
"provider": usage_record.provider,
|
||||
"api_format": usage_record.api_format,
|
||||
"model": usage_record.model,
|
||||
"target_model": usage_record.target_model,
|
||||
"tokens": {
|
||||
"input": usage_record.input_tokens,
|
||||
"output": usage_record.output_tokens,
|
||||
"total": usage_record.total_tokens,
|
||||
},
|
||||
"cost": {
|
||||
"input": usage_record.input_cost_usd,
|
||||
"output": usage_record.output_cost_usd,
|
||||
"total": usage_record.total_cost_usd,
|
||||
},
|
||||
"cache_creation_input_tokens": usage_record.cache_creation_input_tokens,
|
||||
"cache_read_input_tokens": usage_record.cache_read_input_tokens,
|
||||
"cache_creation_cost": getattr(usage_record, "cache_creation_cost_usd", 0.0),
|
||||
"cache_read_cost": getattr(usage_record, "cache_read_cost_usd", 0.0),
|
||||
"request_cost": getattr(usage_record, "request_cost_usd", 0.0),
|
||||
"input_price_per_1m": usage_record.input_price_per_1m,
|
||||
"output_price_per_1m": usage_record.output_price_per_1m,
|
||||
"cache_creation_price_per_1m": usage_record.cache_creation_price_per_1m,
|
||||
"cache_read_price_per_1m": usage_record.cache_read_price_per_1m,
|
||||
"price_per_request": usage_record.price_per_request,
|
||||
"request_type": usage_record.request_type,
|
||||
"is_stream": usage_record.is_stream,
|
||||
"status_code": usage_record.status_code,
|
||||
"error_message": usage_record.error_message,
|
||||
"response_time_ms": usage_record.response_time_ms,
|
||||
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
|
||||
"request_headers": usage_record.request_headers,
|
||||
"request_body": usage_record.get_request_body(),
|
||||
"provider_request_headers": usage_record.provider_request_headers,
|
||||
"response_headers": usage_record.response_headers,
|
||||
"response_body": usage_record.get_response_body(),
|
||||
"metadata": usage_record.request_metadata,
|
||||
"tiered_pricing": tiered_pricing_info,
|
||||
}
|
||||
|
||||
async def _get_tiered_pricing_info(self, db, usage_record) -> dict | None:
|
||||
"""获取阶梯计费信息"""
|
||||
from src.services.model.cost import ModelCostService
|
||||
|
||||
# 计算总输入上下文(用于阶梯判定):输入 + 缓存创建 + 缓存读取
|
||||
input_tokens = usage_record.input_tokens or 0
|
||||
cache_creation_tokens = usage_record.cache_creation_input_tokens or 0
|
||||
cache_read_tokens = usage_record.cache_read_input_tokens or 0
|
||||
total_input_context = input_tokens + cache_creation_tokens + cache_read_tokens
|
||||
|
||||
# 尝试获取模型的阶梯配置(带来源信息)
|
||||
cost_service = ModelCostService(db)
|
||||
pricing_result = await cost_service.get_tiered_pricing_with_source_async(
|
||||
usage_record.provider, usage_record.model
|
||||
)
|
||||
|
||||
if not pricing_result:
|
||||
return None
|
||||
|
||||
tiered_pricing = pricing_result.get("pricing")
|
||||
pricing_source = pricing_result.get("source") # 'provider' 或 'global'
|
||||
|
||||
if not tiered_pricing or not tiered_pricing.get("tiers"):
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
# 找到命中的阶梯
|
||||
tier_index = None
|
||||
matched_tier = None
|
||||
for i, tier in enumerate(tiers):
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_context <= up_to:
|
||||
tier_index = i
|
||||
matched_tier = tier
|
||||
break
|
||||
|
||||
# 如果都没匹配,使用最后一个阶梯
|
||||
if tier_index is None and tiers:
|
||||
tier_index = len(tiers) - 1
|
||||
matched_tier = tiers[-1]
|
||||
|
||||
return {
|
||||
"total_input_context": total_input_context,
|
||||
"tier_index": tier_index,
|
||||
"tier_count": len(tiers),
|
||||
"current_tier": matched_tier,
|
||||
"tiers": tiers,
|
||||
"source": pricing_source, # 定价来源: 'provider' 或 'global'
|
||||
}
|
||||
5
src/api/admin/users/__init__.py
Normal file
5
src/api/admin/users/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""User admin routes export."""
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
488
src/api/admin/users/routes.py
Normal file
488
src/api/admin/users/routes.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""用户管理 API 端点。"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException, translate_pydantic_error
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.admin_requests import UpdateUserRequest
|
||||
from src.models.api import CreateApiKeyRequest, CreateUserRequest
|
||||
from src.models.database import ApiKey, User, UserRole
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
from src.services.user.service import UserService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/admin/users", tags=["Admin - Users"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# 管理员端点
|
||||
@router.post("")
|
||||
async def create_user_endpoint(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminCreateUserAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_users(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
role: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminListUsersAdapter(skip=skip, limit=limit, role=role, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{user_id}")
|
||||
async def get_user(user_id: str, request: Request, db: Session = Depends(get_db)): # UUID
|
||||
adapter = AdminGetUserAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{user_id}")
|
||||
async def update_user(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminUpdateUserAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
async def delete_user(user_id: str, request: Request, db: Session = Depends(get_db)): # UUID
|
||||
adapter = AdminDeleteUserAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{user_id}/quota")
|
||||
async def reset_user_quota(user_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Reset user quota (set used_usd to 0)"""
|
||||
adapter = AdminResetUserQuotaAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{user_id}/api-keys")
|
||||
async def get_user_api_keys(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
is_active: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取用户的所有API Keys(不包括独立Keys)"""
|
||||
adapter = AdminGetUserKeysAdapter(user_id=user_id, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/{user_id}/api-keys")
|
||||
async def create_user_api_key(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""为用户创建API Key"""
|
||||
adapter = AdminCreateUserKeyAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{user_id}/api-keys/{key_id}")
|
||||
async def delete_user_api_key(
|
||||
user_id: str,
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除用户的API Key"""
|
||||
adapter = AdminDeleteUserKeyAdapter(user_id=user_id, key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ============== 管理员适配器实现 ==============
|
||||
|
||||
|
||||
class AdminCreateUserAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
request = CreateUserRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
try:
|
||||
role = (
|
||||
request.role if hasattr(request.role, "value") else UserRole[request.role.upper()]
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
raise InvalidRequestException("角色参数不合法")
|
||||
|
||||
try:
|
||||
user = UserService.create_user(
|
||||
db=db,
|
||||
email=request.email,
|
||||
username=request.username,
|
||||
password=request.password,
|
||||
role=role,
|
||||
quota_usd=request.quota_usd,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise InvalidRequestException(str(exc))
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="create_user",
|
||||
target_user_id=user.id,
|
||||
target_email=user.email,
|
||||
target_username=user.username,
|
||||
target_role=user.role.value,
|
||||
quota_usd=user.quota_usd,
|
||||
is_active=user.is_active,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_models": user.allowed_models,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": getattr(user, "total_usd", 0),
|
||||
"is_active": user.is_active,
|
||||
"created_at": user.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class AdminListUsersAdapter(AdminApiAdapter):
|
||||
def __init__(self, skip: int, limit: int, role: Optional[str], is_active: Optional[bool]):
|
||||
self.skip = skip
|
||||
self.limit = limit
|
||||
self.role = role
|
||||
self.is_active = is_active
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
role_enum = UserRole[self.role.upper()] if self.role else None
|
||||
users = UserService.list_users(db, self.skip, self.limit, role_enum, self.is_active)
|
||||
return [
|
||||
{
|
||||
"id": u.id,
|
||||
"email": u.email,
|
||||
"username": u.username,
|
||||
"role": u.role.value,
|
||||
"quota_usd": u.quota_usd,
|
||||
"used_usd": u.used_usd,
|
||||
"total_usd": getattr(u, "total_usd", 0),
|
||||
"is_active": u.is_active,
|
||||
"created_at": u.created_at.isoformat(),
|
||||
}
|
||||
for u in users
|
||||
]
|
||||
|
||||
|
||||
class AdminGetUserAdapter(AdminApiAdapter):
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = UserService.get_user(db, self.user_id)
|
||||
if not user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="get_user_detail",
|
||||
target_user_id=user.id,
|
||||
target_role=user.role.value,
|
||||
include_history=bool(user.last_login_at),
|
||||
)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_models": user.allowed_models,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": getattr(user, "total_usd", 0),
|
||||
"is_active": user.is_active,
|
||||
"created_at": user.created_at.isoformat(),
|
||||
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
|
||||
"last_login_at": user.last_login_at.isoformat() if user.last_login_at else None,
|
||||
}
|
||||
|
||||
|
||||
class AdminUpdateUserAdapter(AdminApiAdapter):
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
existing_user = UserService.get_user(db, self.user_id)
|
||||
if not existing_user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
request = UpdateUserRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if "role" in update_data and update_data["role"]:
|
||||
if hasattr(update_data["role"], "value"):
|
||||
update_data["role"] = update_data["role"]
|
||||
else:
|
||||
update_data["role"] = UserRole[update_data["role"].upper()]
|
||||
|
||||
user = UserService.update_user(db, self.user_id, **update_data)
|
||||
if not user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
changed_fields = list(update_data.keys())
|
||||
context.add_audit_metadata(
|
||||
action="update_user",
|
||||
target_user_id=user.id,
|
||||
updated_fields=changed_fields,
|
||||
role_before=existing_user.role.value if existing_user.role else None,
|
||||
role_after=user.role.value,
|
||||
quota_usd=user.quota_usd,
|
||||
is_active=user.is_active,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_models": user.allowed_models,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": getattr(user, "total_usd", 0),
|
||||
"is_active": user.is_active,
|
||||
"created_at": user.created_at.isoformat(),
|
||||
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class AdminDeleteUserAdapter(AdminApiAdapter):
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = UserService.get_user(db, self.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||
|
||||
if user.role == UserRole.ADMIN:
|
||||
admin_count = db.query(User).filter(User.role == UserRole.ADMIN).count()
|
||||
if admin_count <= 1:
|
||||
raise InvalidRequestException("不能删除最后一个管理员账户")
|
||||
|
||||
success = UserService.delete_user(db, self.user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="delete_user",
|
||||
target_user_id=user.id,
|
||||
target_email=user.email,
|
||||
target_role=user.role.value,
|
||||
)
|
||||
|
||||
return {"message": "用户删除成功"}
|
||||
|
||||
|
||||
class AdminResetUserQuotaAdapter(AdminApiAdapter):
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = UserService.get_user(db, self.user_id)
|
||||
if not user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
user.used_usd = 0.0
|
||||
user.total_usd = getattr(user, "total_usd", 0)
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="reset_user_quota",
|
||||
target_user_id=user.id,
|
||||
quota_usd=user.quota_usd,
|
||||
used_usd=user.used_usd,
|
||||
total_usd=user.total_usd,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "配额已重置",
|
||||
"user_id": user.id,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": user.total_usd,
|
||||
}
|
||||
|
||||
|
||||
class AdminGetUserKeysAdapter(AdminApiAdapter):
|
||||
"""获取用户的API Keys"""
|
||||
|
||||
def __init__(self, user_id: str, is_active: Optional[bool]):
|
||||
self.user_id = user_id
|
||||
self.is_active = is_active
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 验证用户存在
|
||||
user = db.query(User).filter(User.id == self.user_id).first()
|
||||
if not user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
# 获取用户的Keys(不包括独立Keys)
|
||||
api_keys = ApiKeyService.list_user_api_keys(
|
||||
db=db, user_id=self.user_id, is_active=self.is_active
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="list_user_api_keys",
|
||||
target_user_id=self.user_id,
|
||||
total=len(api_keys),
|
||||
)
|
||||
|
||||
return {
|
||||
"api_keys": [
|
||||
{
|
||||
"id": key.id,
|
||||
"name": key.name,
|
||||
"key_display": key.get_display_key(),
|
||||
"is_active": key.is_active,
|
||||
"total_requests": key.total_requests,
|
||||
"total_cost_usd": float(key.total_cost_usd or 0),
|
||||
"rate_limit": key.rate_limit,
|
||||
"expires_at": key.expires_at.isoformat() if key.expires_at else None,
|
||||
"last_used_at": key.last_used_at.isoformat() if key.last_used_at else None,
|
||||
"created_at": key.created_at.isoformat(),
|
||||
}
|
||||
for key in api_keys
|
||||
],
|
||||
"total": len(api_keys),
|
||||
"user_email": user.email,
|
||||
"username": user.username,
|
||||
}
|
||||
|
||||
|
||||
class AdminCreateUserKeyAdapter(AdminApiAdapter):
|
||||
"""为用户创建API Key"""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
key_data = CreateApiKeyRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
# 验证用户存在
|
||||
user = db.query(User).filter(User.id == self.user_id).first()
|
||||
if not user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
# 为用户创建Key(不是独立Key)
|
||||
api_key, plain_key = ApiKeyService.create_api_key(
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
name=key_data.name,
|
||||
allowed_providers=key_data.allowed_providers,
|
||||
allowed_models=key_data.allowed_models,
|
||||
rate_limit=key_data.rate_limit or 100,
|
||||
expire_days=key_data.expire_days,
|
||||
initial_balance_usd=None, # 普通Key不设置余额限制
|
||||
is_standalone=False, # 不是独立Key
|
||||
)
|
||||
|
||||
logger.info(f"管理员为用户创建API Key: 用户 {user.email}, Key ID {api_key.id}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="create_user_api_key",
|
||||
target_user_id=self.user_id,
|
||||
key_id=api_key.id,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"key": plain_key, # 只在创建时返回
|
||||
"name": api_key.name,
|
||||
"key_display": api_key.get_display_key(),
|
||||
"rate_limit": api_key.rate_limit,
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"message": "API Key创建成功,请妥善保存完整密钥",
|
||||
}
|
||||
|
||||
|
||||
class AdminDeleteUserKeyAdapter(AdminApiAdapter):
|
||||
"""删除用户的API Key"""
|
||||
|
||||
def __init__(self, user_id: str, key_id: str):
|
||||
self.user_id = user_id
|
||||
self.key_id = key_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 验证Key存在且属于该用户
|
||||
api_key = (
|
||||
db.query(ApiKey)
|
||||
.filter(
|
||||
ApiKey.id == self.key_id,
|
||||
ApiKey.user_id == self.user_id,
|
||||
ApiKey.is_standalone == False, # 只能删除普通Key
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundException("API Key不存在或不属于该用户", "api_key")
|
||||
|
||||
db.delete(api_key)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"管理员删除用户API Key: 用户ID {self.user_id}, Key ID {self.key_id}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="delete_user_api_key",
|
||||
target_user_id=self.user_id,
|
||||
key_id=self.key_id,
|
||||
)
|
||||
|
||||
return {"message": "API Key已删除"}
|
||||
10
src/api/announcements/__init__.py
Normal file
10
src/api/announcements/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
'"""Announcement system routers."""'
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .routes import router as announcement_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(announcement_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
297
src/api/announcements/routes.py
Normal file
297
src/api/announcements/routes.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""公告系统 API 端点。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import CreateAnnouncementRequest, UpdateAnnouncementRequest
|
||||
from src.models.database import User
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.system.announcement import AnnouncementService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/announcements", tags=["Announcements"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# ============== 公共端点(所有用户可访问) ==============
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_announcements(
|
||||
request: Request,
|
||||
active_only: bool = Query(True, description="只返回有效公告"),
|
||||
limit: int = Query(50, description="返回数量限制"),
|
||||
offset: int = Query(0, description="偏移量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取公告列表(包含已读状态)"""
|
||||
adapter = ListAnnouncementsAdapter(active_only=active_only, limit=limit, offset=offset)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/active")
|
||||
async def get_active_announcements(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取当前有效的公告(首页展示)"""
|
||||
adapter = GetActiveAnnouncementsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{announcement_id}")
|
||||
async def get_announcement(
|
||||
announcement_id: str, # UUID
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取单个公告详情"""
|
||||
adapter = GetAnnouncementAdapter(announcement_id=announcement_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{announcement_id}/read-status")
|
||||
async def mark_announcement_as_read(
|
||||
announcement_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Mark announcement as read"""
|
||||
adapter = MarkAnnouncementReadAdapter(announcement_id=announcement_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ============== 管理员端点 ==============
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_announcement(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建公告(管理员权限)"""
|
||||
adapter = CreateAnnouncementAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{announcement_id}")
|
||||
async def update_announcement(
|
||||
announcement_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新公告(管理员权限)"""
|
||||
adapter = UpdateAnnouncementAdapter(announcement_id=announcement_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{announcement_id}")
|
||||
async def delete_announcement(
|
||||
announcement_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除公告(管理员权限)"""
|
||||
adapter = DeleteAnnouncementAdapter(announcement_id=announcement_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ============== 用户公告端点 ==============
|
||||
|
||||
|
||||
@router.get("/users/me/unread-count")
|
||||
async def get_my_unread_announcement_count(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取我的未读公告数量"""
|
||||
adapter = UnreadAnnouncementCountAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ============== Pipeline 适配器 ==============
|
||||
|
||||
|
||||
class AnnouncementOptionalAuthAdapter(ApiAdapter):
|
||||
"""允许匿名访问,但可选解析Bearer以获取用户上下文。"""
|
||||
|
||||
mode = ApiMode.PUBLIC
|
||||
|
||||
async def authorize(self, context): # type: ignore[override]
|
||||
context.extra["optional_user"] = await self._resolve_optional_user(context)
|
||||
return None
|
||||
|
||||
async def _resolve_optional_user(self, context) -> Optional[User]:
|
||||
if context.user:
|
||||
return context.user
|
||||
|
||||
authorization = context.request.headers.get("authorization")
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
return None
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await AuthService.verify_token(token)
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
user = (
|
||||
context.db.query(User).filter(User.id == user_id, User.is_active.is_(True)).first()
|
||||
)
|
||||
return user
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_optional_user(self, context) -> Optional[User]:
|
||||
return context.extra.get("optional_user")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListAnnouncementsAdapter(AnnouncementOptionalAuthAdapter):
|
||||
active_only: bool
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
optional_user = self.get_optional_user(context)
|
||||
return AnnouncementService.get_announcements(
|
||||
db=context.db,
|
||||
user_id=optional_user.id if optional_user else None,
|
||||
active_only=self.active_only,
|
||||
include_read_status=True if optional_user else False,
|
||||
limit=self.limit,
|
||||
offset=self.offset,
|
||||
)
|
||||
|
||||
|
||||
class GetActiveAnnouncementsAdapter(AnnouncementOptionalAuthAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
optional_user = self.get_optional_user(context)
|
||||
return AnnouncementService.get_active_announcements(
|
||||
db=context.db,
|
||||
user_id=optional_user.id if optional_user else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetAnnouncementAdapter(AnnouncementOptionalAuthAdapter):
|
||||
announcement_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
announcement = AnnouncementService.get_announcement(context.db, self.announcement_id)
|
||||
return {
|
||||
"id": announcement.id,
|
||||
"title": announcement.title,
|
||||
"content": announcement.content,
|
||||
"type": announcement.type,
|
||||
"priority": announcement.priority,
|
||||
"is_pinned": announcement.is_pinned,
|
||||
"author": {"id": announcement.author.id, "username": announcement.author.username},
|
||||
"start_time": announcement.start_time,
|
||||
"end_time": announcement.end_time,
|
||||
"created_at": announcement.created_at,
|
||||
"updated_at": announcement.updated_at,
|
||||
}
|
||||
|
||||
|
||||
class AnnouncementUserAdapter(AuthenticatedApiAdapter):
|
||||
"""需要登录但不要求管理员的公告适配器基类。"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MarkAnnouncementReadAdapter(AnnouncementUserAdapter):
|
||||
def __init__(self, announcement_id: str):
|
||||
self.announcement_id = announcement_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
AnnouncementService.mark_as_read(context.db, self.announcement_id, context.user.id)
|
||||
return {"message": "公告已标记为已读"}
|
||||
|
||||
|
||||
class UnreadAnnouncementCountAdapter(AnnouncementUserAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
result = AnnouncementService.get_announcements(
|
||||
db=context.db,
|
||||
user_id=context.user.id,
|
||||
active_only=True,
|
||||
include_read_status=True,
|
||||
limit=1,
|
||||
offset=0,
|
||||
)
|
||||
return {"unread_count": result.get("unread_count", 0)}
|
||||
|
||||
|
||||
class CreateAnnouncementAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
req = CreateAnnouncementRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
announcement = AnnouncementService.create_announcement(
|
||||
db=context.db,
|
||||
author_id=context.user.id,
|
||||
title=req.title,
|
||||
content=req.content,
|
||||
type=req.type,
|
||||
priority=req.priority,
|
||||
is_pinned=req.is_pinned,
|
||||
start_time=req.start_time,
|
||||
end_time=req.end_time,
|
||||
)
|
||||
return {"id": announcement.id, "title": announcement.title, "message": "公告创建成功"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateAnnouncementAdapter(AdminApiAdapter):
|
||||
announcement_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
req = UpdateAnnouncementRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
AnnouncementService.update_announcement(
|
||||
db=context.db,
|
||||
announcement_id=self.announcement_id,
|
||||
user_id=context.user.id,
|
||||
title=req.title,
|
||||
content=req.content,
|
||||
type=req.type,
|
||||
priority=req.priority,
|
||||
is_active=req.is_active,
|
||||
is_pinned=req.is_pinned,
|
||||
start_time=req.start_time,
|
||||
end_time=req.end_time,
|
||||
)
|
||||
return {"message": "公告更新成功"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteAnnouncementAdapter(AdminApiAdapter):
|
||||
announcement_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
AnnouncementService.delete_announcement(context.db, self.announcement_id, context.user.id)
|
||||
return {"message": "公告已删除"}
|
||||
10
src/api/auth/__init__.py
Normal file
10
src/api/auth/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Authentication route group."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .routes import router as auth_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(auth_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
353
src/api/auth/routes.py
Normal file
353
src/api/auth/routes.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
认证相关API端点
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import (
|
||||
LoginRequest,
|
||||
LoginResponse,
|
||||
LogoutResponse,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenResponse,
|
||||
RegisterRequest,
|
||||
RegisterResponse,
|
||||
)
|
||||
from src.models.database import AuditEventType, User, UserRole
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.rate_limit.ip_limiter import IPRateLimiter
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.user.service import UserService
|
||||
from src.utils.request_utils import get_client_ip, get_user_agent
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["Authentication"])
|
||||
security = HTTPBearer()
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# API端点
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AuthLoginAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=RefreshTokenResponse)
|
||||
async def refresh_token(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AuthRefreshAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/register", response_model=RegisterResponse)
|
||||
async def register(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AuthRegisterAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_current_user_info(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AuthCurrentUserAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/password")
|
||||
async def change_password(request: Request, db: Session = Depends(get_db)):
|
||||
"""Change current user's password"""
|
||||
adapter = AuthChangePasswordAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=LogoutResponse)
|
||||
async def logout(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AuthLogoutAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ============== 适配器实现 ==============
|
||||
|
||||
|
||||
class AuthPublicAdapter(ApiAdapter):
|
||||
mode = ApiMode.PUBLIC
|
||||
|
||||
def authorize(self, context): # type: ignore[override]
|
||||
return None
|
||||
|
||||
|
||||
class AuthLoginAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
login_request = LoginRequest.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
errors = []
|
||||
for error in exc.errors():
|
||||
field = " -> ".join(str(x) for x in error["loc"])
|
||||
errors.append(f"{field}: {error['msg']}")
|
||||
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
|
||||
|
||||
client_ip = get_client_ip(context.request)
|
||||
user_agent = get_user_agent(context.request)
|
||||
|
||||
# IP 速率限制检查(登录接口:5次/分钟)
|
||||
allowed, remaining, reset_after = await IPRateLimiter.check_limit(client_ip, "login")
|
||||
if not allowed:
|
||||
logger.warning(f"登录请求超过速率限制: IP={client_ip}, 剩余={remaining}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"登录请求过于频繁,请在 {reset_after} 秒后重试",
|
||||
)
|
||||
|
||||
user = await AuthService.authenticate_user(db, login_request.email, login_request.password)
|
||||
if not user:
|
||||
AuditService.log_login_attempt(
|
||||
db=db,
|
||||
email=login_request.email,
|
||||
success=False,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
error_reason="邮箱或密码错误",
|
||||
)
|
||||
db.commit()
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="邮箱或密码错误")
|
||||
|
||||
AuditService.log_login_attempt(
|
||||
db=db,
|
||||
email=login_request.email,
|
||||
success=True,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
user_id=user.id,
|
||||
)
|
||||
db.commit()
|
||||
|
||||
access_token = AuthService.create_access_token(
|
||||
data={
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"role": user.role.value,
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
}
|
||||
)
|
||||
refresh_token = AuthService.create_refresh_token(
|
||||
data={"user_id": user.id, "email": user.email}
|
||||
)
|
||||
response = LoginResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=86400,
|
||||
user_id=user.id,
|
||||
email=user.email,
|
||||
username=user.username,
|
||||
role=user.role.value,
|
||||
)
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
class AuthRefreshAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
refresh_request = RefreshTokenRequest.model_validate(payload)
|
||||
client_ip = get_client_ip(context.request)
|
||||
user_agent = get_user_agent(context.request)
|
||||
|
||||
try:
|
||||
token_payload = await AuthService.verify_token(
|
||||
refresh_request.refresh_token, token_type="refresh"
|
||||
)
|
||||
user_id = token_payload.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的刷新令牌"
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的刷新令牌"
|
||||
)
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户已禁用")
|
||||
|
||||
new_access_token = AuthService.create_access_token(
|
||||
data={
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"role": user.role.value,
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
}
|
||||
)
|
||||
new_refresh_token = AuthService.create_refresh_token(
|
||||
data={"user_id": user.id, "email": user.email}
|
||||
)
|
||||
logger.info(f"令牌刷新成功: {user.email}")
|
||||
return RefreshTokenResponse(
|
||||
access_token=new_access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=86400,
|
||||
).model_dump()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"刷新令牌失败: {exc}")
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="刷新令牌失败")
|
||||
|
||||
|
||||
class AuthRegisterAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from ..models.database import SystemConfig
|
||||
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
register_request = RegisterRequest.model_validate(payload)
|
||||
client_ip = get_client_ip(context.request)
|
||||
user_agent = get_user_agent(context.request)
|
||||
|
||||
# IP 速率限制检查(注册接口:3次/分钟)
|
||||
allowed, remaining, reset_after = await IPRateLimiter.check_limit(client_ip, "register")
|
||||
if not allowed:
|
||||
logger.warning(f"注册请求超过速率限制: IP={client_ip}, 剩余={remaining}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"注册请求过于频繁,请在 {reset_after} 秒后重试",
|
||||
)
|
||||
|
||||
allow_registration = db.query(SystemConfig).filter_by(key="enable_registration").first()
|
||||
if allow_registration and not allow_registration.value:
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=AuditEventType.UNAUTHORIZED_ACCESS,
|
||||
description=f"Registration attempt rejected - registration disabled: {register_request.email}",
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
metadata={"email": register_request.email, "reason": "registration_disabled"},
|
||||
)
|
||||
db.commit()
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="系统暂不开放注册")
|
||||
|
||||
try:
|
||||
user = UserService.create_user(
|
||||
db=db,
|
||||
email=register_request.email,
|
||||
username=register_request.username,
|
||||
password=register_request.password,
|
||||
role=UserRole.USER,
|
||||
)
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=AuditEventType.USER_CREATED,
|
||||
description=f"User registered: {user.email}",
|
||||
user_id=user.id,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
metadata={"email": user.email, "username": user.username, "role": user.role.value},
|
||||
)
|
||||
db.commit()
|
||||
return RegisterResponse(
|
||||
user_id=user.id,
|
||||
email=user.email,
|
||||
username=user.username,
|
||||
message="注册成功",
|
||||
).model_dump()
|
||||
except ValueError as exc:
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=AuditEventType.UNAUTHORIZED_ACCESS,
|
||||
description=f"Registration failed: {register_request.email} - {exc}",
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
metadata={"email": register_request.email, "error": str(exc)},
|
||||
)
|
||||
db.commit()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc))
|
||||
|
||||
|
||||
class AuthCurrentUserAdapter(AuthenticatedApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
user = context.user
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"is_active": user.is_active,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": user.total_usd,
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_models": user.allowed_models,
|
||||
"created_at": user.created_at.isoformat(),
|
||||
"last_login_at": user.last_login_at.isoformat() if user.last_login_at else None,
|
||||
}
|
||||
|
||||
|
||||
class AuthChangePasswordAdapter(AuthenticatedApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
payload = context.ensure_json_body()
|
||||
old_password = payload.get("old_password")
|
||||
new_password = payload.get("new_password")
|
||||
if not old_password or not new_password:
|
||||
raise HTTPException(status_code=400, detail="必须提供旧密码和新密码")
|
||||
user = context.user
|
||||
if not user.verify_password(old_password):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="旧密码错误")
|
||||
if len(new_password) < 8:
|
||||
raise InvalidRequestException("密码长度至少8位")
|
||||
user.set_password(new_password)
|
||||
context.db.commit()
|
||||
logger.info(f"用户修改密码: {user.email}")
|
||||
return {"message": "密码修改成功"}
|
||||
|
||||
|
||||
class AuthLogoutAdapter(AuthenticatedApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""用户登出,将 Token 加入黑名单"""
|
||||
user = context.user
|
||||
client_ip = get_client_ip(context.request)
|
||||
|
||||
# 从 Authorization header 获取 Token
|
||||
auth_header = context.request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="缺少认证令牌")
|
||||
|
||||
token = auth_header.replace("Bearer ", "")
|
||||
|
||||
# 将 Token 加入黑名单
|
||||
success = await AuthService.logout(token)
|
||||
|
||||
if success:
|
||||
# 记录审计日志
|
||||
AuditService.log_event(
|
||||
db=context.db,
|
||||
event_type=AuditEventType.LOGOUT,
|
||||
description=f"User logged out: {user.email}",
|
||||
user_id=user.id,
|
||||
ip_address=client_ip,
|
||||
user_agent=get_user_agent(context.request),
|
||||
metadata={"user_id": user.id, "email": user.email},
|
||||
)
|
||||
context.db.commit()
|
||||
|
||||
logger.info(f"用户登出成功: {user.email}")
|
||||
|
||||
return LogoutResponse(message="登出成功", success=True).model_dump()
|
||||
else:
|
||||
logger.warning(f"用户登出失败(Redis不可用): {user.email}")
|
||||
return LogoutResponse(message="登出成功(降级模式)", success=False).model_dump()
|
||||
82
src/api/base/adapter.py
Normal file
82
src/api/base/adapter.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import Request, Response
|
||||
|
||||
from .context import ApiRequestContext
|
||||
|
||||
|
||||
class ApiMode(str, Enum):
|
||||
STANDARD = "standard"
|
||||
PROXY = "proxy"
|
||||
ADMIN = "admin"
|
||||
USER = "user" # JWT 认证的普通用户(不要求管理员权限)
|
||||
PUBLIC = "public"
|
||||
|
||||
|
||||
class ApiAdapter(ABC):
|
||||
"""所有API格式适配器的抽象基类。"""
|
||||
|
||||
name: str = "base"
|
||||
mode: ApiMode = ApiMode.STANDARD
|
||||
api_format: Optional[str] = None # 对应 Provider API 格式提示
|
||||
audit_log_enabled: bool = True
|
||||
audit_success_event = None
|
||||
audit_failure_event = None
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, context: ApiRequestContext) -> Response:
|
||||
"""处理请求并返回 FastAPI Response。"""
|
||||
|
||||
def authorize(self, context: ApiRequestContext) -> None:
|
||||
"""可选的授权钩子,默认允许通过。"""
|
||||
return None
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
从请求中提取客户端 API 密钥。
|
||||
|
||||
子类应覆盖此方法以支持各自的认证头格式。
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
|
||||
Returns:
|
||||
提取的 API 密钥,如果未找到则返回 None
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_audit_metadata(
|
||||
self,
|
||||
context: ApiRequestContext,
|
||||
*,
|
||||
success: bool,
|
||||
status_code: Optional[int],
|
||||
error: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""允许适配器在审计日志中追加自定义字段。"""
|
||||
return {}
|
||||
|
||||
def detect_capability_requirements(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
检测请求中隐含的能力需求(子类可覆盖)
|
||||
|
||||
不同 API 格式有不同的能力声明方式,例如:
|
||||
- Claude: anthropic-beta: context-1m-xxx 表示需要 1M 上下文
|
||||
- 其他格式可能有不同的请求头或请求体字段
|
||||
|
||||
Args:
|
||||
headers: 请求头字典
|
||||
request_body: 请求体字典(可选)
|
||||
|
||||
Returns:
|
||||
检测到的能力需求,如 {"context_1m": True}
|
||||
"""
|
||||
return {}
|
||||
29
src/api/base/admin_adapter.py
Normal file
29
src/api/base/admin_adapter.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.models.database import UserRole
|
||||
|
||||
from .adapter import ApiAdapter, ApiMode
|
||||
from .context import ApiRequestContext
|
||||
|
||||
|
||||
class AdminApiAdapter(ApiAdapter):
|
||||
"""管理员端点适配器基类,提供统一的权限校验。"""
|
||||
|
||||
mode = ApiMode.ADMIN
|
||||
required_roles: tuple[UserRole, ...] = (UserRole.ADMIN,)
|
||||
|
||||
def authorize(self, context: ApiRequestContext) -> None:
|
||||
user = context.user
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 检查是否使用独立余额Key访问管理接口
|
||||
if context.api_key and context.api_key.is_standalone:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="独立余额Key不允许访问管理接口,仅可用于代理请求"
|
||||
)
|
||||
|
||||
if not any(user.role == role for role in self.required_roles):
|
||||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||
13
src/api/base/authenticated_adapter.py
Normal file
13
src/api/base/authenticated_adapter.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .adapter import ApiAdapter, ApiMode
|
||||
|
||||
|
||||
class AuthenticatedApiAdapter(ApiAdapter):
|
||||
"""通用需要登录的适配器基类。"""
|
||||
|
||||
mode = ApiMode.USER
|
||||
|
||||
def authorize(self, context): # type: ignore[override]
|
||||
if not context.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
116
src/api/base/context.py
Normal file
116
src/api/base/context.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, User
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiRequestContext:
|
||||
"""统一的API请求上下文,贯穿Pipeline与格式适配器。"""
|
||||
|
||||
request: Request
|
||||
db: Session
|
||||
user: Optional[User]
|
||||
api_key: Optional[ApiKey]
|
||||
request_id: str
|
||||
start_time: float
|
||||
client_ip: str
|
||||
user_agent: str
|
||||
original_headers: Dict[str, str]
|
||||
query_params: Dict[str, str]
|
||||
raw_body: bytes | None = None
|
||||
json_body: Optional[Dict[str, Any]] = None
|
||||
quota_remaining: Optional[float] = None
|
||||
mode: str = "standard" # standard / proxy
|
||||
api_format_hint: Optional[str] = None
|
||||
|
||||
# URL 路径参数(如 Gemini API 的 /v1beta/models/{model}:generateContent)
|
||||
path_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# 供适配器扩展的状态存储
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
audit_metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def ensure_json_body(self) -> Dict[str, Any]:
|
||||
"""确保请求体已解析为JSON并返回。"""
|
||||
if self.json_body is not None:
|
||||
return self.json_body
|
||||
|
||||
if not self.raw_body:
|
||||
raise HTTPException(status_code=400, detail="请求体不能为空")
|
||||
|
||||
try:
|
||||
self.json_body = json.loads(self.raw_body.decode("utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.warning(f"解析JSON失败: {exc}")
|
||||
raise HTTPException(status_code=400, detail="请求体必须是合法的JSON") from exc
|
||||
|
||||
return self.json_body
|
||||
|
||||
def add_audit_metadata(self, **values: Any) -> None:
|
||||
"""向审计日志附加字段(会自动过滤 None)。"""
|
||||
for key, value in values.items():
|
||||
if value is not None:
|
||||
self.audit_metadata[key] = value
|
||||
|
||||
def extend_audit_metadata(self, data: Dict[str, Any]) -> None:
|
||||
"""批量附加审计字段。"""
|
||||
for key, value in data.items():
|
||||
if value is not None:
|
||||
self.audit_metadata[key] = value
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
request: Request,
|
||||
db: Session,
|
||||
user: Optional[User],
|
||||
api_key: Optional[ApiKey],
|
||||
raw_body: Optional[bytes] = None,
|
||||
mode: str = "standard",
|
||||
api_format_hint: Optional[str] = None,
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
) -> "ApiRequestContext":
|
||||
"""创建上下文实例并提前读取必要的元数据。"""
|
||||
request_id = getattr(request.state, "request_id", None) or str(uuid.uuid4())[:8]
|
||||
setattr(request.state, "request_id", request_id)
|
||||
|
||||
start_time = time.time()
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
context = cls(
|
||||
request=request,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
request_id=request_id,
|
||||
start_time=start_time,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
original_headers=dict(request.headers),
|
||||
query_params=dict(request.query_params),
|
||||
raw_body=raw_body,
|
||||
mode=mode,
|
||||
api_format_hint=api_format_hint,
|
||||
path_params=path_params or {},
|
||||
)
|
||||
|
||||
# 便于插件/日志引用
|
||||
request.state.request_id = request_id
|
||||
if user:
|
||||
request.state.user_id = user.id
|
||||
if api_key:
|
||||
request.state.api_key_id = api_key.id
|
||||
|
||||
return context
|
||||
49
src/api/base/pagination.py
Normal file
49
src/api/base/pagination.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Sequence, Tuple, TypeVar
|
||||
|
||||
from sqlalchemy.orm import Query
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaginationMeta:
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
count: int
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def paginate_query(query: Query, limit: int, offset: int) -> Tuple[int, List[T]]:
|
||||
"""
|
||||
对 SQLAlchemy 查询应用 limit/offset,并返回总数与结果列表。
|
||||
"""
|
||||
total = query.order_by(None).count()
|
||||
records = query.offset(offset).limit(limit).all()
|
||||
return total, records
|
||||
|
||||
|
||||
def paginate_sequence(
|
||||
items: Sequence[T], limit: int, offset: int
|
||||
) -> Tuple[List[T], PaginationMeta]:
|
||||
"""
|
||||
对内存序列应用分页,返回切片和元数据。
|
||||
"""
|
||||
total = len(items)
|
||||
sliced = list(items[offset : offset + limit])
|
||||
meta = PaginationMeta(total=total, limit=limit, offset=offset, count=len(sliced))
|
||||
return sliced, meta
|
||||
|
||||
|
||||
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra) -> dict:
|
||||
"""
|
||||
构建标准分页响应 payload。
|
||||
"""
|
||||
payload = {"items": items, "meta": meta.to_dict()}
|
||||
payload.update(extra)
|
||||
return payload
|
||||
387
src/api/base/pipeline.py
Normal file
387
src/api/base/pipeline.py
Normal file
@@ -0,0 +1,387 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from src.core.exceptions import QuotaExceededException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, AuditEventType, User, UserRole
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.cache.user_cache import UserCacheService
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
from .adapter import ApiAdapter, ApiMode
|
||||
from .context import ApiRequestContext
|
||||
|
||||
|
||||
|
||||
class ApiRequestPipeline:
|
||||
"""负责统一执行认证、配额校验、上下文构建等通用逻辑的管道。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auth_service: AuthService = AuthService,
|
||||
usage_service: UsageService = UsageService,
|
||||
audit_service: AuditService = AuditService,
|
||||
):
|
||||
self.auth_service = auth_service
|
||||
self.usage_service = usage_service
|
||||
self.audit_service = audit_service
|
||||
|
||||
async def run(
|
||||
self,
|
||||
adapter: ApiAdapter,
|
||||
http_request: Request,
|
||||
db: Session,
|
||||
*,
|
||||
mode: ApiMode = ApiMode.STANDARD,
|
||||
api_format_hint: Optional[str] = None,
|
||||
path_params: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
logger.debug(f"[Pipeline] START | path={http_request.url.path}")
|
||||
logger.debug(f"[Pipeline] Running with mode={mode}, adapter={adapter.__class__.__name__}, "
|
||||
f"adapter.mode={adapter.mode}, path={http_request.url.path}")
|
||||
if mode == ApiMode.ADMIN:
|
||||
user = await self._authenticate_admin(http_request, db)
|
||||
api_key = None
|
||||
elif mode == ApiMode.USER:
|
||||
user = await self._authenticate_user(http_request, db)
|
||||
api_key = None
|
||||
elif mode == ApiMode.PUBLIC:
|
||||
user = None
|
||||
api_key = None
|
||||
else:
|
||||
logger.debug("[Pipeline] 调用 _authenticate_client")
|
||||
user, api_key = self._authenticate_client(http_request, db, adapter)
|
||||
logger.debug(f"[Pipeline] 认证完成 | user={user.username if user else None}")
|
||||
|
||||
raw_body = None
|
||||
if http_request.method in {"POST", "PUT", "PATCH"}:
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
# 添加30秒超时防止卡死
|
||||
raw_body = await asyncio.wait_for(http_request.body(), timeout=30.0)
|
||||
logger.debug(f"[Pipeline] Raw body读取完成 | size={len(raw_body) if raw_body is not None else 0} bytes")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("读取请求体超时(30s),可能客户端未发送完整请求体")
|
||||
raise HTTPException(
|
||||
status_code=408, detail="Request timeout: body not received within 30 seconds"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"[Pipeline] 非写请求跳过读取Body | method={http_request.method}")
|
||||
|
||||
context = ApiRequestContext.build(
|
||||
request=http_request,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
raw_body=raw_body,
|
||||
mode=mode.value,
|
||||
api_format_hint=api_format_hint,
|
||||
path_params=path_params,
|
||||
)
|
||||
logger.debug(f"[Pipeline] Context构建完成 | adapter={adapter.name} | request_id={context.request_id}")
|
||||
|
||||
if mode != ApiMode.ADMIN and user:
|
||||
context.quota_remaining = self._calculate_quota_remaining(user)
|
||||
|
||||
logger.debug(f"[Pipeline] Adapter={adapter.name} | RequestID={context.request_id}")
|
||||
|
||||
logger.debug(f"[Pipeline] Calling authorize on {adapter.__class__.__name__}, user={context.user}")
|
||||
# authorize 可能是异步的,需要检查并 await
|
||||
authorize_result = adapter.authorize(context)
|
||||
if hasattr(authorize_result, "__await__"):
|
||||
await authorize_result
|
||||
|
||||
try:
|
||||
response = await adapter.handle(context)
|
||||
status_code = getattr(response, "status_code", None)
|
||||
self._record_audit_event(context, adapter, success=True, status_code=status_code)
|
||||
return response
|
||||
except HTTPException as exc:
|
||||
err_detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
||||
self._record_audit_event(
|
||||
context,
|
||||
adapter,
|
||||
success=False,
|
||||
status_code=exc.status_code,
|
||||
error=err_detail,
|
||||
)
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._record_audit_event(
|
||||
context,
|
||||
adapter,
|
||||
success=False,
|
||||
status_code=500,
|
||||
error=str(exc),
|
||||
)
|
||||
raise
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Internal helpers
|
||||
# --------------------------------------------------------------------- #
|
||||
|
||||
def _authenticate_client(
|
||||
self, request: Request, db: Session, adapter: ApiAdapter
|
||||
) -> Tuple[User, ApiKey]:
|
||||
logger.debug("[Pipeline._authenticate_client] 开始")
|
||||
# 使用 adapter 的 extract_api_key 方法,支持不同 API 格式的认证头
|
||||
client_api_key = adapter.extract_api_key(request)
|
||||
logger.debug(f"[Pipeline._authenticate_client] 提取API密钥完成 | key_prefix={client_api_key[:8] if client_api_key else None}...")
|
||||
if not client_api_key:
|
||||
raise HTTPException(status_code=401, detail="请提供API密钥")
|
||||
|
||||
logger.debug("[Pipeline._authenticate_client] 调用 auth_service.authenticate_api_key")
|
||||
auth_result = self.auth_service.authenticate_api_key(db, client_api_key)
|
||||
logger.debug(f"[Pipeline._authenticate_client] 认证结果 | result={bool(auth_result)}")
|
||||
if not auth_result:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
user, api_key = auth_result
|
||||
if not user or not api_key:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
request.state.user_id = user.id
|
||||
request.state.api_key_id = api_key.id
|
||||
|
||||
# 检查配额或余额(支持独立Key)
|
||||
quota_ok, message = self.usage_service.check_user_quota(db, user, api_key=api_key)
|
||||
if not quota_ok:
|
||||
# 根据Key类型计算剩余额度
|
||||
if api_key.is_standalone:
|
||||
# 独立Key:显示剩余余额
|
||||
remaining = (
|
||||
None
|
||||
if api_key.current_balance_usd is None
|
||||
else float(api_key.current_balance_usd - (api_key.balance_used_usd or 0))
|
||||
)
|
||||
else:
|
||||
# 普通Key:显示用户配额剩余
|
||||
remaining = (
|
||||
None
|
||||
if user.quota_usd is None or user.quota_usd < 0
|
||||
else float(user.quota_usd - user.used_usd)
|
||||
)
|
||||
raise QuotaExceededException(quota_type="USD", remaining=remaining)
|
||||
|
||||
return user, api_key
|
||||
|
||||
async def _authenticate_admin(self, request: Request, db: Session) -> User:
|
||||
authorization = request.headers.get("authorization")
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Admin token 验证失败: {exc}")
|
||||
raise HTTPException(status_code=401, detail="无效的管理员令牌")
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的管理员令牌")
|
||||
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
request.state.user_id = user.id
|
||||
return user
|
||||
|
||||
async def _authenticate_user(self, request: Request, db: Session) -> User:
|
||||
"""JWT 认证普通用户(不要求管理员权限)"""
|
||||
authorization = request.headers.get("authorization")
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"User token 验证失败: {exc}")
|
||||
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
||||
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
request.state.user_id = user.id
|
||||
return user
|
||||
|
||||
def _calculate_quota_remaining(self, user: Optional[User]) -> Optional[float]:
|
||||
if not user:
|
||||
return None
|
||||
if user.quota_usd is None or user.quota_usd < 0:
|
||||
return None
|
||||
return max(float(user.quota_usd - user.used_usd), 0.0)
|
||||
|
||||
def _record_audit_event(
|
||||
self,
|
||||
context: ApiRequestContext,
|
||||
adapter: ApiAdapter,
|
||||
*,
|
||||
success: bool,
|
||||
status_code: Optional[int] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
if not getattr(adapter, "audit_log_enabled", True):
|
||||
return
|
||||
|
||||
bind = context.db.get_bind()
|
||||
if bind is None:
|
||||
return
|
||||
|
||||
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
|
||||
if not event_type:
|
||||
if not success and status_code == 401:
|
||||
event_type = AuditEventType.UNAUTHORIZED_ACCESS
|
||||
else:
|
||||
event_type = (
|
||||
AuditEventType.REQUEST_SUCCESS if success else AuditEventType.REQUEST_FAILED
|
||||
)
|
||||
|
||||
metadata = self._build_audit_metadata(
|
||||
context=context,
|
||||
adapter=adapter,
|
||||
success=success,
|
||||
status_code=status_code,
|
||||
error=error,
|
||||
)
|
||||
|
||||
SessionMaker = sessionmaker(bind=bind)
|
||||
audit_session = SessionMaker()
|
||||
try:
|
||||
self.audit_service.log_event(
|
||||
db=audit_session,
|
||||
event_type=event_type,
|
||||
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
|
||||
user_id=context.user.id if context.user else None,
|
||||
api_key_id=context.api_key.id if context.api_key else None,
|
||||
ip_address=context.client_ip,
|
||||
user_agent=context.user_agent,
|
||||
request_id=context.request_id,
|
||||
status_code=status_code,
|
||||
error_message=error,
|
||||
metadata=metadata,
|
||||
)
|
||||
audit_session.commit()
|
||||
except Exception as exc:
|
||||
audit_session.rollback()
|
||||
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
|
||||
finally:
|
||||
audit_session.close()
|
||||
|
||||
def _build_audit_metadata(
|
||||
self,
|
||||
context: ApiRequestContext,
|
||||
adapter: ApiAdapter,
|
||||
*,
|
||||
success: bool,
|
||||
status_code: Optional[int],
|
||||
error: Optional[str],
|
||||
) -> dict:
|
||||
duration_ms = max((time.time() - context.start_time) * 1000, 0.0)
|
||||
request = context.request
|
||||
path_params = {}
|
||||
try:
|
||||
path_params = dict(getattr(request, "path_params", {}) or {})
|
||||
except Exception:
|
||||
path_params = {}
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"path": request.url.path,
|
||||
"path_params": path_params,
|
||||
"method": request.method,
|
||||
"adapter": adapter.name,
|
||||
"adapter_class": adapter.__class__.__name__,
|
||||
"adapter_mode": getattr(adapter.mode, "value", str(adapter.mode)),
|
||||
"mode": context.mode,
|
||||
"api_format_hint": context.api_format_hint,
|
||||
"query": context.query_params,
|
||||
"duration_ms": round(duration_ms, 2),
|
||||
"request_body_bytes": len(context.raw_body or b""),
|
||||
"has_body": bool(context.raw_body),
|
||||
"request_content_type": request.headers.get("content-type"),
|
||||
"quota_remaining": context.quota_remaining,
|
||||
"success": success,
|
||||
}
|
||||
if status_code is not None:
|
||||
metadata["status_code"] = status_code
|
||||
|
||||
if context.user and getattr(context.user, "role", None):
|
||||
role = context.user.role
|
||||
metadata["user_role"] = getattr(role, "value", role)
|
||||
|
||||
if context.api_key:
|
||||
if getattr(context.api_key, "name", None):
|
||||
metadata["api_key_name"] = context.api_key.name
|
||||
# 使用脱敏后的密钥显示
|
||||
if hasattr(context.api_key, "get_display_key"):
|
||||
metadata["api_key_display"] = context.api_key.get_display_key()
|
||||
|
||||
extra_details: dict[str, Any] = {}
|
||||
if context.audit_metadata:
|
||||
extra_details.update(context.audit_metadata)
|
||||
|
||||
try:
|
||||
adapter_details = adapter.get_audit_metadata(
|
||||
context,
|
||||
success=success,
|
||||
status_code=status_code,
|
||||
error=error,
|
||||
)
|
||||
if adapter_details:
|
||||
extra_details.update(adapter_details)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Audit] Adapter metadata failed: {adapter.__class__.__name__}: {exc}")
|
||||
|
||||
if extra_details:
|
||||
metadata["details"] = extra_details
|
||||
|
||||
if error:
|
||||
metadata["error"] = error
|
||||
|
||||
return self._sanitize_metadata(metadata)
|
||||
|
||||
def _sanitize_metadata(self, value: Any, depth: int = 0):
|
||||
if value is None:
|
||||
return None
|
||||
if depth > 5:
|
||||
return str(value)
|
||||
if isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
if isinstance(value, dict):
|
||||
sanitized = {}
|
||||
for key, val in value.items():
|
||||
cleaned = self._sanitize_metadata(val, depth + 1)
|
||||
if cleaned is not None:
|
||||
sanitized[str(key)] = cleaned
|
||||
return sanitized
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [self._sanitize_metadata(item, depth + 1) for item in value]
|
||||
if hasattr(value, "isoformat"):
|
||||
try:
|
||||
return value.isoformat()
|
||||
except Exception:
|
||||
return str(value)
|
||||
return str(value)
|
||||
10
src/api/dashboard/__init__.py
Normal file
10
src/api/dashboard/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
'"""Dashboard API routers."""'
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .routes import router as dashboard_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(dashboard_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
905
src/api/dashboard/routes.py
Normal file
905
src/api/dashboard/routes.py
Normal file
@@ -0,0 +1,905 @@
|
||||
"""仪表盘统计 API 端点。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.enums import UserRole
|
||||
from src.database import get_db
|
||||
from src.models.database import ApiKey, Provider, RequestCandidate, StatsDaily, Usage
|
||||
from src.models.database import User as DBUser
|
||||
from src.services.system.stats_aggregator import StatsAggregatorService
|
||||
from src.utils.cache_decorator import cache_result
|
||||
|
||||
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
def format_tokens(num: int) -> str:
|
||||
"""格式化 Token 数量,自动转换 K/M 单位"""
|
||||
if num < 1000:
|
||||
return str(num)
|
||||
if num < 1000000:
|
||||
thousands = num / 1000
|
||||
if thousands >= 100:
|
||||
return f"{round(thousands)}K"
|
||||
elif thousands >= 10:
|
||||
return f"{thousands:.1f}K"
|
||||
else:
|
||||
return f"{thousands:.2f}K"
|
||||
millions = num / 1000000
|
||||
if millions >= 100:
|
||||
return f"{round(millions)}M"
|
||||
elif millions >= 10:
|
||||
return f"{millions:.1f}M"
|
||||
else:
|
||||
return f"{millions:.2f}M"
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_dashboard_stats(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = DashboardStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/recent-requests")
|
||||
async def get_recent_requests(
|
||||
request: Request,
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = DashboardRecentRequestsAdapter(limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# NOTE: /request-detail/{request_id} has been moved to /api/admin/usage/{id}
|
||||
# The old route is removed. Use dashboardApi.getRequestDetail() which now calls the new API.
|
||||
|
||||
|
||||
@router.get("/provider-status")
|
||||
async def get_provider_status(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = DashboardProviderStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/daily-stats")
|
||||
async def get_daily_stats(
|
||||
request: Request,
|
||||
days: int = Query(7, ge=1, le=30),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = DashboardDailyStatsAdapter(days=days)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class DashboardAdapter(ApiAdapter):
|
||||
"""需要登录的仪表盘适配器基类。"""
|
||||
|
||||
mode = ApiMode.ADMIN
|
||||
|
||||
def authorize(self, context): # type: ignore[override]
|
||||
if not context.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
|
||||
class DashboardStatsAdapter(DashboardAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
user = context.user
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
adapter = (
|
||||
AdminDashboardStatsAdapter()
|
||||
if user.role == UserRole.ADMIN
|
||||
else UserDashboardStatsAdapter()
|
||||
)
|
||||
return await adapter.handle(context)
|
||||
|
||||
|
||||
class AdminDashboardStatsAdapter(AdminApiAdapter):
|
||||
@cache_result(key_prefix="dashboard:admin:stats", ttl=60, user_specific=False)
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""管理员仪表盘统计 - 使用预聚合数据优化性能"""
|
||||
db = context.db
|
||||
now = datetime.now(timezone.utc)
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
yesterday = today - timedelta(days=1)
|
||||
last_month = today - timedelta(days=30)
|
||||
|
||||
# ==================== 使用预聚合数据 ====================
|
||||
# 从 stats_summary + 今日实时数据获取全局统计
|
||||
combined_stats = StatsAggregatorService.get_combined_stats(db)
|
||||
|
||||
all_time_requests = combined_stats["total_requests"]
|
||||
all_time_success_requests = combined_stats["success_requests"]
|
||||
all_time_error_requests = combined_stats["error_requests"]
|
||||
all_time_input_tokens = combined_stats["input_tokens"]
|
||||
all_time_output_tokens = combined_stats["output_tokens"]
|
||||
all_time_cache_creation = combined_stats["cache_creation_tokens"]
|
||||
all_time_cache_read = combined_stats["cache_read_tokens"]
|
||||
all_time_cost = combined_stats["total_cost"]
|
||||
all_time_actual_cost = combined_stats["actual_total_cost"]
|
||||
|
||||
# 用户/API Key 统计
|
||||
total_users = combined_stats.get("total_users") or db.query(func.count(DBUser.id)).scalar()
|
||||
active_users = combined_stats.get("active_users") or (
|
||||
db.query(func.count(DBUser.id)).filter(DBUser.is_active.is_(True)).scalar()
|
||||
)
|
||||
total_api_keys = combined_stats.get("total_api_keys") or db.query(func.count(ApiKey.id)).scalar()
|
||||
active_api_keys = combined_stats.get("active_api_keys") or (
|
||||
db.query(func.count(ApiKey.id)).filter(ApiKey.is_active.is_(True)).scalar()
|
||||
)
|
||||
|
||||
# ==================== 今日实时统计 ====================
|
||||
today_stats = StatsAggregatorService.get_today_realtime_stats(db)
|
||||
requests_today = today_stats["total_requests"]
|
||||
cost_today = today_stats["total_cost"]
|
||||
actual_cost_today = today_stats["actual_total_cost"]
|
||||
input_tokens_today = today_stats["input_tokens"]
|
||||
output_tokens_today = today_stats["output_tokens"]
|
||||
cache_creation_today = today_stats["cache_creation_tokens"]
|
||||
cache_read_today = today_stats["cache_read_tokens"]
|
||||
tokens_today = (
|
||||
input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today
|
||||
)
|
||||
|
||||
# ==================== 昨日统计(从预聚合表获取)====================
|
||||
yesterday_stats = db.query(StatsDaily).filter(StatsDaily.date == yesterday).first()
|
||||
if yesterday_stats:
|
||||
requests_yesterday = yesterday_stats.total_requests
|
||||
cost_yesterday = yesterday_stats.total_cost
|
||||
input_tokens_yesterday = yesterday_stats.input_tokens
|
||||
output_tokens_yesterday = yesterday_stats.output_tokens
|
||||
cache_creation_yesterday = yesterday_stats.cache_creation_tokens
|
||||
cache_read_yesterday = yesterday_stats.cache_read_tokens
|
||||
else:
|
||||
# 如果没有预聚合数据,回退到实时查询
|
||||
requests_yesterday = (
|
||||
db.query(func.count(Usage.id))
|
||||
.filter(Usage.created_at >= yesterday, Usage.created_at < today)
|
||||
.scalar() or 0
|
||||
)
|
||||
cost_yesterday = (
|
||||
db.query(func.sum(Usage.total_cost_usd))
|
||||
.filter(Usage.created_at >= yesterday, Usage.created_at < today)
|
||||
.scalar() or 0
|
||||
)
|
||||
yesterday_token_stats = (
|
||||
db.query(
|
||||
func.sum(Usage.input_tokens).label("input_tokens"),
|
||||
func.sum(Usage.output_tokens).label("output_tokens"),
|
||||
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
|
||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||
)
|
||||
.filter(Usage.created_at >= yesterday, Usage.created_at < today)
|
||||
.first()
|
||||
)
|
||||
input_tokens_yesterday = int(yesterday_token_stats.input_tokens or 0) if yesterday_token_stats else 0
|
||||
output_tokens_yesterday = int(yesterday_token_stats.output_tokens or 0) if yesterday_token_stats else 0
|
||||
cache_creation_yesterday = int(yesterday_token_stats.cache_creation_tokens or 0) if yesterday_token_stats else 0
|
||||
cache_read_yesterday = int(yesterday_token_stats.cache_read_tokens or 0) if yesterday_token_stats else 0
|
||||
|
||||
# ==================== 本月统计(从预聚合表聚合)====================
|
||||
monthly_stats = (
|
||||
db.query(
|
||||
func.sum(StatsDaily.total_requests).label("total_requests"),
|
||||
func.sum(StatsDaily.error_requests).label("error_requests"),
|
||||
func.sum(StatsDaily.total_cost).label("total_cost"),
|
||||
func.sum(StatsDaily.actual_total_cost).label("actual_total_cost"),
|
||||
func.sum(StatsDaily.input_tokens + StatsDaily.output_tokens +
|
||||
StatsDaily.cache_creation_tokens + StatsDaily.cache_read_tokens).label("total_tokens"),
|
||||
func.sum(StatsDaily.cache_creation_tokens).label("cache_creation_tokens"),
|
||||
func.sum(StatsDaily.cache_read_tokens).label("cache_read_tokens"),
|
||||
func.sum(StatsDaily.cache_creation_cost).label("cache_creation_cost"),
|
||||
func.sum(StatsDaily.cache_read_cost).label("cache_read_cost"),
|
||||
func.sum(StatsDaily.fallback_count).label("fallback_count"),
|
||||
)
|
||||
.filter(StatsDaily.date >= last_month, StatsDaily.date < today)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 本月数据 = 预聚合月数据 + 今日实时数据
|
||||
if monthly_stats and monthly_stats.total_requests:
|
||||
total_requests = int(monthly_stats.total_requests or 0) + requests_today
|
||||
error_requests = int(monthly_stats.error_requests or 0) + today_stats["error_requests"]
|
||||
total_cost = float(monthly_stats.total_cost or 0) + cost_today
|
||||
total_actual_cost = float(monthly_stats.actual_total_cost or 0) + actual_cost_today
|
||||
total_tokens = int(monthly_stats.total_tokens or 0) + tokens_today
|
||||
cache_creation_tokens = int(monthly_stats.cache_creation_tokens or 0) + cache_creation_today
|
||||
cache_read_tokens = int(monthly_stats.cache_read_tokens or 0) + cache_read_today
|
||||
cache_creation_cost = float(monthly_stats.cache_creation_cost or 0)
|
||||
cache_read_cost = float(monthly_stats.cache_read_cost or 0)
|
||||
fallback_count = int(monthly_stats.fallback_count or 0)
|
||||
else:
|
||||
# 回退到实时查询(没有预聚合数据时)
|
||||
total_requests = (
|
||||
db.query(func.count(Usage.id)).filter(Usage.created_at >= last_month).scalar() or 0
|
||||
)
|
||||
total_cost = (
|
||||
db.query(func.sum(Usage.total_cost_usd)).filter(Usage.created_at >= last_month).scalar() or 0
|
||||
)
|
||||
total_actual_cost = (
|
||||
db.query(func.sum(Usage.actual_total_cost_usd))
|
||||
.filter(Usage.created_at >= last_month).scalar() or 0
|
||||
)
|
||||
error_requests = (
|
||||
db.query(func.count(Usage.id))
|
||||
.filter(
|
||||
Usage.created_at >= last_month,
|
||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None)),
|
||||
).scalar() or 0
|
||||
)
|
||||
total_tokens = (
|
||||
db.query(func.sum(Usage.total_tokens)).filter(Usage.created_at >= last_month).scalar() or 0
|
||||
)
|
||||
cache_stats = (
|
||||
db.query(
|
||||
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
|
||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
|
||||
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
|
||||
)
|
||||
.filter(Usage.created_at >= last_month)
|
||||
.first()
|
||||
)
|
||||
cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
|
||||
cache_read_tokens = int(cache_stats.cache_read_tokens or 0) if cache_stats else 0
|
||||
cache_creation_cost = float(cache_stats.cache_creation_cost or 0) if cache_stats else 0
|
||||
cache_read_cost = float(cache_stats.cache_read_cost or 0) if cache_stats else 0
|
||||
|
||||
# Fallback 统计
|
||||
fallback_subquery = (
|
||||
db.query(
|
||||
RequestCandidate.request_id, func.count(RequestCandidate.id).label("executed_count")
|
||||
)
|
||||
.filter(
|
||||
RequestCandidate.created_at >= last_month,
|
||||
RequestCandidate.status.in_(["success", "failed"]),
|
||||
)
|
||||
.group_by(RequestCandidate.request_id)
|
||||
.subquery()
|
||||
)
|
||||
fallback_count = (
|
||||
db.query(func.count())
|
||||
.select_from(fallback_subquery)
|
||||
.filter(fallback_subquery.c.executed_count > 1)
|
||||
.scalar() or 0
|
||||
)
|
||||
|
||||
# ==================== 系统健康指标 ====================
|
||||
error_rate = round((error_requests / total_requests) * 100, 2) if total_requests > 0 else 0
|
||||
|
||||
# 平均响应时间(仅查询今日数据,降低查询成本)
|
||||
avg_response_time = (
|
||||
db.query(func.avg(Usage.response_time_ms))
|
||||
.filter(
|
||||
Usage.created_at >= today,
|
||||
Usage.status_code == 200,
|
||||
Usage.response_time_ms.isnot(None),
|
||||
)
|
||||
.scalar() or 0
|
||||
)
|
||||
avg_response_time_seconds = float(avg_response_time) / 1000.0
|
||||
|
||||
# 缓存命中率
|
||||
total_input_with_cache = all_time_input_tokens + all_time_cache_read
|
||||
cache_hit_rate = (
|
||||
round((all_time_cache_read / total_input_with_cache) * 100, 1)
|
||||
if total_input_with_cache > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"stats": [
|
||||
{
|
||||
"name": "总请求",
|
||||
"value": f"{all_time_requests:,}",
|
||||
"subValue": f"有效 {all_time_success_requests:,} / 异常 {all_time_error_requests:,}",
|
||||
"change": (
|
||||
f"+{requests_today}"
|
||||
if requests_today > requests_yesterday
|
||||
else str(requests_today)
|
||||
),
|
||||
"changeType": (
|
||||
"increase"
|
||||
if requests_today > requests_yesterday
|
||||
else ("decrease" if requests_today < requests_yesterday else "neutral")
|
||||
),
|
||||
"icon": "Activity",
|
||||
},
|
||||
{
|
||||
"name": "总费用",
|
||||
"value": f"${all_time_cost:.2f}",
|
||||
"subValue": f"倍率后 ${all_time_actual_cost:.2f}",
|
||||
"change": (
|
||||
f"+${cost_today:.2f}"
|
||||
if cost_today > cost_yesterday
|
||||
else f"${cost_today:.2f}"
|
||||
),
|
||||
"changeType": (
|
||||
"increase"
|
||||
if cost_today > cost_yesterday
|
||||
else ("decrease" if cost_today < cost_yesterday else "neutral")
|
||||
),
|
||||
"icon": "DollarSign",
|
||||
},
|
||||
{
|
||||
"name": "总Token",
|
||||
"value": format_tokens(
|
||||
all_time_input_tokens
|
||||
+ all_time_output_tokens
|
||||
+ all_time_cache_creation
|
||||
+ all_time_cache_read
|
||||
),
|
||||
"subValue": f"输入 {format_tokens(all_time_input_tokens)} / 输出 {format_tokens(all_time_output_tokens)}",
|
||||
"change": (
|
||||
f"+{format_tokens(input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today)}"
|
||||
if (input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today)
|
||||
> (input_tokens_yesterday + output_tokens_yesterday + cache_creation_yesterday + cache_read_yesterday)
|
||||
else format_tokens(input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today)
|
||||
),
|
||||
"changeType": (
|
||||
"increase"
|
||||
if (input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today)
|
||||
> (input_tokens_yesterday + output_tokens_yesterday + cache_creation_yesterday + cache_read_yesterday)
|
||||
else (
|
||||
"decrease"
|
||||
if (input_tokens_today + output_tokens_today + cache_creation_today + cache_read_today)
|
||||
< (input_tokens_yesterday + output_tokens_yesterday + cache_creation_yesterday + cache_read_yesterday)
|
||||
else "neutral"
|
||||
)
|
||||
),
|
||||
"icon": "Hash",
|
||||
},
|
||||
{
|
||||
"name": "总缓存",
|
||||
"value": format_tokens(all_time_cache_creation + all_time_cache_read),
|
||||
"subValue": f"创建 {format_tokens(all_time_cache_creation)} / 读取 {format_tokens(all_time_cache_read)}",
|
||||
"change": (
|
||||
f"+{format_tokens(cache_creation_today + cache_read_today)}"
|
||||
if (cache_creation_today + cache_read_today)
|
||||
> (cache_creation_yesterday + cache_read_yesterday)
|
||||
else format_tokens(cache_creation_today + cache_read_today)
|
||||
),
|
||||
"changeType": (
|
||||
"increase"
|
||||
if (cache_creation_today + cache_read_today)
|
||||
> (cache_creation_yesterday + cache_read_yesterday)
|
||||
else (
|
||||
"decrease"
|
||||
if (cache_creation_today + cache_read_today)
|
||||
< (cache_creation_yesterday + cache_read_yesterday)
|
||||
else "neutral"
|
||||
)
|
||||
),
|
||||
"extraBadge": f"命中率 {cache_hit_rate}%",
|
||||
"icon": "Database",
|
||||
},
|
||||
],
|
||||
"today": {
|
||||
"requests": requests_today,
|
||||
"cost": cost_today,
|
||||
"actual_cost": actual_cost_today,
|
||||
"tokens": tokens_today,
|
||||
"cache_creation_tokens": cache_creation_today,
|
||||
"cache_read_tokens": cache_read_today,
|
||||
},
|
||||
"api_keys": {"total": total_api_keys, "active": active_api_keys},
|
||||
"tokens": {"month": total_tokens},
|
||||
"token_breakdown": {
|
||||
"input": all_time_input_tokens,
|
||||
"output": all_time_output_tokens,
|
||||
"cache_creation": all_time_cache_creation,
|
||||
"cache_read": all_time_cache_read,
|
||||
},
|
||||
"system_health": {
|
||||
"avg_response_time": round(avg_response_time_seconds, 2),
|
||||
"error_rate": error_rate,
|
||||
"error_requests": error_requests,
|
||||
"fallback_count": fallback_count,
|
||||
"total_requests": total_requests,
|
||||
},
|
||||
"cost_stats": {
|
||||
"total_cost": round(total_cost, 4),
|
||||
"total_actual_cost": round(total_actual_cost, 4),
|
||||
"cost_savings": round(total_cost - total_actual_cost, 4),
|
||||
},
|
||||
"cache_stats": {
|
||||
"cache_creation_tokens": cache_creation_tokens,
|
||||
"cache_read_tokens": cache_read_tokens,
|
||||
"cache_creation_cost": round(cache_creation_cost, 4),
|
||||
"cache_read_cost": round(cache_read_cost, 4),
|
||||
"total_cache_tokens": cache_creation_tokens + cache_read_tokens,
|
||||
},
|
||||
"users": {
|
||||
"total": total_users,
|
||||
"active": active_users,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class UserDashboardStatsAdapter(DashboardAdapter):
|
||||
@cache_result(key_prefix="dashboard:user:stats", ttl=30, user_specific=True)
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = context.user
|
||||
now = datetime.now(timezone.utc)
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
last_month = today - timedelta(days=30)
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
user_api_keys = db.query(func.count(ApiKey.id)).filter(ApiKey.user_id == user.id).scalar()
|
||||
active_keys = (
|
||||
db.query(func.count(ApiKey.id))
|
||||
.filter(and_(ApiKey.user_id == user.id, ApiKey.is_active.is_(True)))
|
||||
.scalar()
|
||||
)
|
||||
|
||||
# 全局 Token 统计
|
||||
all_time_token_stats = (
|
||||
db.query(
|
||||
func.sum(Usage.input_tokens).label("input_tokens"),
|
||||
func.sum(Usage.output_tokens).label("output_tokens"),
|
||||
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
|
||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||
)
|
||||
.filter(Usage.user_id == user.id)
|
||||
.first()
|
||||
)
|
||||
all_time_input_tokens = (
|
||||
int(all_time_token_stats.input_tokens or 0) if all_time_token_stats else 0
|
||||
)
|
||||
all_time_output_tokens = (
|
||||
int(all_time_token_stats.output_tokens or 0) if all_time_token_stats else 0
|
||||
)
|
||||
all_time_cache_creation = (
|
||||
int(all_time_token_stats.cache_creation_tokens or 0) if all_time_token_stats else 0
|
||||
)
|
||||
all_time_cache_read = (
|
||||
int(all_time_token_stats.cache_read_tokens or 0) if all_time_token_stats else 0
|
||||
)
|
||||
|
||||
# 本月请求统计
|
||||
user_requests = (
|
||||
db.query(func.count(Usage.id))
|
||||
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month))
|
||||
.scalar()
|
||||
)
|
||||
user_cost = (
|
||||
db.query(func.sum(Usage.total_cost_usd))
|
||||
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# 今日统计
|
||||
requests_today = (
|
||||
db.query(func.count(Usage.id))
|
||||
.filter(and_(Usage.user_id == user.id, Usage.created_at >= today))
|
||||
.scalar()
|
||||
)
|
||||
cost_today = (
|
||||
db.query(func.sum(Usage.total_cost_usd))
|
||||
.filter(and_(Usage.user_id == user.id, Usage.created_at >= today))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
tokens_today = (
|
||||
db.query(func.sum(Usage.total_tokens))
|
||||
.filter(and_(Usage.user_id == user.id, Usage.created_at >= today))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# 昨日统计(用于计算变化)
|
||||
requests_yesterday = (
|
||||
db.query(func.count(Usage.id))
|
||||
.filter(
|
||||
and_(
|
||||
Usage.user_id == user.id,
|
||||
Usage.created_at >= yesterday,
|
||||
Usage.created_at < today,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
# 缓存统计(本月)
|
||||
cache_stats = (
|
||||
db.query(
|
||||
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
|
||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||
func.sum(Usage.input_tokens).label("total_input_tokens"),
|
||||
)
|
||||
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month))
|
||||
.first()
|
||||
)
|
||||
cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
|
||||
cache_read_tokens = int(cache_stats.cache_read_tokens or 0) if cache_stats else 0
|
||||
|
||||
# 计算缓存命中率:cache_read / (input_tokens + cache_read)
|
||||
# input_tokens 是实际发送给模型的输入(不含缓存读取),cache_read 是从缓存读取的
|
||||
# 总输入 = input_tokens + cache_read,缓存命中率 = cache_read / 总输入
|
||||
total_input_with_cache = all_time_input_tokens + all_time_cache_read
|
||||
cache_hit_rate = (
|
||||
round((all_time_cache_read / total_input_with_cache) * 100, 1)
|
||||
if total_input_with_cache > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
# 今日缓存统计
|
||||
cache_stats_today = (
|
||||
db.query(
|
||||
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
|
||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||
)
|
||||
.filter(and_(Usage.user_id == user.id, Usage.created_at >= today))
|
||||
.first()
|
||||
)
|
||||
cache_creation_tokens_today = (
|
||||
int(cache_stats_today.cache_creation_tokens or 0) if cache_stats_today else 0
|
||||
)
|
||||
cache_read_tokens_today = (
|
||||
int(cache_stats_today.cache_read_tokens or 0) if cache_stats_today else 0
|
||||
)
|
||||
|
||||
# 配额状态
|
||||
if user.quota_usd is None:
|
||||
quota_value = "无限制"
|
||||
quota_change = f"已用 ${user.used_usd:.2f}"
|
||||
quota_high = False
|
||||
elif user.quota_usd and user.quota_usd > 0:
|
||||
percent = min(100, int((user.used_usd / user.quota_usd) * 100))
|
||||
quota_value = "无限制"
|
||||
quota_change = f"已用 ${user.used_usd:.2f}"
|
||||
quota_high = percent > 80
|
||||
else:
|
||||
quota_value = "0%"
|
||||
quota_change = f"已用 ${user.used_usd:.2f}"
|
||||
quota_high = False
|
||||
|
||||
return {
|
||||
"stats": [
|
||||
{
|
||||
"name": "API 密钥",
|
||||
"value": f"{active_keys}/{user_api_keys}",
|
||||
"icon": "Key",
|
||||
},
|
||||
{
|
||||
"name": "本月请求",
|
||||
"value": f"{user_requests:,}",
|
||||
"change": f"今日 {requests_today}",
|
||||
"changeType": (
|
||||
"increase"
|
||||
if requests_today > requests_yesterday
|
||||
else ("decrease" if requests_today < requests_yesterday else "neutral")
|
||||
),
|
||||
"icon": "Activity",
|
||||
},
|
||||
{
|
||||
"name": "配额使用",
|
||||
"value": quota_value,
|
||||
"change": quota_change,
|
||||
"changeType": "increase" if quota_high else "neutral",
|
||||
"icon": "TrendingUp",
|
||||
},
|
||||
{
|
||||
"name": "本月费用",
|
||||
"value": f"${user_cost:.2f}",
|
||||
"icon": "DollarSign",
|
||||
},
|
||||
],
|
||||
"today": {
|
||||
"requests": requests_today,
|
||||
"cost": cost_today,
|
||||
"tokens": tokens_today,
|
||||
"cache_creation_tokens": cache_creation_tokens_today,
|
||||
"cache_read_tokens": cache_read_tokens_today,
|
||||
},
|
||||
# 全局 Token 详细分类(与管理员端对齐)
|
||||
"token_breakdown": {
|
||||
"input": all_time_input_tokens,
|
||||
"output": all_time_output_tokens,
|
||||
"cache_creation": all_time_cache_creation,
|
||||
"cache_read": all_time_cache_read,
|
||||
},
|
||||
# 用户视角:缓存使用情况
|
||||
"cache_stats": {
|
||||
"cache_creation_tokens": cache_creation_tokens,
|
||||
"cache_read_tokens": cache_read_tokens,
|
||||
"cache_hit_rate": cache_hit_rate,
|
||||
"total_cache_tokens": cache_creation_tokens + cache_read_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DashboardRecentRequestsAdapter(DashboardAdapter):
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = context.user
|
||||
query = db.query(Usage)
|
||||
if user.role != UserRole.ADMIN:
|
||||
query = query.filter(Usage.user_id == user.id)
|
||||
|
||||
recent_requests = query.order_by(Usage.created_at.desc()).limit(self.limit).all()
|
||||
|
||||
results = []
|
||||
for req in recent_requests:
|
||||
owner = db.query(DBUser).filter(DBUser.id == req.user_id).first()
|
||||
results.append(
|
||||
{
|
||||
"id": req.id,
|
||||
"user": owner.username if owner else "Unknown",
|
||||
"model": req.model or "N/A",
|
||||
"tokens": req.total_tokens,
|
||||
"time": req.created_at.strftime("%H:%M") if req.created_at else None,
|
||||
"is_stream": req.is_stream,
|
||||
}
|
||||
)
|
||||
|
||||
return {"requests": results}
|
||||
|
||||
|
||||
# NOTE: DashboardRequestDetailAdapter has been moved to AdminUsageDetailAdapter
|
||||
# in src/api/admin/usage/routes.py
|
||||
|
||||
|
||||
class DashboardProviderStatusAdapter(DashboardAdapter):
|
||||
@cache_result(key_prefix="dashboard:provider:status", ttl=60, user_specific=False)
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = context.user
|
||||
providers = db.query(Provider).filter(Provider.is_active.is_(True)).all()
|
||||
since = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
|
||||
entries = []
|
||||
for provider in providers:
|
||||
count = (
|
||||
db.query(func.count(Usage.id))
|
||||
.filter(and_(Usage.provider == provider.name, Usage.created_at >= since))
|
||||
.scalar()
|
||||
)
|
||||
entries.append(
|
||||
{
|
||||
"name": provider.name,
|
||||
"status": "active" if provider.is_active else "inactive",
|
||||
"requests": count,
|
||||
}
|
||||
)
|
||||
|
||||
entries.sort(key=lambda x: x["requests"], reverse=True)
|
||||
limit = 10 if user.role == UserRole.ADMIN else 5
|
||||
return {"providers": entries[:limit]}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DashboardDailyStatsAdapter(DashboardAdapter):
|
||||
days: int
|
||||
|
||||
@cache_result(key_prefix="dashboard:daily:stats", ttl=300, user_specific=True)
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = context.user
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = now.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
start_date = (end_date - timedelta(days=self.days - 1)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
# ==================== 使用预聚合数据优化 ====================
|
||||
if is_admin:
|
||||
# 管理员:从 stats_daily 获取历史数据
|
||||
daily_stats = (
|
||||
db.query(StatsDaily)
|
||||
.filter(and_(StatsDaily.date >= start_date, StatsDaily.date < today))
|
||||
.order_by(StatsDaily.date.asc())
|
||||
.all()
|
||||
)
|
||||
stats_map = {
|
||||
stat.date.replace(tzinfo=timezone.utc).date().isoformat(): {
|
||||
"requests": stat.total_requests,
|
||||
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
|
||||
"cost": stat.total_cost,
|
||||
"avg_response_time": stat.avg_response_time_ms / 1000.0 if stat.avg_response_time_ms else 0,
|
||||
"unique_models": getattr(stat, 'unique_models', 0) or 0,
|
||||
"unique_providers": getattr(stat, 'unique_providers', 0) or 0,
|
||||
"fallback_count": stat.fallback_count or 0,
|
||||
}
|
||||
for stat in daily_stats
|
||||
}
|
||||
|
||||
# 今日实时数据
|
||||
today_stats = StatsAggregatorService.get_today_realtime_stats(db)
|
||||
today_str = today.date().isoformat()
|
||||
if today_stats["total_requests"] > 0:
|
||||
# 今日平均响应时间需要单独查询
|
||||
today_avg_rt = (
|
||||
db.query(func.avg(Usage.response_time_ms))
|
||||
.filter(Usage.created_at >= today, Usage.response_time_ms.isnot(None))
|
||||
.scalar() or 0
|
||||
)
|
||||
# 今日 unique_models 和 unique_providers
|
||||
today_unique_models = (
|
||||
db.query(func.count(func.distinct(Usage.model)))
|
||||
.filter(Usage.created_at >= today)
|
||||
.scalar() or 0
|
||||
)
|
||||
today_unique_providers = (
|
||||
db.query(func.count(func.distinct(Usage.provider)))
|
||||
.filter(Usage.created_at >= today)
|
||||
.scalar() or 0
|
||||
)
|
||||
# 今日 fallback_count
|
||||
today_fallback_count = (
|
||||
db.query(func.count())
|
||||
.select_from(
|
||||
db.query(RequestCandidate.request_id)
|
||||
.filter(
|
||||
RequestCandidate.created_at >= today,
|
||||
RequestCandidate.status.in_(["success", "failed"]),
|
||||
)
|
||||
.group_by(RequestCandidate.request_id)
|
||||
.having(func.count(RequestCandidate.id) > 1)
|
||||
.subquery()
|
||||
)
|
||||
.scalar() or 0
|
||||
)
|
||||
stats_map[today_str] = {
|
||||
"requests": today_stats["total_requests"],
|
||||
"tokens": (today_stats["input_tokens"] + today_stats["output_tokens"] +
|
||||
today_stats["cache_creation_tokens"] + today_stats["cache_read_tokens"]),
|
||||
"cost": today_stats["total_cost"],
|
||||
"avg_response_time": float(today_avg_rt) / 1000.0 if today_avg_rt else 0,
|
||||
"unique_models": today_unique_models,
|
||||
"unique_providers": today_unique_providers,
|
||||
"fallback_count": today_fallback_count,
|
||||
}
|
||||
else:
|
||||
# 普通用户:仍需实时查询(用户级预聚合可选)
|
||||
query = db.query(Usage).filter(
|
||||
and_(
|
||||
Usage.user_id == user.id,
|
||||
Usage.created_at >= start_date,
|
||||
Usage.created_at <= end_date,
|
||||
)
|
||||
)
|
||||
|
||||
user_daily_stats = (
|
||||
query.with_entities(
|
||||
func.date(Usage.created_at).label("date"),
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time"),
|
||||
)
|
||||
.group_by(func.date(Usage.created_at))
|
||||
.order_by(func.date(Usage.created_at).asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
stats_map = {
|
||||
stat.date.isoformat(): {
|
||||
"requests": stat.requests or 0,
|
||||
"tokens": int(stat.tokens or 0),
|
||||
"cost": float(stat.cost or 0),
|
||||
"avg_response_time": float(stat.avg_response_time or 0) / 1000.0 if stat.avg_response_time else 0,
|
||||
}
|
||||
for stat in user_daily_stats
|
||||
}
|
||||
|
||||
# 构建完整日期序列
|
||||
current_date = start_date.date()
|
||||
formatted: List[dict] = []
|
||||
while current_date <= end_date.date():
|
||||
date_str = current_date.isoformat()
|
||||
stat = stats_map.get(date_str)
|
||||
if stat:
|
||||
formatted.append({
|
||||
"date": date_str,
|
||||
"requests": stat["requests"],
|
||||
"tokens": stat["tokens"],
|
||||
"cost": stat["cost"],
|
||||
"avg_response_time": stat["avg_response_time"],
|
||||
"unique_models": stat.get("unique_models", 0),
|
||||
"unique_providers": stat.get("unique_providers", 0),
|
||||
"fallback_count": stat.get("fallback_count", 0),
|
||||
})
|
||||
else:
|
||||
formatted.append({
|
||||
"date": date_str,
|
||||
"requests": 0,
|
||||
"tokens": 0,
|
||||
"cost": 0.0,
|
||||
"avg_response_time": 0.0,
|
||||
"unique_models": 0,
|
||||
"unique_providers": 0,
|
||||
"fallback_count": 0,
|
||||
})
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
# ==================== 模型统计(仍需实时查询)====================
|
||||
model_query = db.query(Usage)
|
||||
if not is_admin:
|
||||
model_query = model_query.filter(Usage.user_id == user.id)
|
||||
model_query = model_query.filter(
|
||||
and_(Usage.created_at >= start_date, Usage.created_at <= end_date)
|
||||
)
|
||||
|
||||
model_stats = (
|
||||
model_query.with_entities(
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time"),
|
||||
)
|
||||
.group_by(Usage.model)
|
||||
.order_by(func.sum(Usage.total_cost_usd).desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
model_summary = [
|
||||
{
|
||||
"model": stat.model,
|
||||
"requests": stat.requests or 0,
|
||||
"tokens": int(stat.tokens or 0),
|
||||
"cost": float(stat.cost or 0),
|
||||
"avg_response_time": (
|
||||
float(stat.avg_response_time or 0) / 1000.0 if stat.avg_response_time else 0
|
||||
),
|
||||
"cost_per_request": float(stat.cost or 0) / max(stat.requests or 1, 1),
|
||||
"tokens_per_request": int(stat.tokens or 0) / max(stat.requests or 1, 1),
|
||||
}
|
||||
for stat in model_stats
|
||||
]
|
||||
|
||||
daily_model_stats = (
|
||||
model_query.with_entities(
|
||||
func.date(Usage.created_at).label("date"),
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("cost"),
|
||||
)
|
||||
.group_by(func.date(Usage.created_at), Usage.model)
|
||||
.order_by(func.date(Usage.created_at).desc(), func.sum(Usage.total_cost_usd).desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
breakdown = {}
|
||||
for stat in daily_model_stats:
|
||||
date_str = stat.date.isoformat()
|
||||
breakdown.setdefault(date_str, []).append(
|
||||
{
|
||||
"model": stat.model,
|
||||
"requests": stat.requests or 0,
|
||||
"tokens": int(stat.tokens or 0),
|
||||
"cost": float(stat.cost or 0),
|
||||
}
|
||||
)
|
||||
|
||||
for item in formatted:
|
||||
item["model_breakdown"] = breakdown.get(item["date"], [])
|
||||
|
||||
return {
|
||||
"daily_stats": formatted,
|
||||
"model_summary": model_summary,
|
||||
"period": {
|
||||
"start_date": start_date.date().isoformat(),
|
||||
"end_date": end_date.date().isoformat(),
|
||||
"days": self.days,
|
||||
},
|
||||
}
|
||||
99
src/api/handlers/__init__.py
Normal file
99
src/api/handlers/__init__.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
API Handlers - 请求处理器
|
||||
|
||||
按 API 格式组织的 Adapter 和 Handler:
|
||||
- Adapter: 请求验证、格式转换、错误处理
|
||||
- Handler: 业务逻辑、调用 Provider、记录用量
|
||||
|
||||
支持的格式:
|
||||
- claude: Claude Chat API (/v1/messages)
|
||||
- claude_cli: Claude CLI 透传模式
|
||||
- openai: OpenAI Chat API (/v1/chat/completions)
|
||||
- openai_cli: OpenAI CLI 透传模式
|
||||
|
||||
注意:Handler 基类和具体 Handler 使用延迟导入以避免循环依赖。
|
||||
"""
|
||||
|
||||
# Adapter 基类(不会引起循环导入,可以直接导入)
|
||||
from src.api.handlers.base import (
|
||||
ChatAdapterBase,
|
||||
CliAdapterBase,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Adapter 基类
|
||||
"ChatAdapterBase",
|
||||
"CliAdapterBase",
|
||||
# Handler 基类(延迟导入)
|
||||
"ChatHandlerBase",
|
||||
"CliMessageHandlerBase",
|
||||
"BaseMessageHandler",
|
||||
"MessageHandlerProtocol",
|
||||
"MessageTelemetry",
|
||||
"StreamContext",
|
||||
# Claude
|
||||
"ClaudeChatAdapter",
|
||||
"ClaudeTokenCountAdapter",
|
||||
"build_claude_adapter",
|
||||
"ClaudeChatHandler",
|
||||
# Claude CLI
|
||||
"ClaudeCliAdapter",
|
||||
"ClaudeCliMessageHandler",
|
||||
# OpenAI
|
||||
"OpenAIChatAdapter",
|
||||
"OpenAIChatHandler",
|
||||
# OpenAI CLI
|
||||
"OpenAICliAdapter",
|
||||
"OpenAICliMessageHandler",
|
||||
]
|
||||
|
||||
# 延迟导入映射表
|
||||
_LAZY_IMPORTS = {
|
||||
# Handler 基类
|
||||
"ChatHandlerBase": ("src.api.handlers.base.chat_handler_base", "ChatHandlerBase"),
|
||||
"CliMessageHandlerBase": (
|
||||
"src.api.handlers.base.cli_handler_base",
|
||||
"CliMessageHandlerBase",
|
||||
),
|
||||
"StreamContext": ("src.api.handlers.base.cli_handler_base", "StreamContext"),
|
||||
"BaseMessageHandler": ("src.api.handlers.base.base_handler", "BaseMessageHandler"),
|
||||
"MessageHandlerProtocol": (
|
||||
"src.api.handlers.base.base_handler",
|
||||
"MessageHandlerProtocol",
|
||||
),
|
||||
"MessageTelemetry": ("src.api.handlers.base.base_handler", "MessageTelemetry"),
|
||||
# Claude
|
||||
"ClaudeChatAdapter": ("src.api.handlers.claude.adapter", "ClaudeChatAdapter"),
|
||||
"ClaudeTokenCountAdapter": (
|
||||
"src.api.handlers.claude.adapter",
|
||||
"ClaudeTokenCountAdapter",
|
||||
),
|
||||
"build_claude_adapter": ("src.api.handlers.claude.adapter", "build_claude_adapter"),
|
||||
"ClaudeChatHandler": ("src.api.handlers.claude.handler", "ClaudeChatHandler"),
|
||||
# Claude CLI
|
||||
"ClaudeCliAdapter": ("src.api.handlers.claude_cli.adapter", "ClaudeCliAdapter"),
|
||||
"ClaudeCliMessageHandler": (
|
||||
"src.api.handlers.claude_cli.handler",
|
||||
"ClaudeCliMessageHandler",
|
||||
),
|
||||
# OpenAI
|
||||
"OpenAIChatAdapter": ("src.api.handlers.openai.adapter", "OpenAIChatAdapter"),
|
||||
"OpenAIChatHandler": ("src.api.handlers.openai.handler", "OpenAIChatHandler"),
|
||||
# OpenAI CLI
|
||||
"OpenAICliAdapter": ("src.api.handlers.openai_cli.adapter", "OpenAICliAdapter"),
|
||||
"OpenAICliMessageHandler": (
|
||||
"src.api.handlers.openai_cli.handler",
|
||||
"OpenAICliMessageHandler",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""延迟导入以避免循环依赖"""
|
||||
if name in _LAZY_IMPORTS:
|
||||
module_path, attr_name = _LAZY_IMPORTS[name]
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, attr_name)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
68
src/api/handlers/base/__init__.py
Normal file
68
src/api/handlers/base/__init__.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Handler 基类模块
|
||||
|
||||
提供 Adapter、Handler 的抽象基类,以及请求构建器和响应解析器。
|
||||
|
||||
注意:Handler 基类(ChatHandlerBase, CliMessageHandlerBase 等)不在这里导出,
|
||||
因为它们依赖 services.usage.stream,而后者又需要导入 response_parser,
|
||||
会形成循环导入。请直接从具体模块导入 Handler 基类。
|
||||
"""
|
||||
|
||||
# Chat Adapter 基类(不会引起循环导入)
|
||||
from src.api.handlers.base.chat_adapter_base import (
|
||||
ChatAdapterBase,
|
||||
get_adapter_class,
|
||||
get_adapter_instance,
|
||||
list_registered_formats,
|
||||
register_adapter,
|
||||
)
|
||||
|
||||
# CLI Adapter 基类
|
||||
from src.api.handlers.base.cli_adapter_base import (
|
||||
CliAdapterBase,
|
||||
get_cli_adapter_class,
|
||||
get_cli_adapter_instance,
|
||||
list_registered_cli_formats,
|
||||
register_cli_adapter,
|
||||
)
|
||||
|
||||
# 请求构建器
|
||||
from src.api.handlers.base.request_builder import (
|
||||
SENSITIVE_HEADERS,
|
||||
PassthroughRequestBuilder,
|
||||
RequestBuilder,
|
||||
build_passthrough_request,
|
||||
)
|
||||
|
||||
# 响应解析器
|
||||
from src.api.handlers.base.response_parser import (
|
||||
ParsedChunk,
|
||||
ParsedResponse,
|
||||
ResponseParser,
|
||||
StreamStats,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Chat Adapter
|
||||
"ChatAdapterBase",
|
||||
"register_adapter",
|
||||
"get_adapter_class",
|
||||
"get_adapter_instance",
|
||||
"list_registered_formats",
|
||||
# CLI Adapter
|
||||
"CliAdapterBase",
|
||||
"register_cli_adapter",
|
||||
"get_cli_adapter_class",
|
||||
"get_cli_adapter_instance",
|
||||
"list_registered_cli_formats",
|
||||
# 请求构建器
|
||||
"RequestBuilder",
|
||||
"PassthroughRequestBuilder",
|
||||
"build_passthrough_request",
|
||||
"SENSITIVE_HEADERS",
|
||||
# 响应解析器
|
||||
"ResponseParser",
|
||||
"ParsedChunk",
|
||||
"ParsedResponse",
|
||||
"StreamStats",
|
||||
]
|
||||
363
src/api/handlers/base/base_handler.py
Normal file
363
src/api/handlers/base/base_handler.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
基础消息处理器,封装通用的编排、转换、遥测逻辑。
|
||||
|
||||
接口约定:
|
||||
- process_stream: 处理流式请求,返回 StreamingResponse
|
||||
- process_sync: 处理非流式请求,返回 JSONResponse
|
||||
|
||||
签名规范(推荐):
|
||||
async def process_stream(
|
||||
self,
|
||||
request: Any, # 解析后的请求模型
|
||||
http_request: Request, # FastAPI Request 对象
|
||||
original_headers: Dict[str, str], # 原始请求头
|
||||
original_request_body: Dict[str, Any], # 原始请求体
|
||||
query_params: Optional[Dict[str, str]] = None, # 查询参数
|
||||
) -> StreamingResponse: ...
|
||||
|
||||
async def process_sync(
|
||||
self,
|
||||
request: Any,
|
||||
http_request: Request,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
query_params: Optional[Dict[str, str]] = None,
|
||||
) -> JSONResponse: ...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.clients.redis_client import get_redis_client_sync
|
||||
from src.core.api_format_metadata import resolve_api_format
|
||||
from src.core.enums import APIFormat
|
||||
from src.core.logger import logger
|
||||
from src.services.orchestration.fallback_orchestrator import FallbackOrchestrator
|
||||
from src.services.provider.format import normalize_api_format
|
||||
from src.services.system.audit import audit_service
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
|
||||
|
||||
class MessageTelemetry:
|
||||
"""
|
||||
负责记录 Usage/Audit,避免处理器里重复代码。
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, user, api_key, request_id: str, client_ip: str):
|
||||
self.db = db
|
||||
self.user = user
|
||||
self.api_key = api_key
|
||||
self.request_id = request_id
|
||||
self.client_ip = client_ip
|
||||
|
||||
async def calculate_cost(
|
||||
self,
|
||||
provider: str,
|
||||
model: str,
|
||||
*,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
) -> float:
|
||||
input_price, output_price = await UsageService.get_model_price_async(
|
||||
self.db, provider, model
|
||||
)
|
||||
_, _, _, _, _, _, total_cost = UsageService.calculate_cost(
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
input_price,
|
||||
output_price,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
*await UsageService.get_cache_prices_async(self.db, provider, model, input_price),
|
||||
)
|
||||
return total_cost
|
||||
|
||||
async def record_success(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
response_time_ms: int,
|
||||
status_code: int,
|
||||
request_body: Dict[str, Any],
|
||||
request_headers: Dict[str, Any],
|
||||
response_body: Any,
|
||||
response_headers: Dict[str, Any],
|
||||
cache_creation_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
is_stream: bool = False,
|
||||
provider_request_headers: Optional[Dict[str, Any]] = None,
|
||||
# Provider 侧追踪信息(用于记录真实成本)
|
||||
provider_id: Optional[str] = None,
|
||||
provider_endpoint_id: Optional[str] = None,
|
||||
provider_api_key_id: Optional[str] = None,
|
||||
api_format: Optional[str] = None,
|
||||
# 模型映射信息
|
||||
target_model: Optional[str] = None,
|
||||
# Provider 响应元数据(如 Gemini 的 modelVersion)
|
||||
response_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> float:
|
||||
total_cost = await self.calculate_cost(
|
||||
provider,
|
||||
model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
)
|
||||
|
||||
await UsageService.record_usage(
|
||||
db=self.db,
|
||||
user=self.user,
|
||||
api_key=self.api_key,
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_tokens,
|
||||
cache_read_input_tokens=cache_read_tokens,
|
||||
request_type="chat",
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
provider_request_headers=provider_request_headers or {},
|
||||
response_headers=response_headers,
|
||||
response_body=response_body,
|
||||
request_id=self.request_id,
|
||||
# Provider 侧追踪信息(用于记录真实成本)
|
||||
provider_id=provider_id,
|
||||
provider_endpoint_id=provider_endpoint_id,
|
||||
provider_api_key_id=provider_api_key_id,
|
||||
# 模型映射信息
|
||||
target_model=target_model,
|
||||
# Provider 响应元数据
|
||||
metadata=response_metadata,
|
||||
)
|
||||
|
||||
if self.user and self.api_key:
|
||||
audit_service.log_api_request(
|
||||
db=self.db,
|
||||
user_id=self.user.id,
|
||||
api_key_id=self.api_key.id,
|
||||
request_id=self.request_id,
|
||||
model=model,
|
||||
provider=provider,
|
||||
success=True,
|
||||
ip_address=self.client_ip,
|
||||
status_code=status_code,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost_usd=total_cost,
|
||||
)
|
||||
|
||||
return total_cost
|
||||
|
||||
async def record_failure(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
response_time_ms: int,
|
||||
status_code: int,
|
||||
error_message: str,
|
||||
request_body: Dict[str, Any],
|
||||
request_headers: Dict[str, Any],
|
||||
is_stream: bool,
|
||||
api_format: Optional[str] = None,
|
||||
provider_request_headers: Optional[Dict[str, Any]] = None,
|
||||
# 预估 token 信息(来自 message_start 事件,用于中断请求的成本估算)
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
response_body: Optional[Dict[str, Any]] = None,
|
||||
# 模型映射信息
|
||||
target_model: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
记录失败请求
|
||||
|
||||
注意:Provider 链路信息(provider_id, endpoint_id, key_id)不在此处记录,
|
||||
因为 RequestCandidate 表已经记录了完整的请求链路追踪信息。
|
||||
|
||||
Args:
|
||||
input_tokens: 预估输入 tokens(来自 message_start,用于中断请求的成本估算)
|
||||
output_tokens: 预估输出 tokens(来自已收到的内容)
|
||||
cache_creation_tokens: 缓存创建 tokens
|
||||
cache_read_tokens: 缓存读取 tokens
|
||||
response_body: 响应体(如果有部分响应)
|
||||
target_model: 映射后的目标模型名(如果发生了映射)
|
||||
"""
|
||||
provider_name = provider or "unknown"
|
||||
if provider_name == "unknown":
|
||||
logger.warning(f"[Telemetry] Recording failure with unknown provider (request_id={self.request_id})")
|
||||
|
||||
await UsageService.record_usage(
|
||||
db=self.db,
|
||||
user=self.user,
|
||||
api_key=self.api_key,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_tokens,
|
||||
cache_read_input_tokens=cache_read_tokens,
|
||||
request_type="chat",
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
provider_request_headers=provider_request_headers or {},
|
||||
response_headers={},
|
||||
response_body=response_body or {"error": error_message},
|
||||
request_id=self.request_id,
|
||||
# 模型映射信息
|
||||
target_model=target_model,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MessageHandlerProtocol(Protocol):
|
||||
"""
|
||||
消息处理器协议 - 定义标准接口
|
||||
|
||||
ChatHandlerBase 使用完整签名(含 request, http_request)。
|
||||
CliMessageHandlerBase 使用简化签名(仅 original_request_body, original_headers)。
|
||||
"""
|
||||
|
||||
async def process_stream(
|
||||
self,
|
||||
request: Any,
|
||||
http_request: Request,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
query_params: Optional[Dict[str, str]] = None,
|
||||
) -> StreamingResponse:
|
||||
"""处理流式请求"""
|
||||
...
|
||||
|
||||
async def process_sync(
|
||||
self,
|
||||
request: Any,
|
||||
http_request: Request,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
query_params: Optional[Dict[str, str]] = None,
|
||||
) -> JSONResponse:
|
||||
"""处理非流式请求"""
|
||||
...
|
||||
|
||||
|
||||
class BaseMessageHandler:
|
||||
"""
|
||||
消息处理器基类,所有具体格式的 handler 可以继承它。
|
||||
|
||||
子类需要实现:
|
||||
- process_stream: 处理流式请求
|
||||
- process_sync: 处理非流式请求
|
||||
|
||||
推荐使用 MessageHandlerProtocol 中定义的签名。
|
||||
"""
|
||||
|
||||
# Adapter 检测器类型
|
||||
AdapterDetectorType = Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
db: Session,
|
||||
user,
|
||||
api_key,
|
||||
request_id: str,
|
||||
client_ip: str,
|
||||
user_agent: str,
|
||||
start_time: float,
|
||||
allowed_api_formats: Optional[list[str]] = None,
|
||||
adapter_detector: Optional[AdapterDetectorType] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.user = user
|
||||
self.api_key = api_key
|
||||
self.request_id = request_id
|
||||
self.client_ip = client_ip
|
||||
self.user_agent = user_agent
|
||||
self.start_time = start_time
|
||||
self.allowed_api_formats = allowed_api_formats or [APIFormat.CLAUDE.value]
|
||||
self.primary_api_format = normalize_api_format(self.allowed_api_formats[0])
|
||||
self.adapter_detector = adapter_detector
|
||||
|
||||
redis_client = get_redis_client_sync()
|
||||
self.orchestrator = FallbackOrchestrator(db, redis_client)
|
||||
self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip)
|
||||
|
||||
def elapsed_ms(self) -> int:
|
||||
return int((time.time() - self.start_time) * 1000)
|
||||
|
||||
def _resolve_capability_requirements(
|
||||
self,
|
||||
model_name: str,
|
||||
request_headers: Optional[Dict[str, str]] = None,
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
解析请求的能力需求
|
||||
|
||||
来源:
|
||||
1. 用户模型级配置 (User.model_capability_settings)
|
||||
2. 用户 API Key 强制配置 (ApiKey.force_capabilities)
|
||||
3. 请求头 X-Require-Capability
|
||||
4. Adapter 的 detect_capability_requirements(如 Claude 的 anthropic-beta)
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
request_headers: 请求头
|
||||
request_body: 请求体(可选)
|
||||
|
||||
Returns:
|
||||
能力需求字典
|
||||
"""
|
||||
from src.services.capability.resolver import CapabilityResolver
|
||||
|
||||
return CapabilityResolver.resolve_requirements(
|
||||
user=self.user,
|
||||
user_api_key=self.api_key,
|
||||
model_name=model_name,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
adapter_detector=self.adapter_detector,
|
||||
)
|
||||
|
||||
def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat:
|
||||
"""根据 provider_type 解析 API 格式,未知类型默认 OPENAI"""
|
||||
if provider_type:
|
||||
return resolve_api_format(provider_type, default=APIFormat.OPENAI)
|
||||
return self.primary_api_format
|
||||
|
||||
def build_provider_payload(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
*,
|
||||
mapped_model: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建发送给 Provider 的请求体,替换 model 名称"""
|
||||
payload = dict(original_body)
|
||||
if mapped_model:
|
||||
payload["model"] = mapped_model
|
||||
return payload
|
||||
724
src/api/handlers/base/chat_adapter_base.py
Normal file
724
src/api/handlers/base/chat_adapter_base.py
Normal file
@@ -0,0 +1,724 @@
|
||||
"""
|
||||
Chat Adapter 通用基类
|
||||
|
||||
提供 Chat 格式(进行请求验证和标准化)的通用适配器逻辑:
|
||||
- 请求解析和验证
|
||||
- 审计日志记录
|
||||
- 错误处理和响应格式化
|
||||
- Handler 创建和调用
|
||||
- 计费策略(支持不同 API 格式的差异化计费)
|
||||
|
||||
子类只需提供:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: 对应的 ChatHandlerBase 子类
|
||||
- _validate_request_body(): 可选覆盖请求验证逻辑
|
||||
- _build_audit_metadata(): 可选覆盖审计元数据构建
|
||||
- compute_total_input_context(): 可选覆盖总输入上下文计算(用于阶梯计费判定)
|
||||
"""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.core.exceptions import (
|
||||
InvalidRequestException,
|
||||
ModelNotSupportedException,
|
||||
ProviderAuthException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderRateLimitException,
|
||||
ProviderTimeoutException,
|
||||
ProxyException,
|
||||
QuotaExceededException,
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
|
||||
|
||||
class ChatAdapterBase(ApiAdapter):
|
||||
"""
|
||||
Chat Adapter 通用基类
|
||||
|
||||
提供 Chat 格式的通用适配器逻辑,子类只需配置:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: ChatHandlerBase 子类
|
||||
- name: 适配器名称
|
||||
"""
|
||||
|
||||
# 子类必须覆盖
|
||||
FORMAT_ID: str = "UNKNOWN"
|
||||
HANDLER_CLASS: Type[ChatHandlerBase]
|
||||
|
||||
# 适配器配置
|
||||
name: str = "chat.base"
|
||||
mode = ApiMode.STANDARD
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||
self.response_normalizer = None
|
||||
# 可选启用响应规范化
|
||||
self._init_response_normalizer()
|
||||
|
||||
def _init_response_normalizer(self):
|
||||
"""初始化响应规范化器 - 子类可覆盖"""
|
||||
try:
|
||||
from src.services.provider.response_normalizer import ResponseNormalizer
|
||||
|
||||
self.response_normalizer = ResponseNormalizer()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
"""处理 Chat API 请求"""
|
||||
http_request = context.request
|
||||
user = context.user
|
||||
api_key = context.api_key
|
||||
db = context.db
|
||||
request_id = context.request_id
|
||||
quota_remaining_value = context.quota_remaining
|
||||
start_time = context.start_time
|
||||
client_ip = context.client_ip
|
||||
user_agent = context.user_agent
|
||||
original_headers = context.original_headers
|
||||
query_params = context.query_params
|
||||
|
||||
original_request_body = context.ensure_json_body()
|
||||
|
||||
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
|
||||
if context.path_params:
|
||||
original_request_body = self._merge_path_params(
|
||||
original_request_body, context.path_params
|
||||
)
|
||||
|
||||
# 验证和解析请求
|
||||
request_obj = self._validate_request_body(original_request_body, context.path_params)
|
||||
if isinstance(request_obj, JSONResponse):
|
||||
return request_obj
|
||||
|
||||
stream = getattr(request_obj, "stream", False)
|
||||
model = getattr(request_obj, "model", "unknown")
|
||||
|
||||
# 添加审计元数据
|
||||
audit_metadata = self._build_audit_metadata(original_request_body, request_obj)
|
||||
context.add_audit_metadata(**audit_metadata)
|
||||
|
||||
# 格式化额度显示
|
||||
quota_display = (
|
||||
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
|
||||
)
|
||||
|
||||
# 请求开始日志
|
||||
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
|
||||
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
|
||||
|
||||
try:
|
||||
# 检查客户端连接
|
||||
if await http_request.is_disconnected():
|
||||
logger.warning("客户端连接断开")
|
||||
raise HTTPException(status_code=499, detail="Client disconnected")
|
||||
|
||||
# 创建 Handler
|
||||
handler = self._create_handler(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
request_id=request_id,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
# 处理请求
|
||||
if stream:
|
||||
return await handler.process_stream(
|
||||
request=request_obj,
|
||||
http_request=http_request,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
query_params=query_params,
|
||||
)
|
||||
return await handler.process_sync(
|
||||
request=request_obj,
|
||||
http_request=http_request,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
query_params=query_params,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except (
|
||||
ModelNotSupportedException,
|
||||
QuotaExceededException,
|
||||
InvalidRequestException,
|
||||
) as e:
|
||||
logger.info(f"客户端请求错误: {e.error_type}")
|
||||
return self._error_response(
|
||||
status_code=e.status_code,
|
||||
error_type=(
|
||||
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
|
||||
),
|
||||
message=e.message,
|
||||
)
|
||||
|
||||
except (
|
||||
ProviderAuthException,
|
||||
ProviderRateLimitException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderTimeoutException,
|
||||
UpstreamClientException,
|
||||
) as e:
|
||||
return await self._handle_provider_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return await self._handle_unexpected_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _create_handler(
|
||||
self,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
request_id: str,
|
||||
client_ip: str,
|
||||
user_agent: str,
|
||||
start_time: float,
|
||||
):
|
||||
"""创建 Handler 实例 - 子类可覆盖"""
|
||||
return self.HANDLER_CLASS(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
request_id=request_id,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
start_time=start_time,
|
||||
allowed_api_formats=self.allowed_api_formats,
|
||||
response_normalizer=self.response_normalizer,
|
||||
enable_response_normalization=self.response_normalizer is not None,
|
||||
adapter_detector=self.detect_capability_requirements,
|
||||
)
|
||||
|
||||
def _merge_path_params(
|
||||
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并 URL 路径参数到请求体 - 子类可覆盖
|
||||
|
||||
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数字典
|
||||
|
||||
Returns:
|
||||
合并后的请求体字典
|
||||
"""
|
||||
merged = original_request_body.copy()
|
||||
for key, value in path_params.items():
|
||||
if key not in merged:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
@abstractmethod
|
||||
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
||||
"""
|
||||
验证请求体 - 子类必须实现
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数(如 Gemini 的 stream 通过 URL 端点传入)
|
||||
|
||||
Returns:
|
||||
验证后的请求对象,或 JSONResponse 错误响应
|
||||
"""
|
||||
pass
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
|
||||
"""
|
||||
提取消息数量 - 子类可覆盖
|
||||
|
||||
默认实现:从 messages 字段提取
|
||||
"""
|
||||
messages = payload.get("messages", [])
|
||||
if hasattr(request_obj, "messages"):
|
||||
messages = request_obj.messages
|
||||
return len(messages) if isinstance(messages, list) else 0
|
||||
|
||||
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
||||
"""
|
||||
构建审计日志元数据 - 子类可覆盖
|
||||
"""
|
||||
model = getattr(request_obj, "model", payload.get("model", "unknown"))
|
||||
stream = getattr(request_obj, "stream", payload.get("stream", False))
|
||||
messages_count = self._extract_message_count(payload, request_obj)
|
||||
|
||||
return {
|
||||
"action": f"{self.FORMAT_ID.lower()}_request",
|
||||
"model": model,
|
||||
"stream": bool(stream),
|
||||
"max_tokens": getattr(request_obj, "max_tokens", payload.get("max_tokens")),
|
||||
"messages_count": messages_count,
|
||||
"temperature": getattr(request_obj, "temperature", payload.get("temperature")),
|
||||
"top_p": getattr(request_obj, "top_p", payload.get("top_p")),
|
||||
}
|
||||
|
||||
async def _handle_provider_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理 Provider 相关异常"""
|
||||
logger.debug(f"Caught provider exception: {type(e).__name__}")
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
# 确定错误消息
|
||||
if isinstance(e, ProviderAuthException):
|
||||
error_message = (
|
||||
f"提供商认证失败: {str(e)}"
|
||||
if result.metadata.provider != "unknown"
|
||||
else "服务端错误: 无可用提供商"
|
||||
)
|
||||
result.error_message = error_message
|
||||
|
||||
# 处理上游客户端错误(如图片处理失败)
|
||||
if isinstance(e, UpstreamClientException):
|
||||
# 返回 400 状态码和清晰的错误消息
|
||||
result.status_code = e.status_code
|
||||
result.error_message = e.message
|
||||
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
|
||||
# 根据异常类型确定错误类型
|
||||
if isinstance(e, UpstreamClientException):
|
||||
error_type = "invalid_request_error"
|
||||
elif result.status_code == 503:
|
||||
error_type = "internal_server_error"
|
||||
else:
|
||||
error_type = "rate_limit_exceeded"
|
||||
|
||||
return self._error_response(
|
||||
status_code=result.status_code,
|
||||
error_type=error_type,
|
||||
message=result.error_message or str(e),
|
||||
)
|
||||
|
||||
async def _handle_unexpected_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理未预期的异常"""
|
||||
if isinstance(e, ProxyException):
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
|
||||
else:
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
|
||||
exception=e,
|
||||
extra_data={
|
||||
"exception_class": e.__class__.__name__,
|
||||
"processing_stage": "request_processing",
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
"traceback_preview": str(traceback.format_exc())[:500],
|
||||
},
|
||||
)
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
# 对于未预期的异常,强制设置状态码为 500
|
||||
result.status_code = 500
|
||||
result.error_type = "internal_error"
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
try:
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
except Exception as record_error:
|
||||
logger.error(f"记录失败请求时出错: {record_error}")
|
||||
|
||||
return self._error_response(
|
||||
status_code=500,
|
||||
error_type="internal_server_error",
|
||||
message="处理请求时发生内部错误")
|
||||
|
||||
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
|
||||
"""生成错误响应 - 子类可覆盖以自定义格式"""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
|
||||
# =========================================================================
|
||||
|
||||
def compute_total_input_context(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
计算总输入上下文(用于阶梯计费判定)
|
||||
|
||||
默认实现:input_tokens + cache_read_input_tokens
|
||||
子类可覆盖此方法实现不同的计算逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
|
||||
|
||||
Returns:
|
||||
总输入上下文 token 数
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
input_price_per_1m: float,
|
||||
output_price_per_1m: float,
|
||||
cache_creation_price_per_1m: Optional[float],
|
||||
cache_read_price_per_1m: Optional[float],
|
||||
price_per_request: Optional[float],
|
||||
tiered_pricing: Optional[dict] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
output_tokens: 输出 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
input_price_per_1m: 输入价格(每 1M tokens)
|
||||
output_price_per_1m: 输出价格(每 1M tokens)
|
||||
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens)
|
||||
cache_read_price_per_1m: 缓存读取价格(每 1M tokens)
|
||||
price_per_request: 按次计费价格
|
||||
tiered_pricing: 阶梯计费配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
包含各项成本的字典:
|
||||
{
|
||||
"input_cost": float,
|
||||
"output_cost": float,
|
||||
"cache_creation_cost": float,
|
||||
"cache_read_cost": float,
|
||||
"cache_cost": float,
|
||||
"request_cost": float,
|
||||
"total_cost": float,
|
||||
"tier_index": Optional[int], # 命中的阶梯索引
|
||||
}
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""
|
||||
根据总输入 token 数确定价格阶梯
|
||||
|
||||
Args:
|
||||
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
|
||||
total_input_tokens: 总输入 token 数
|
||||
|
||||
Returns:
|
||||
匹配的阶梯配置
|
||||
"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||
return tiers[-1] if tiers else None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
||||
# =========================================================================
|
||||
|
||||
_ADAPTER_REGISTRY: Dict[str, Type["ChatAdapterBase"]] = {}
|
||||
_ADAPTERS_LOADED = False
|
||||
|
||||
|
||||
def register_adapter(adapter_class: Type["ChatAdapterBase"]) -> Type["ChatAdapterBase"]:
|
||||
"""
|
||||
注册 Adapter 类到注册表
|
||||
|
||||
用法:
|
||||
@register_adapter
|
||||
class ClaudeChatAdapter(ChatAdapterBase):
|
||||
FORMAT_ID = "CLAUDE"
|
||||
...
|
||||
|
||||
Args:
|
||||
adapter_class: Adapter 类
|
||||
|
||||
Returns:
|
||||
注册的 Adapter 类(支持作为装饰器使用)
|
||||
"""
|
||||
format_id = adapter_class.FORMAT_ID
|
||||
if format_id and format_id != "UNKNOWN":
|
||||
_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
|
||||
return adapter_class
|
||||
|
||||
|
||||
def _ensure_adapters_loaded():
|
||||
"""确保所有 Adapter 已被加载(触发注册)"""
|
||||
global _ADAPTERS_LOADED
|
||||
if _ADAPTERS_LOADED:
|
||||
return
|
||||
|
||||
# 导入各个 Adapter 模块以触发 @register_adapter 装饰器
|
||||
try:
|
||||
from src.api.handlers.claude import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.openai import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.gemini import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_ADAPTERS_LOADED = True
|
||||
|
||||
|
||||
def get_adapter_class(api_format: str) -> Optional[Type["ChatAdapterBase"]]:
|
||||
"""
|
||||
根据 API format 获取 Adapter 类
|
||||
|
||||
Args:
|
||||
api_format: API 格式标识(如 "CLAUDE", "OPENAI", "GEMINI")
|
||||
|
||||
Returns:
|
||||
对应的 Adapter 类,如果未找到返回 None
|
||||
"""
|
||||
_ensure_adapters_loaded()
|
||||
return _ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
|
||||
|
||||
|
||||
def get_adapter_instance(api_format: str) -> Optional["ChatAdapterBase"]:
|
||||
"""
|
||||
根据 API format 获取 Adapter 实例
|
||||
|
||||
Args:
|
||||
api_format: API 格式标识
|
||||
|
||||
Returns:
|
||||
Adapter 实例,如果未找到返回 None
|
||||
"""
|
||||
adapter_class = get_adapter_class(api_format)
|
||||
if adapter_class:
|
||||
return adapter_class()
|
||||
return None
|
||||
|
||||
|
||||
def list_registered_formats() -> list[str]:
|
||||
"""返回所有已注册的 API 格式"""
|
||||
_ensure_adapters_loaded()
|
||||
return list(_ADAPTER_REGISTRY.keys())
|
||||
1257
src/api/handlers/base/chat_handler_base.py
Normal file
1257
src/api/handlers/base/chat_handler_base.py
Normal file
File diff suppressed because it is too large
Load Diff
648
src/api/handlers/base/cli_adapter_base.py
Normal file
648
src/api/handlers/base/cli_adapter_base.py
Normal file
@@ -0,0 +1,648 @@
|
||||
"""
|
||||
CLI Adapter 通用基类
|
||||
|
||||
提供 CLI 格式(直接透传请求)的通用适配器逻辑:
|
||||
- 请求解析和验证
|
||||
- 审计日志记录
|
||||
- 错误处理和响应格式化
|
||||
- Handler 创建和调用
|
||||
- 计费策略(支持不同 API 格式的差异化计费)
|
||||
|
||||
子类只需提供:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: 对应的 MessageHandler 类
|
||||
- 可选覆盖 _extract_message_count() 自定义消息计数逻辑
|
||||
- 可选覆盖 compute_total_input_context() 自定义总输入上下文计算
|
||||
"""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
from src.core.exceptions import (
|
||||
InvalidRequestException,
|
||||
ModelNotSupportedException,
|
||||
ProviderAuthException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderRateLimitException,
|
||||
ProviderTimeoutException,
|
||||
ProxyException,
|
||||
QuotaExceededException,
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
|
||||
|
||||
class CliAdapterBase(ApiAdapter):
|
||||
"""
|
||||
CLI Adapter 通用基类
|
||||
|
||||
提供 CLI 格式的通用适配器逻辑,子类只需配置:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: MessageHandler 类
|
||||
- name: 适配器名称
|
||||
"""
|
||||
|
||||
# 子类必须覆盖
|
||||
FORMAT_ID: str = "UNKNOWN"
|
||||
HANDLER_CLASS: Type[CliMessageHandlerBase]
|
||||
|
||||
# 适配器配置
|
||||
name: str = "cli.base"
|
||||
mode = ApiMode.PROXY
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
"""处理 CLI API 请求"""
|
||||
http_request = context.request
|
||||
user = context.user
|
||||
api_key = context.api_key
|
||||
db = context.db
|
||||
request_id = context.request_id
|
||||
quota_remaining_value = context.quota_remaining
|
||||
start_time = context.start_time
|
||||
client_ip = context.client_ip
|
||||
user_agent = context.user_agent
|
||||
original_headers = context.original_headers
|
||||
query_params = context.query_params # 获取查询参数
|
||||
|
||||
original_request_body = context.ensure_json_body()
|
||||
|
||||
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
|
||||
if context.path_params:
|
||||
original_request_body = self._merge_path_params(
|
||||
original_request_body, context.path_params
|
||||
)
|
||||
|
||||
# 获取 stream:优先从请求体,其次从 path_params(如 Gemini 通过 URL 端点区分)
|
||||
stream = original_request_body.get("stream")
|
||||
if stream is None and context.path_params:
|
||||
stream = context.path_params.get("stream", False)
|
||||
stream = bool(stream)
|
||||
|
||||
# 获取 model:优先从请求体,其次从 path_params(如 Gemini 的 model 在 URL 路径中)
|
||||
model = original_request_body.get("model")
|
||||
if model is None and context.path_params:
|
||||
model = context.path_params.get("model", "unknown")
|
||||
model = model or "unknown"
|
||||
|
||||
# 提取请求元数据
|
||||
audit_metadata = self._build_audit_metadata(original_request_body, context.path_params)
|
||||
context.add_audit_metadata(**audit_metadata)
|
||||
|
||||
# 格式化额度显示
|
||||
quota_display = (
|
||||
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
|
||||
)
|
||||
|
||||
# 请求开始日志
|
||||
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
|
||||
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
|
||||
|
||||
try:
|
||||
# 检查客户端连接
|
||||
if await http_request.is_disconnected():
|
||||
logger.warning("客户端连接断开")
|
||||
raise HTTPException(status_code=499, detail="Client disconnected")
|
||||
|
||||
# 创建 Handler
|
||||
handler = self.HANDLER_CLASS(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
request_id=request_id,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
start_time=start_time,
|
||||
allowed_api_formats=self.allowed_api_formats,
|
||||
adapter_detector=self.detect_capability_requirements,
|
||||
)
|
||||
|
||||
# 处理请求
|
||||
if stream:
|
||||
return await handler.process_stream(
|
||||
original_request_body=original_request_body,
|
||||
original_headers=original_headers,
|
||||
query_params=query_params,
|
||||
path_params=context.path_params,
|
||||
)
|
||||
return await handler.process_sync(
|
||||
original_request_body=original_request_body,
|
||||
original_headers=original_headers,
|
||||
query_params=query_params,
|
||||
path_params=context.path_params,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except (
|
||||
ModelNotSupportedException,
|
||||
QuotaExceededException,
|
||||
InvalidRequestException,
|
||||
) as e:
|
||||
logger.debug(f"客户端请求错误: {e.error_type}")
|
||||
return self._error_response(
|
||||
status_code=e.status_code,
|
||||
error_type=(
|
||||
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
|
||||
),
|
||||
message=e.message,
|
||||
)
|
||||
|
||||
except (
|
||||
ProviderAuthException,
|
||||
ProviderRateLimitException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderTimeoutException,
|
||||
UpstreamClientException,
|
||||
) as e:
|
||||
return await self._handle_provider_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return await self._handle_unexpected_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _merge_path_params(
|
||||
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并 URL 路径参数到请求体 - 子类可覆盖
|
||||
|
||||
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数字典
|
||||
|
||||
Returns:
|
||||
合并后的请求体字典
|
||||
"""
|
||||
merged = original_request_body.copy()
|
||||
for key, value in path_params.items():
|
||||
if key not in merged:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
|
||||
"""
|
||||
提取消息数量 - 子类可覆盖
|
||||
|
||||
默认实现:从 input 字段提取
|
||||
"""
|
||||
if "input" not in payload:
|
||||
return 0
|
||||
input_data = payload["input"]
|
||||
if isinstance(input_data, list):
|
||||
return len(input_data)
|
||||
if isinstance(input_data, dict) and "messages" in input_data:
|
||||
return len(input_data.get("messages", []))
|
||||
return 0
|
||||
|
||||
def _build_audit_metadata(
|
||||
self,
|
||||
payload: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
构建审计日志元数据 - 子类可覆盖
|
||||
|
||||
Args:
|
||||
payload: 请求体
|
||||
path_params: URL 路径参数(用于获取 model 等)
|
||||
"""
|
||||
# 优先从请求体获取 model,其次从 path_params
|
||||
model = payload.get("model")
|
||||
if model is None and path_params:
|
||||
model = path_params.get("model", "unknown")
|
||||
model = model or "unknown"
|
||||
|
||||
stream = payload.get("stream", False)
|
||||
messages_count = self._extract_message_count(payload)
|
||||
|
||||
return {
|
||||
"action": f"{self.FORMAT_ID.lower()}_request",
|
||||
"model": model,
|
||||
"stream": bool(stream),
|
||||
"max_tokens": payload.get("max_tokens"),
|
||||
"messages_count": messages_count,
|
||||
"temperature": payload.get("temperature"),
|
||||
"top_p": payload.get("top_p"),
|
||||
"tool_count": len(payload.get("tools") or []),
|
||||
"instructions_present": bool(payload.get("instructions")),
|
||||
}
|
||||
|
||||
async def _handle_provider_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理 Provider 相关异常"""
|
||||
logger.debug(f"Caught provider exception: {type(e).__name__}")
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
# 确定错误消息
|
||||
if isinstance(e, ProviderAuthException):
|
||||
error_message = (
|
||||
f"提供商认证失败: {str(e)}"
|
||||
if result.metadata.provider != "unknown"
|
||||
else "服务端错误: 无可用提供商"
|
||||
)
|
||||
result.error_message = error_message
|
||||
|
||||
# 处理上游客户端错误(如图片处理失败)
|
||||
if isinstance(e, UpstreamClientException):
|
||||
# 返回 400 状态码和清晰的错误消息
|
||||
result.status_code = e.status_code
|
||||
result.error_message = e.message
|
||||
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
|
||||
# 根据异常类型确定错误类型
|
||||
if isinstance(e, UpstreamClientException):
|
||||
error_type = "invalid_request_error"
|
||||
elif result.status_code == 503:
|
||||
error_type = "internal_server_error"
|
||||
else:
|
||||
error_type = "rate_limit_exceeded"
|
||||
|
||||
return self._error_response(
|
||||
status_code=result.status_code,
|
||||
error_type=error_type,
|
||||
message=result.error_message or str(e),
|
||||
)
|
||||
|
||||
async def _handle_unexpected_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理未预期的异常"""
|
||||
if isinstance(e, ProxyException):
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
|
||||
else:
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
|
||||
exception=e,
|
||||
extra_data={
|
||||
"exception_class": e.__class__.__name__,
|
||||
"processing_stage": "request_processing",
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
"traceback_preview": str(traceback.format_exc())[:500],
|
||||
},
|
||||
)
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
# 对于未预期的异常,强制设置状态码为 500
|
||||
result.status_code = 500
|
||||
result.error_type = "internal_error"
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
|
||||
return self._error_response(
|
||||
status_code=500,
|
||||
error_type="internal_server_error",
|
||||
message="处理请求时发生内部错误")
|
||||
|
||||
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
|
||||
"""生成错误响应"""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
|
||||
# =========================================================================
|
||||
|
||||
def compute_total_input_context(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
计算总输入上下文(用于阶梯计费判定)
|
||||
|
||||
默认实现:input_tokens + cache_read_input_tokens
|
||||
子类可覆盖此方法实现不同的计算逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
|
||||
|
||||
Returns:
|
||||
总输入上下文 token 数
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
input_price_per_1m: float,
|
||||
output_price_per_1m: float,
|
||||
cache_creation_price_per_1m: Optional[float],
|
||||
cache_read_price_per_1m: Optional[float],
|
||||
price_per_request: Optional[float],
|
||||
tiered_pricing: Optional[dict] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
output_tokens: 输出 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
input_price_per_1m: 输入价格(每 1M tokens)
|
||||
output_price_per_1m: 输出价格(每 1M tokens)
|
||||
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens)
|
||||
cache_read_price_per_1m: 缓存读取价格(每 1M tokens)
|
||||
price_per_request: 按次计费价格
|
||||
tiered_pricing: 阶梯计费配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
包含各项成本的字典
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""根据总输入 token 数确定价格阶梯"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
return tiers[-1] if tiers else None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
||||
# =========================================================================
|
||||
|
||||
_CLI_ADAPTER_REGISTRY: Dict[str, Type["CliAdapterBase"]] = {}
|
||||
_CLI_ADAPTERS_LOADED = False
|
||||
|
||||
|
||||
def register_cli_adapter(adapter_class: Type["CliAdapterBase"]) -> Type["CliAdapterBase"]:
|
||||
"""
|
||||
注册 CLI Adapter 类到注册表
|
||||
|
||||
用法:
|
||||
@register_cli_adapter
|
||||
class ClaudeCliAdapter(CliAdapterBase):
|
||||
FORMAT_ID = "CLAUDE_CLI"
|
||||
...
|
||||
"""
|
||||
format_id = adapter_class.FORMAT_ID
|
||||
if format_id and format_id != "UNKNOWN":
|
||||
_CLI_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
|
||||
return adapter_class
|
||||
|
||||
|
||||
def _ensure_cli_adapters_loaded():
|
||||
"""确保所有 CLI Adapter 已被加载(触发注册)"""
|
||||
global _CLI_ADAPTERS_LOADED
|
||||
if _CLI_ADAPTERS_LOADED:
|
||||
return
|
||||
|
||||
# 导入各个 CLI Adapter 模块以触发 @register_cli_adapter 装饰器
|
||||
try:
|
||||
from src.api.handlers.claude_cli import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.openai_cli import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.gemini_cli import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_CLI_ADAPTERS_LOADED = True
|
||||
|
||||
|
||||
def get_cli_adapter_class(api_format: str) -> Optional[Type["CliAdapterBase"]]:
|
||||
"""根据 API format 获取 CLI Adapter 类"""
|
||||
_ensure_cli_adapters_loaded()
|
||||
return _CLI_ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
|
||||
|
||||
|
||||
def get_cli_adapter_instance(api_format: str) -> Optional["CliAdapterBase"]:
|
||||
"""根据 API format 获取 CLI Adapter 实例"""
|
||||
adapter_class = get_cli_adapter_class(api_format)
|
||||
if adapter_class:
|
||||
return adapter_class()
|
||||
return None
|
||||
|
||||
|
||||
def list_registered_cli_formats() -> list[str]:
|
||||
"""返回所有已注册的 CLI API 格式"""
|
||||
_ensure_cli_adapters_loaded()
|
||||
return list(_CLI_ADAPTER_REGISTRY.keys())
|
||||
1614
src/api/handlers/base/cli_handler_base.py
Normal file
1614
src/api/handlers/base/cli_handler_base.py
Normal file
File diff suppressed because it is too large
Load Diff
279
src/api/handlers/base/format_converter_registry.py
Normal file
279
src/api/handlers/base/format_converter_registry.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
格式转换器注册表
|
||||
|
||||
自动管理不同 API 格式之间的转换器,支持:
|
||||
- 请求转换:客户端格式 → Provider 格式
|
||||
- 响应转换:Provider 格式 → 客户端格式
|
||||
|
||||
使用方法:
|
||||
1. 实现 Converter 类(需要有 convert_request 和/或 convert_response 方法)
|
||||
2. 调用 registry.register() 注册转换器
|
||||
3. 在 Handler 中调用 registry.convert_request/convert_response
|
||||
|
||||
示例:
|
||||
from src.api.handlers.base.format_converter_registry import converter_registry
|
||||
|
||||
# 注册转换器
|
||||
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
|
||||
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
|
||||
|
||||
# 使用转换器
|
||||
gemini_request = converter_registry.convert_request(claude_request, "CLAUDE", "GEMINI")
|
||||
claude_response = converter_registry.convert_response(gemini_response, "GEMINI", "CLAUDE")
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Protocol, Tuple
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class RequestConverter(Protocol):
|
||||
"""请求转换器协议"""
|
||||
|
||||
def convert_request(self, request: Dict[str, Any]) -> Dict[str, Any]: ...
|
||||
|
||||
|
||||
class ResponseConverter(Protocol):
|
||||
"""响应转换器协议"""
|
||||
|
||||
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]: ...
|
||||
|
||||
|
||||
class StreamChunkConverter(Protocol):
|
||||
"""流式响应块转换器协议"""
|
||||
|
||||
def convert_stream_chunk(self, chunk: Dict[str, Any]) -> Dict[str, Any]: ...
|
||||
|
||||
|
||||
class FormatConverterRegistry:
|
||||
"""
|
||||
格式转换器注册表
|
||||
|
||||
管理不同 API 格式之间的双向转换器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# key: (source_format, target_format), value: converter instance
|
||||
self._converters: Dict[Tuple[str, str], Any] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
converter: Any,
|
||||
) -> None:
|
||||
"""
|
||||
注册格式转换器
|
||||
|
||||
Args:
|
||||
source_format: 源格式(如 "CLAUDE", "OPENAI", "GEMINI")
|
||||
target_format: 目标格式
|
||||
converter: 转换器实例(需要有 convert_request/convert_response 方法)
|
||||
"""
|
||||
key = (source_format.upper(), target_format.upper())
|
||||
self._converters[key] = converter
|
||||
logger.info(f"[ConverterRegistry] 注册转换器: {source_format} -> {target_format}")
|
||||
|
||||
def get_converter(
|
||||
self,
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
获取转换器
|
||||
|
||||
Args:
|
||||
source_format: 源格式
|
||||
target_format: 目标格式
|
||||
|
||||
Returns:
|
||||
转换器实例,如果不存在返回 None
|
||||
"""
|
||||
key = (source_format.upper(), target_format.upper())
|
||||
return self._converters.get(key)
|
||||
|
||||
def has_converter(
|
||||
self,
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> bool:
|
||||
"""检查是否存在转换器"""
|
||||
key = (source_format.upper(), target_format.upper())
|
||||
return key in self._converters
|
||||
|
||||
def convert_request(
|
||||
self,
|
||||
request: Dict[str, Any],
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
转换请求
|
||||
|
||||
Args:
|
||||
request: 原始请求字典
|
||||
source_format: 源格式(客户端格式)
|
||||
target_format: 目标格式(Provider 格式)
|
||||
|
||||
Returns:
|
||||
转换后的请求字典,如果无需转换或没有转换器则返回原始请求
|
||||
"""
|
||||
# 同格式无需转换
|
||||
if source_format.upper() == target_format.upper():
|
||||
return request
|
||||
|
||||
converter = self.get_converter(source_format, target_format)
|
||||
if converter is None:
|
||||
logger.warning(f"[ConverterRegistry] 未找到请求转换器: {source_format} -> {target_format},返回原始请求")
|
||||
return request
|
||||
|
||||
if not hasattr(converter, "convert_request"):
|
||||
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_request 方法: {source_format} -> {target_format}")
|
||||
return request
|
||||
|
||||
try:
|
||||
converted = converter.convert_request(request)
|
||||
logger.debug(f"[ConverterRegistry] 请求转换成功: {source_format} -> {target_format}")
|
||||
return converted
|
||||
except Exception as e:
|
||||
logger.error(f"[ConverterRegistry] 请求转换失败: {source_format} -> {target_format}: {e}")
|
||||
return request
|
||||
|
||||
def convert_response(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
转换响应
|
||||
|
||||
Args:
|
||||
response: 原始响应字典
|
||||
source_format: 源格式(Provider 格式)
|
||||
target_format: 目标格式(客户端格式)
|
||||
|
||||
Returns:
|
||||
转换后的响应字典,如果无需转换或没有转换器则返回原始响应
|
||||
"""
|
||||
# 同格式无需转换
|
||||
if source_format.upper() == target_format.upper():
|
||||
return response
|
||||
|
||||
converter = self.get_converter(source_format, target_format)
|
||||
if converter is None:
|
||||
logger.warning(f"[ConverterRegistry] 未找到响应转换器: {source_format} -> {target_format},返回原始响应")
|
||||
return response
|
||||
|
||||
if not hasattr(converter, "convert_response"):
|
||||
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_response 方法: {source_format} -> {target_format}")
|
||||
return response
|
||||
|
||||
try:
|
||||
converted = converter.convert_response(response)
|
||||
logger.debug(f"[ConverterRegistry] 响应转换成功: {source_format} -> {target_format}")
|
||||
return converted
|
||||
except Exception as e:
|
||||
logger.error(f"[ConverterRegistry] 响应转换失败: {source_format} -> {target_format}: {e}")
|
||||
return response
|
||||
|
||||
def convert_stream_chunk(
|
||||
self,
|
||||
chunk: Dict[str, Any],
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
转换流式响应块
|
||||
|
||||
Args:
|
||||
chunk: 原始流式响应块
|
||||
source_format: 源格式(Provider 格式)
|
||||
target_format: 目标格式(客户端格式)
|
||||
|
||||
Returns:
|
||||
转换后的流式响应块
|
||||
"""
|
||||
# 同格式无需转换
|
||||
if source_format.upper() == target_format.upper():
|
||||
return chunk
|
||||
|
||||
converter = self.get_converter(source_format, target_format)
|
||||
if converter is None:
|
||||
return chunk
|
||||
|
||||
# 优先使用专门的流式转换方法
|
||||
if hasattr(converter, "convert_stream_chunk"):
|
||||
try:
|
||||
return converter.convert_stream_chunk(chunk)
|
||||
except Exception as e:
|
||||
logger.error(f"[ConverterRegistry] 流式块转换失败: {source_format} -> {target_format}: {e}")
|
||||
return chunk
|
||||
|
||||
# 降级到普通响应转换
|
||||
if hasattr(converter, "convert_response"):
|
||||
try:
|
||||
return converter.convert_response(chunk)
|
||||
except Exception:
|
||||
return chunk
|
||||
|
||||
return chunk
|
||||
|
||||
def list_converters(self) -> list[Tuple[str, str]]:
|
||||
"""列出所有已注册的转换器"""
|
||||
return list(self._converters.keys())
|
||||
|
||||
|
||||
# 全局单例
|
||||
converter_registry = FormatConverterRegistry()
|
||||
|
||||
|
||||
def register_all_converters():
|
||||
"""
|
||||
注册所有内置的格式转换器
|
||||
|
||||
在应用启动时调用此函数
|
||||
"""
|
||||
# Claude <-> OpenAI
|
||||
try:
|
||||
from src.api.handlers.claude.converter import OpenAIToClaudeConverter
|
||||
from src.api.handlers.openai.converter import ClaudeToOpenAIConverter
|
||||
|
||||
converter_registry.register("OPENAI", "CLAUDE", OpenAIToClaudeConverter())
|
||||
converter_registry.register("CLAUDE", "OPENAI", ClaudeToOpenAIConverter())
|
||||
except ImportError as e:
|
||||
logger.warning(f"[ConverterRegistry] 无法加载 Claude/OpenAI 转换器: {e}")
|
||||
|
||||
# Claude <-> Gemini
|
||||
try:
|
||||
from src.api.handlers.gemini.converter import (
|
||||
ClaudeToGeminiConverter,
|
||||
GeminiToClaudeConverter,
|
||||
)
|
||||
|
||||
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
|
||||
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
|
||||
except ImportError as e:
|
||||
logger.warning(f"[ConverterRegistry] 无法加载 Claude/Gemini 转换器: {e}")
|
||||
|
||||
# OpenAI <-> Gemini
|
||||
try:
|
||||
from src.api.handlers.gemini.converter import (
|
||||
GeminiToOpenAIConverter,
|
||||
OpenAIToGeminiConverter,
|
||||
)
|
||||
|
||||
converter_registry.register("OPENAI", "GEMINI", OpenAIToGeminiConverter())
|
||||
converter_registry.register("GEMINI", "OPENAI", GeminiToOpenAIConverter())
|
||||
except ImportError as e:
|
||||
logger.warning(f"[ConverterRegistry] 无法加载 OpenAI/Gemini 转换器: {e}")
|
||||
|
||||
logger.info(f"[ConverterRegistry] 已注册 {len(converter_registry.list_converters())} 个格式转换器")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FormatConverterRegistry",
|
||||
"converter_registry",
|
||||
"register_all_converters",
|
||||
]
|
||||
465
src/api/handlers/base/parsers.py
Normal file
465
src/api/handlers/base/parsers.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
响应解析器工厂
|
||||
|
||||
直接根据格式 ID 创建对应的 ResponseParser 实现,
|
||||
不再经过 Protocol 抽象层。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.response_parser import (
|
||||
ParsedChunk,
|
||||
ParsedResponse,
|
||||
ResponseParser,
|
||||
StreamStats,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIResponseParser(ResponseParser):
|
||||
"""OpenAI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
|
||||
|
||||
self._parser = OpenAIStreamParser()
|
||||
self.name = "OPENAI"
|
||||
self.api_format = "OPENAI"
|
||||
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
if not line or not line.strip():
|
||||
return None
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
else:
|
||||
data_str = line
|
||||
|
||||
parsed = self._parser.parse_line(data_str)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
chunk = ParsedChunk(
|
||||
raw_line=line,
|
||||
event_type=None,
|
||||
data=parsed,
|
||||
)
|
||||
|
||||
# 提取文本增量
|
||||
text_delta = self._parser.extract_text_delta(parsed)
|
||||
if text_delta:
|
||||
chunk.text_delta = text_delta
|
||||
stats.collected_text += text_delta
|
||||
|
||||
# 检查是否结束
|
||||
if self._parser.is_done_chunk(parsed):
|
||||
chunk.is_done = True
|
||||
stats.has_completion = True
|
||||
|
||||
stats.chunk_count += 1
|
||||
stats.data_count += 1
|
||||
|
||||
return chunk
|
||||
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
result = ParsedResponse(
|
||||
raw_response=response,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
# 提取文本内容
|
||||
choices = response.get("choices", [])
|
||||
if choices:
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content")
|
||||
if content:
|
||||
result.text_content = content
|
||||
|
||||
result.response_id = response.get("id")
|
||||
|
||||
# 提取 usage
|
||||
usage = response.get("usage", {})
|
||||
result.input_tokens = usage.get("prompt_tokens", 0)
|
||||
result.output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
# 检查错误
|
||||
if "error" in response:
|
||||
result.is_error = True
|
||||
error = response.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
result.error_type = error.get("type")
|
||||
result.error_message = error.get("message")
|
||||
else:
|
||||
result.error_message = str(error)
|
||||
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
usage = response.get("usage", {})
|
||||
return {
|
||||
"input_tokens": usage.get("prompt_tokens", 0),
|
||||
"output_tokens": usage.get("completion_tokens", 0),
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
}
|
||||
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
choices = response.get("choices", [])
|
||||
if choices:
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content")
|
||||
if content:
|
||||
return content
|
||||
return ""
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
return "error" in response
|
||||
|
||||
|
||||
class OpenAICliResponseParser(OpenAIResponseParser):
|
||||
"""OpenAI CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "OPENAI_CLI"
|
||||
self.api_format = "OPENAI_CLI"
|
||||
|
||||
|
||||
class ClaudeResponseParser(ResponseParser):
|
||||
"""Claude 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
|
||||
|
||||
self._parser = ClaudeStreamParser()
|
||||
self.name = "CLAUDE"
|
||||
self.api_format = "CLAUDE"
|
||||
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
if not line or not line.strip():
|
||||
return None
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
else:
|
||||
data_str = line
|
||||
|
||||
parsed = self._parser.parse_line(data_str)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
chunk = ParsedChunk(
|
||||
raw_line=line,
|
||||
event_type=self._parser.get_event_type(parsed),
|
||||
data=parsed,
|
||||
)
|
||||
|
||||
# 提取文本增量
|
||||
text_delta = self._parser.extract_text_delta(parsed)
|
||||
if text_delta:
|
||||
chunk.text_delta = text_delta
|
||||
stats.collected_text += text_delta
|
||||
|
||||
# 检查是否结束
|
||||
if self._parser.is_done_event(parsed):
|
||||
chunk.is_done = True
|
||||
stats.has_completion = True
|
||||
|
||||
# 提取 usage
|
||||
usage = self._parser.extract_usage(parsed)
|
||||
if usage:
|
||||
chunk.input_tokens = usage.get("input_tokens", 0)
|
||||
chunk.output_tokens = usage.get("output_tokens", 0)
|
||||
chunk.cache_creation_tokens = usage.get("cache_creation_tokens", 0)
|
||||
chunk.cache_read_tokens = usage.get("cache_read_tokens", 0)
|
||||
|
||||
stats.input_tokens = chunk.input_tokens
|
||||
stats.output_tokens = chunk.output_tokens
|
||||
stats.cache_creation_tokens = chunk.cache_creation_tokens
|
||||
stats.cache_read_tokens = chunk.cache_read_tokens
|
||||
|
||||
# 检查错误
|
||||
if self._parser.is_error_event(parsed):
|
||||
chunk.is_error = True
|
||||
error = parsed.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
chunk.error_message = error.get("message", str(error))
|
||||
else:
|
||||
chunk.error_message = str(error)
|
||||
|
||||
stats.chunk_count += 1
|
||||
stats.data_count += 1
|
||||
|
||||
return chunk
|
||||
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
result = ParsedResponse(
|
||||
raw_response=response,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
# 提取文本内容
|
||||
content = response.get("content", [])
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
result.text_content = "".join(text_parts)
|
||||
|
||||
result.response_id = response.get("id")
|
||||
|
||||
# 提取 usage
|
||||
usage = response.get("usage", {})
|
||||
result.input_tokens = usage.get("input_tokens", 0)
|
||||
result.output_tokens = usage.get("output_tokens", 0)
|
||||
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
||||
|
||||
# 检查错误
|
||||
if "error" in response or response.get("type") == "error":
|
||||
result.is_error = True
|
||||
error = response.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
result.error_type = error.get("type")
|
||||
result.error_message = error.get("message")
|
||||
else:
|
||||
result.error_message = str(error)
|
||||
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
usage = response.get("usage", {})
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
content = response.get("content", [])
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
return "".join(text_parts)
|
||||
return ""
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
return "error" in response or response.get("type") == "error"
|
||||
|
||||
|
||||
class ClaudeCliResponseParser(ClaudeResponseParser):
|
||||
"""Claude CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "CLAUDE_CLI"
|
||||
self.api_format = "CLAUDE_CLI"
|
||||
|
||||
|
||||
class GeminiResponseParser(ResponseParser):
|
||||
"""Gemini 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
self._parser = GeminiStreamParser()
|
||||
self.name = "GEMINI"
|
||||
self.api_format = "GEMINI"
|
||||
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
"""
|
||||
解析 Gemini SSE 行
|
||||
|
||||
Gemini 的流式响应使用 SSE 格式 (data: {...})
|
||||
"""
|
||||
if not line or not line.strip():
|
||||
return None
|
||||
|
||||
# Gemini SSE 格式: data: {...}
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
else:
|
||||
data_str = line
|
||||
|
||||
parsed = self._parser.parse_line(data_str)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
chunk = ParsedChunk(
|
||||
raw_line=line,
|
||||
event_type="content",
|
||||
data=parsed,
|
||||
)
|
||||
|
||||
# 提取文本增量
|
||||
text_delta = self._parser.extract_text_delta(parsed)
|
||||
if text_delta:
|
||||
chunk.text_delta = text_delta
|
||||
stats.collected_text += text_delta
|
||||
|
||||
# 检查是否结束
|
||||
if self._parser.is_done_event(parsed):
|
||||
chunk.is_done = True
|
||||
stats.has_completion = True
|
||||
|
||||
# 提取 usage
|
||||
usage = self._parser.extract_usage(parsed)
|
||||
if usage:
|
||||
chunk.input_tokens = usage.get("input_tokens", 0)
|
||||
chunk.output_tokens = usage.get("output_tokens", 0)
|
||||
chunk.cache_read_tokens = usage.get("cached_tokens", 0)
|
||||
|
||||
stats.input_tokens = chunk.input_tokens
|
||||
stats.output_tokens = chunk.output_tokens
|
||||
stats.cache_read_tokens = chunk.cache_read_tokens
|
||||
|
||||
# 检查错误
|
||||
if self._parser.is_error_event(parsed):
|
||||
chunk.is_error = True
|
||||
error = parsed.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
chunk.error_message = error.get("message", str(error))
|
||||
else:
|
||||
chunk.error_message = str(error)
|
||||
|
||||
stats.chunk_count += 1
|
||||
stats.data_count += 1
|
||||
|
||||
return chunk
|
||||
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
result = ParsedResponse(
|
||||
raw_response=response,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
# 提取文本内容
|
||||
candidates = response.get("candidates", [])
|
||||
if candidates:
|
||||
content = candidates[0].get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
text_parts = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
result.text_content = "".join(text_parts)
|
||||
|
||||
result.response_id = response.get("modelVersion")
|
||||
|
||||
# 提取 usage(调用 GeminiStreamParser.extract_usage 作为单一实现源)
|
||||
usage = self._parser.extract_usage(response)
|
||||
if usage:
|
||||
result.input_tokens = usage.get("input_tokens", 0)
|
||||
result.output_tokens = usage.get("output_tokens", 0)
|
||||
result.cache_read_tokens = usage.get("cached_tokens", 0)
|
||||
|
||||
# 检查错误(使用增强的错误检测)
|
||||
error_info = self._parser.extract_error_info(response)
|
||||
if error_info:
|
||||
result.is_error = True
|
||||
result.error_type = error_info.get("status")
|
||||
result.error_message = error_info.get("message")
|
||||
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""
|
||||
从 Gemini 响应中提取 token 使用量
|
||||
|
||||
调用 GeminiStreamParser.extract_usage 作为单一实现源
|
||||
"""
|
||||
usage = self._parser.extract_usage(response)
|
||||
if not usage:
|
||||
return {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": usage.get("cached_tokens", 0),
|
||||
}
|
||||
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
candidates = response.get("candidates", [])
|
||||
if candidates:
|
||||
content = candidates[0].get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
text_parts = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
return "".join(text_parts)
|
||||
return ""
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断响应是否为错误响应
|
||||
|
||||
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
|
||||
"""
|
||||
return self._parser.is_error_event(response)
|
||||
|
||||
|
||||
class GeminiCliResponseParser(GeminiResponseParser):
|
||||
"""Gemini CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "GEMINI_CLI"
|
||||
self.api_format = "GEMINI_CLI"
|
||||
|
||||
|
||||
# 解析器注册表
|
||||
_PARSERS = {
|
||||
"CLAUDE": ClaudeResponseParser,
|
||||
"CLAUDE_CLI": ClaudeCliResponseParser,
|
||||
"OPENAI": OpenAIResponseParser,
|
||||
"OPENAI_CLI": OpenAICliResponseParser,
|
||||
"GEMINI": GeminiResponseParser,
|
||||
"GEMINI_CLI": GeminiCliResponseParser,
|
||||
}
|
||||
|
||||
|
||||
def get_parser_for_format(format_id: str) -> ResponseParser:
|
||||
"""
|
||||
根据格式 ID 获取 ResponseParser
|
||||
|
||||
Args:
|
||||
format_id: 格式 ID,如 "CLAUDE", "OPENAI", "CLAUDE_CLI", "OPENAI_CLI"
|
||||
|
||||
Returns:
|
||||
ResponseParser 实例
|
||||
|
||||
Raises:
|
||||
KeyError: 格式不存在
|
||||
"""
|
||||
format_id = format_id.upper()
|
||||
if format_id not in _PARSERS:
|
||||
raise KeyError(f"Unknown format: {format_id}")
|
||||
return _PARSERS[format_id]()
|
||||
|
||||
|
||||
def is_cli_format(format_id: str) -> bool:
|
||||
"""判断是否为 CLI 格式"""
|
||||
return format_id.upper().endswith("_CLI")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OpenAIResponseParser",
|
||||
"OpenAICliResponseParser",
|
||||
"ClaudeResponseParser",
|
||||
"ClaudeCliResponseParser",
|
||||
"GeminiResponseParser",
|
||||
"GeminiCliResponseParser",
|
||||
"get_parser_for_format",
|
||||
"get_parser_from_protocol",
|
||||
"is_cli_format",
|
||||
]
|
||||
207
src/api/handlers/base/request_builder.py
Normal file
207
src/api/handlers/base/request_builder.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
请求构建器 - 透传模式
|
||||
|
||||
透传模式 (Passthrough): CLI 和 Chat 等场景,原样转发请求体和头部
|
||||
- 清理敏感头部:authorization, x-api-key, host, content-length 等
|
||||
- 保留所有其他头部和请求体字段
|
||||
- 适用于:Claude CLI、OpenAI CLI、Chat API 等场景
|
||||
|
||||
使用方式:
|
||||
builder = PassthroughRequestBuilder()
|
||||
payload, headers = builder.build(original_body, original_headers, endpoint, key)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, FrozenSet, Optional, Tuple
|
||||
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
# ==============================================================================
|
||||
# 统一的头部配置常量
|
||||
# ==============================================================================
|
||||
|
||||
# 敏感头部 - 透传时需要清理(黑名单)
|
||||
# 这些头部要么包含认证信息,要么由代理层重新生成
|
||||
SENSITIVE_HEADERS: FrozenSet[str] = frozenset(
|
||||
{
|
||||
"authorization",
|
||||
"x-api-key",
|
||||
"x-goog-api-key", # Gemini API 认证头
|
||||
"host",
|
||||
"content-length",
|
||||
"transfer-encoding",
|
||||
"connection",
|
||||
# 不透传 accept-encoding,让 httpx 自己协商压缩格式
|
||||
# 避免客户端请求 brotli/zstd 但 httpx 不支持解压的问题
|
||||
"accept-encoding",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 请求构建器
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RequestBuilder(ABC):
|
||||
"""请求构建器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def build_payload(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
*,
|
||||
mapped_model: Optional[str] = None,
|
||||
is_stream: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建请求体"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_headers(
|
||||
self,
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
*,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
pass
|
||||
|
||||
def build(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
*,
|
||||
mapped_model: Optional[str] = None,
|
||||
is_stream: bool = False,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
构建完整的请求(请求体 + 请求头)
|
||||
|
||||
Returns:
|
||||
Tuple[payload, headers]
|
||||
"""
|
||||
payload = self.build_payload(
|
||||
original_body,
|
||||
mapped_model=mapped_model,
|
||||
is_stream=is_stream,
|
||||
)
|
||||
headers = self.build_headers(
|
||||
original_headers,
|
||||
endpoint,
|
||||
key,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
return payload, headers
|
||||
|
||||
|
||||
class PassthroughRequestBuilder(RequestBuilder):
|
||||
"""
|
||||
透传模式请求构建器
|
||||
|
||||
适用于 CLI 等场景,尽量保持请求原样:
|
||||
- 请求体:直接复制,只修改必要字段(model, stream)
|
||||
- 请求头:清理敏感头部(黑名单),透传其他所有头部
|
||||
"""
|
||||
|
||||
def build_payload(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
*,
|
||||
mapped_model: Optional[str] = None, # noqa: ARG002 - 由 apply_mapped_model 处理
|
||||
is_stream: bool = False, # noqa: ARG002 - 保留原始值,不自动添加
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
透传请求体 - 原样复制,不做任何修改
|
||||
|
||||
透传模式下:
|
||||
- model: 由各 handler 的 apply_mapped_model 方法处理
|
||||
- stream: 保留客户端原始值(不同 API 处理方式不同)
|
||||
"""
|
||||
return dict(original_body)
|
||||
|
||||
def build_headers(
|
||||
self,
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
*,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
透传请求头 - 清理敏感头部(黑名单),透传其他所有头部
|
||||
"""
|
||||
from src.core.api_format_metadata import get_auth_config, resolve_api_format
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
|
||||
# 1. 根据 API 格式自动设置认证头
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
api_format = getattr(endpoint, "api_format", None)
|
||||
resolved_format = resolve_api_format(api_format)
|
||||
auth_header, auth_type = (
|
||||
get_auth_config(resolved_format) if resolved_format else ("Authorization", "bearer")
|
||||
)
|
||||
|
||||
if auth_type == "bearer":
|
||||
headers[auth_header] = f"Bearer {decrypted_key}"
|
||||
else:
|
||||
headers[auth_header] = decrypted_key
|
||||
|
||||
# 2. 添加 endpoint 配置的额外头部
|
||||
if endpoint.headers:
|
||||
headers.update(endpoint.headers)
|
||||
|
||||
# 3. 透传原始头部(排除敏感头部 - 黑名单模式)
|
||||
if original_headers:
|
||||
for name, value in original_headers.items():
|
||||
lower_name = name.lower()
|
||||
|
||||
# 跳过敏感头部
|
||||
if lower_name in SENSITIVE_HEADERS:
|
||||
continue
|
||||
|
||||
headers[name] = value
|
||||
|
||||
# 4. 添加额外头部
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# 5. 确保有 Content-Type
|
||||
if "Content-Type" not in headers and "content-type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 便捷函数
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def build_passthrough_request(
|
||||
original_body: Dict[str, Any],
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
构建透传模式的请求
|
||||
|
||||
纯透传:原样复制请求体,只处理请求头(认证等)。
|
||||
model mapping 和 stream 由调用方自行处理(不同 API 格式处理方式不同)。
|
||||
"""
|
||||
builder = PassthroughRequestBuilder()
|
||||
return builder.build(
|
||||
original_body,
|
||||
original_headers,
|
||||
endpoint,
|
||||
key,
|
||||
)
|
||||
174
src/api/handlers/base/response_parser.py
Normal file
174
src/api/handlers/base/response_parser.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
响应解析器基类 - 定义统一的响应解析接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedChunk:
|
||||
"""解析后的流式数据块"""
|
||||
|
||||
# 原始数据
|
||||
raw_line: str
|
||||
event_type: Optional[str] = None
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 提取的内容
|
||||
text_delta: str = ""
|
||||
is_done: bool = False
|
||||
is_error: bool = False
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# 使用量信息(通常在最后一个 chunk 中)
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
|
||||
# 响应 ID
|
||||
response_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamStats:
|
||||
"""流式响应统计信息"""
|
||||
|
||||
# 计数
|
||||
chunk_count: int = 0
|
||||
data_count: int = 0
|
||||
|
||||
# Token 使用量
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
|
||||
# 内容
|
||||
collected_text: str = ""
|
||||
response_id: Optional[str] = None
|
||||
|
||||
# 状态
|
||||
has_completion: bool = False
|
||||
status_code: int = 200
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# Provider 信息
|
||||
provider_name: Optional[str] = None
|
||||
endpoint_id: Optional[str] = None
|
||||
key_id: Optional[str] = None
|
||||
|
||||
# 响应头和完整响应
|
||||
response_headers: Dict[str, str] = field(default_factory=dict)
|
||||
final_response: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedResponse:
|
||||
"""解析后的非流式响应"""
|
||||
|
||||
# 原始响应
|
||||
raw_response: Dict[str, Any]
|
||||
status_code: int
|
||||
|
||||
# 提取的内容
|
||||
text_content: str = ""
|
||||
response_id: Optional[str] = None
|
||||
|
||||
# 使用量
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
|
||||
# 错误信息
|
||||
is_error: bool = False
|
||||
error_type: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class ResponseParser(ABC):
|
||||
"""
|
||||
响应解析器基类
|
||||
|
||||
定义统一的接口来解析不同 API 格式的响应。
|
||||
子类需要实现具体的解析逻辑。
|
||||
"""
|
||||
|
||||
# 解析器名称(用于日志)
|
||||
name: str = "base"
|
||||
|
||||
# 支持的 API 格式
|
||||
api_format: str = "UNKNOWN"
|
||||
|
||||
@abstractmethod
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
"""
|
||||
解析单行 SSE 数据
|
||||
|
||||
Args:
|
||||
line: SSE 行数据
|
||||
stats: 流统计对象(会被更新)
|
||||
|
||||
Returns:
|
||||
解析后的数据块,如果行不包含有效数据则返回 None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
"""
|
||||
解析非流式响应
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
status_code: HTTP 状态码
|
||||
|
||||
Returns:
|
||||
解析后的响应对象
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""
|
||||
从响应中提取 token 使用量
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
|
||||
Returns:
|
||||
包含 input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens 的字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
"""
|
||||
从响应中提取文本内容
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断响应是否为错误响应
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
|
||||
Returns:
|
||||
是否为错误响应
|
||||
"""
|
||||
return "error" in response
|
||||
|
||||
def create_stats(self) -> StreamStats:
|
||||
"""创建新的流统计对象"""
|
||||
return StreamStats()
|
||||
17
src/api/handlers/claude/__init__.py
Normal file
17
src/api/handlers/claude/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Claude Chat API 处理器
|
||||
"""
|
||||
|
||||
from src.api.handlers.claude.adapter import (
|
||||
ClaudeChatAdapter,
|
||||
ClaudeTokenCountAdapter,
|
||||
build_claude_adapter,
|
||||
)
|
||||
from src.api.handlers.claude.handler import ClaudeChatHandler
|
||||
|
||||
__all__ = [
|
||||
"ClaudeChatAdapter",
|
||||
"ClaudeTokenCountAdapter",
|
||||
"build_claude_adapter",
|
||||
"ClaudeChatHandler",
|
||||
]
|
||||
228
src/api/handlers/claude/adapter.py
Normal file
228
src/api/handlers/claude/adapter.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Claude Chat Adapter - 基于 ChatAdapterBase 的 Claude Chat API 适配器
|
||||
|
||||
处理 /v1/messages 端点的 Claude Chat 格式请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.core.logger import logger
|
||||
from src.core.optimization_utils import TokenCounter
|
||||
from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest
|
||||
|
||||
|
||||
class ClaudeCapabilityDetector:
|
||||
"""Claude API 能力检测器"""
|
||||
|
||||
@staticmethod
|
||||
def detect_from_headers(
|
||||
headers: Dict[str, str],
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
从 Claude 请求头检测能力需求
|
||||
|
||||
检测规则:
|
||||
- anthropic-beta: context-1m-xxx -> context_1m: True
|
||||
|
||||
Args:
|
||||
headers: 请求头字典
|
||||
request_body: 请求体(Claude 不使用,保留用于接口统一)
|
||||
"""
|
||||
requirements: Dict[str, bool] = {}
|
||||
|
||||
# 检查 anthropic-beta 请求头(大小写不敏感)
|
||||
beta_header = None
|
||||
for key, value in headers.items():
|
||||
if key.lower() == "anthropic-beta":
|
||||
beta_header = value
|
||||
break
|
||||
|
||||
if beta_header:
|
||||
# 检查是否包含 context-1m 标识
|
||||
if "context-1m" in beta_header.lower():
|
||||
requirements["context_1m"] = True
|
||||
|
||||
return requirements
|
||||
|
||||
|
||||
@register_adapter
|
||||
class ClaudeChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
Claude Chat API 适配器
|
||||
|
||||
处理 Claude Chat 格式的请求(/v1/messages 端点,进行格式验证)。
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE"
|
||||
name = "claude.chat"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.claude.handler import ClaudeChatHandler
|
||||
|
||||
return ClaudeChatHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["CLAUDE"])
|
||||
logger.info(f"[{self.name}] 初始化Chat模式适配器 | API格式: {self.allowed_api_formats}")
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (x-api-key)"""
|
||||
return request.headers.get("x-api-key")
|
||||
|
||||
def detect_capability_requirements(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""检测 Claude 请求中隐含的能力需求"""
|
||||
return ClaudeCapabilityDetector.detect_from_headers(headers)
|
||||
|
||||
# =========================================================================
|
||||
# Claude 特定的计费逻辑
|
||||
# =========================================================================
|
||||
|
||||
def compute_total_input_context(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
计算 Claude 的总输入上下文(用于阶梯计费判定)
|
||||
|
||||
Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
"""
|
||||
return input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
|
||||
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
||||
"""验证请求体"""
|
||||
try:
|
||||
if not isinstance(original_request_body, dict):
|
||||
raise ValueError("Request body must be a JSON object")
|
||||
|
||||
required_fields = ["model", "messages", "max_tokens"]
|
||||
missing_fields = [f for f in required_fields if f not in original_request_body]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
|
||||
|
||||
request = ClaudeMessagesRequest.model_validate(
|
||||
original_request_body,
|
||||
strict=False,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"请求体基本验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
|
||||
request = ClaudeMessagesRequest.model_construct(
|
||||
model=original_request_body.get("model"),
|
||||
max_tokens=original_request_body.get("max_tokens"),
|
||||
messages=original_request_body.get("messages", []),
|
||||
stream=original_request_body.get("stream", False),
|
||||
)
|
||||
return request
|
||||
|
||||
def _build_audit_metadata(self, _payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
||||
"""构建 Claude Chat 特定的审计元数据"""
|
||||
role_counts: dict[str, int] = {}
|
||||
for message in request_obj.messages:
|
||||
role_counts[message.role] = role_counts.get(message.role, 0) + 1
|
||||
|
||||
return {
|
||||
"action": "claude_messages",
|
||||
"model": request_obj.model,
|
||||
"stream": bool(request_obj.stream),
|
||||
"max_tokens": request_obj.max_tokens,
|
||||
"temperature": getattr(request_obj, "temperature", None),
|
||||
"top_p": getattr(request_obj, "top_p", None),
|
||||
"top_k": getattr(request_obj, "top_k", None),
|
||||
"messages_count": len(request_obj.messages),
|
||||
"message_roles": role_counts,
|
||||
"stop_sequences": len(request_obj.stop_sequences or []),
|
||||
"tools_count": len(request_obj.tools or []),
|
||||
"system_present": bool(request_obj.system),
|
||||
"metadata_present": bool(request_obj.metadata),
|
||||
"thinking_enabled": bool(request_obj.thinking),
|
||||
}
|
||||
|
||||
|
||||
def build_claude_adapter(x_app_header: Optional[str]):
|
||||
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
||||
if x_app_header and x_app_header.lower() == "cli":
|
||||
from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter
|
||||
|
||||
return ClaudeCliAdapter()
|
||||
return ClaudeChatAdapter()
|
||||
|
||||
|
||||
class ClaudeTokenCountAdapter(ApiAdapter):
|
||||
"""计算 Claude 请求 Token 数的轻量适配器。"""
|
||||
|
||||
name = "claude.token_count"
|
||||
mode = ApiMode.STANDARD
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (x-api-key 或 Authorization: Bearer)"""
|
||||
# 优先检查 x-api-key
|
||||
api_key = request.headers.get("x-api-key")
|
||||
if api_key:
|
||||
return api_key
|
||||
# 降级到 Authorization: Bearer
|
||||
authorization = request.headers.get("authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
request = ClaudeTokenCountRequest.model_validate(payload, strict=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Token count payload invalid: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid token count payload") from e
|
||||
|
||||
token_counter = TokenCounter()
|
||||
total_tokens = 0
|
||||
|
||||
if request.system:
|
||||
if isinstance(request.system, str):
|
||||
total_tokens += token_counter.count_tokens(request.system, request.model)
|
||||
elif isinstance(request.system, list):
|
||||
for block in request.system:
|
||||
if hasattr(block, "text"):
|
||||
total_tokens += token_counter.count_tokens(block.text, request.model)
|
||||
|
||||
messages_dict = [
|
||||
msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in request.messages
|
||||
]
|
||||
total_tokens += token_counter.count_messages_tokens(messages_dict, request.model)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="claude_token_count",
|
||||
model=request.model,
|
||||
messages_count=len(request.messages),
|
||||
system_present=bool(request.system),
|
||||
tools_count=len(request.tools or []),
|
||||
thinking_enabled=bool(request.thinking),
|
||||
input_tokens=total_tokens,
|
||||
)
|
||||
|
||||
return JSONResponse({"input_tokens": total_tokens})
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ClaudeChatAdapter",
|
||||
"ClaudeTokenCountAdapter",
|
||||
"build_claude_adapter",
|
||||
]
|
||||
490
src/api/handlers/claude/converter.py
Normal file
490
src/api/handlers/claude/converter.py
Normal file
@@ -0,0 +1,490 @@
|
||||
"""
|
||||
OpenAI -> Claude 格式转换器
|
||||
|
||||
将 OpenAI Chat Completions API 格式转换为 Claude Messages API 格式。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OpenAIToClaudeConverter:
|
||||
"""
|
||||
OpenAI -> Claude 格式转换器
|
||||
|
||||
支持:
|
||||
- 请求转换:OpenAI Chat Request -> Claude Request
|
||||
- 响应转换:OpenAI Chat Response -> Claude Response
|
||||
- 流式转换:OpenAI SSE -> Claude SSE
|
||||
"""
|
||||
|
||||
# 内容类型常量
|
||||
CONTENT_TYPE_TEXT = "text"
|
||||
CONTENT_TYPE_IMAGE = "image"
|
||||
CONTENT_TYPE_TOOL_USE = "tool_use"
|
||||
CONTENT_TYPE_TOOL_RESULT = "tool_result"
|
||||
|
||||
# 停止原因映射(OpenAI -> Claude)
|
||||
FINISH_REASON_MAP = {
|
||||
"stop": "end_turn",
|
||||
"length": "max_tokens",
|
||||
"tool_calls": "tool_use",
|
||||
"function_call": "tool_use",
|
||||
"content_filter": "end_turn",
|
||||
}
|
||||
|
||||
def __init__(self, model_mapping: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Args:
|
||||
model_mapping: OpenAI 模型到 Claude 模型的映射
|
||||
"""
|
||||
self._model_mapping = model_mapping or {}
|
||||
|
||||
# ==================== 请求转换 ====================
|
||||
|
||||
def convert_request(self, request: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 OpenAI 请求转换为 Claude 格式
|
||||
|
||||
Args:
|
||||
request: OpenAI 请求(Dict 或 Pydantic 模型)
|
||||
|
||||
Returns:
|
||||
Claude 格式的请求字典
|
||||
"""
|
||||
if hasattr(request, "model_dump"):
|
||||
data = request.model_dump(exclude_none=True)
|
||||
else:
|
||||
data = dict(request)
|
||||
|
||||
# 模型映射
|
||||
model = data.get("model", "")
|
||||
claude_model = self._model_mapping.get(model, model)
|
||||
|
||||
# 处理消息
|
||||
system_content: Optional[str] = None
|
||||
claude_messages: List[Dict[str, Any]] = []
|
||||
|
||||
for message in data.get("messages", []):
|
||||
role = message.get("role")
|
||||
|
||||
# 提取 system 消息
|
||||
if role == "system":
|
||||
system_content = self._collapse_content(message.get("content"))
|
||||
continue
|
||||
|
||||
# 转换其他消息
|
||||
converted = self._convert_message(message)
|
||||
if converted:
|
||||
claude_messages.append(converted)
|
||||
|
||||
# 构建 Claude 请求
|
||||
result: Dict[str, Any] = {
|
||||
"model": claude_model,
|
||||
"messages": claude_messages,
|
||||
"max_tokens": data.get("max_tokens") or 4096,
|
||||
}
|
||||
|
||||
# 可选参数
|
||||
if data.get("temperature") is not None:
|
||||
result["temperature"] = data["temperature"]
|
||||
if data.get("top_p") is not None:
|
||||
result["top_p"] = data["top_p"]
|
||||
if data.get("stream"):
|
||||
result["stream"] = data["stream"]
|
||||
if data.get("stop"):
|
||||
result["stop_sequences"] = self._convert_stop(data["stop"])
|
||||
if system_content:
|
||||
result["system"] = system_content
|
||||
|
||||
# 工具转换
|
||||
tools = self._convert_tools(data.get("tools"))
|
||||
if tools:
|
||||
result["tools"] = tools
|
||||
|
||||
tool_choice = self._convert_tool_choice(data.get("tool_choice"))
|
||||
if tool_choice:
|
||||
result["tool_choice"] = tool_choice
|
||||
|
||||
return result
|
||||
|
||||
def _convert_message(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""转换单条消息"""
|
||||
role = message.get("role")
|
||||
|
||||
if role == "user":
|
||||
return self._convert_user_message(message)
|
||||
if role == "assistant":
|
||||
return self._convert_assistant_message(message)
|
||||
if role == "tool":
|
||||
return self._convert_tool_message(message)
|
||||
|
||||
return None
|
||||
|
||||
def _convert_user_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""转换用户消息"""
|
||||
content = message.get("content")
|
||||
|
||||
if isinstance(content, str) or content is None:
|
||||
return {"role": "user", "content": content or ""}
|
||||
|
||||
# 转换内容数组
|
||||
claude_content: List[Dict[str, Any]] = []
|
||||
for item in content:
|
||||
item_type = item.get("type")
|
||||
|
||||
if item_type == "text":
|
||||
claude_content.append(
|
||||
{"type": self.CONTENT_TYPE_TEXT, "text": item.get("text", "")}
|
||||
)
|
||||
elif item_type == "image_url":
|
||||
image_url = (item.get("image_url") or {}).get("url", "")
|
||||
claude_content.append(self._convert_image_url(image_url))
|
||||
|
||||
return {"role": "user", "content": claude_content}
|
||||
|
||||
def _convert_assistant_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""转换助手消息"""
|
||||
content_blocks: List[Dict[str, Any]] = []
|
||||
|
||||
# 处理文本内容
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
content_blocks.append({"type": self.CONTENT_TYPE_TEXT, "text": content})
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if part.get("type") == "text":
|
||||
content_blocks.append(
|
||||
{"type": self.CONTENT_TYPE_TEXT, "text": part.get("text", "")}
|
||||
)
|
||||
|
||||
# 处理工具调用
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if tool_call.get("type") == "function":
|
||||
function = tool_call.get("function", {})
|
||||
arguments = function.get("arguments", "{}")
|
||||
try:
|
||||
input_data = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
input_data = {"raw": arguments}
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": self.CONTENT_TYPE_TOOL_USE,
|
||||
"id": tool_call.get("id", ""),
|
||||
"name": function.get("name", ""),
|
||||
"input": input_data,
|
||||
}
|
||||
)
|
||||
|
||||
# 简化单文本内容
|
||||
if not content_blocks:
|
||||
return {"role": "assistant", "content": ""}
|
||||
if len(content_blocks) == 1 and content_blocks[0]["type"] == self.CONTENT_TYPE_TEXT:
|
||||
return {"role": "assistant", "content": content_blocks[0]["text"]}
|
||||
|
||||
return {"role": "assistant", "content": content_blocks}
|
||||
|
||||
def _convert_tool_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""转换工具结果消息"""
|
||||
tool_content = message.get("content", "")
|
||||
|
||||
# 尝试解析 JSON
|
||||
parsed_content = tool_content
|
||||
if isinstance(tool_content, str):
|
||||
try:
|
||||
parsed_content = json.loads(tool_content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
tool_block = {
|
||||
"type": self.CONTENT_TYPE_TOOL_RESULT,
|
||||
"tool_use_id": message.get("tool_call_id", ""),
|
||||
"content": parsed_content,
|
||||
}
|
||||
|
||||
return {"role": "user", "content": [tool_block]}
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: Optional[List[Dict[str, Any]]]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""转换工具定义"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
result: List[Dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
if tool.get("type") != "function":
|
||||
continue
|
||||
|
||||
function = tool.get("function", {})
|
||||
result.append(
|
||||
{
|
||||
"name": function.get("name", ""),
|
||||
"description": function.get("description"),
|
||||
"input_schema": function.get("parameters") or {},
|
||||
}
|
||||
)
|
||||
|
||||
return result if result else None
|
||||
|
||||
def _convert_tool_choice(
|
||||
self, tool_choice: Optional[Union[str, Dict[str, Any]]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""转换工具选择"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
if tool_choice == "none":
|
||||
return {"type": "none"}
|
||||
if tool_choice == "auto":
|
||||
return {"type": "auto"}
|
||||
if tool_choice == "required":
|
||||
return {"type": "any"}
|
||||
if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
|
||||
function = tool_choice.get("function", {})
|
||||
return {"type": "tool_use", "name": function.get("name", "")}
|
||||
|
||||
return {"type": "auto"}
|
||||
|
||||
def _convert_image_url(self, image_url: str) -> Dict[str, Any]:
|
||||
"""转换图片 URL"""
|
||||
if image_url.startswith("data:"):
|
||||
header, _, data = image_url.partition(",")
|
||||
media_type = "image/jpeg"
|
||||
if ";" in header:
|
||||
media_type = header.split(";")[0].split(":")[-1]
|
||||
|
||||
return {
|
||||
"type": self.CONTENT_TYPE_IMAGE,
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": data,
|
||||
},
|
||||
}
|
||||
|
||||
return {"type": self.CONTENT_TYPE_TEXT, "text": f"[Image: {image_url}]"}
|
||||
|
||||
def _convert_stop(self, stop: Optional[Union[str, List[str]]]) -> Optional[List[str]]:
|
||||
"""转换停止序列"""
|
||||
if stop is None:
|
||||
return None
|
||||
if isinstance(stop, str):
|
||||
return [stop]
|
||||
return stop
|
||||
|
||||
# ==================== 响应转换 ====================
|
||||
|
||||
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 OpenAI 响应转换为 Claude 格式
|
||||
|
||||
Args:
|
||||
response: OpenAI 响应字典
|
||||
|
||||
Returns:
|
||||
Claude 格式的响应字典
|
||||
"""
|
||||
choices = response.get("choices", [])
|
||||
if not choices:
|
||||
return self._empty_claude_response(response)
|
||||
|
||||
choice = choices[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
# 构建 content 数组
|
||||
content: List[Dict[str, Any]] = []
|
||||
|
||||
# 处理文本
|
||||
text_content = message.get("content")
|
||||
if text_content:
|
||||
content.append(
|
||||
{
|
||||
"type": self.CONTENT_TYPE_TEXT,
|
||||
"text": text_content,
|
||||
}
|
||||
)
|
||||
|
||||
# 处理工具调用
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if tool_call.get("type") == "function":
|
||||
function = tool_call.get("function", {})
|
||||
arguments = function.get("arguments", "{}")
|
||||
try:
|
||||
input_data = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
input_data = {"raw": arguments}
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": self.CONTENT_TYPE_TOOL_USE,
|
||||
"id": tool_call.get("id", ""),
|
||||
"name": function.get("name", ""),
|
||||
"input": input_data,
|
||||
}
|
||||
)
|
||||
|
||||
# 转换 finish_reason
|
||||
finish_reason = choice.get("finish_reason")
|
||||
stop_reason = self.FINISH_REASON_MAP.get(finish_reason, "end_turn")
|
||||
|
||||
# 转换 usage
|
||||
usage = response.get("usage", {})
|
||||
claude_usage = {
|
||||
"input_tokens": usage.get("prompt_tokens", 0),
|
||||
"output_tokens": usage.get("completion_tokens", 0),
|
||||
}
|
||||
|
||||
return {
|
||||
"id": f"msg_{response.get('id', uuid.uuid4().hex[:8])}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": response.get("model", ""),
|
||||
"content": content,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
"usage": claude_usage,
|
||||
}
|
||||
|
||||
def _empty_claude_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""构建空的 Claude 响应"""
|
||||
return {
|
||||
"id": f"msg_{response.get('id', uuid.uuid4().hex[:8])}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": response.get("model", ""),
|
||||
"content": [],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0},
|
||||
}
|
||||
|
||||
# ==================== 流式转换 ====================
|
||||
|
||||
def convert_stream_chunk(
|
||||
self,
|
||||
chunk: Dict[str, Any],
|
||||
model: str = "",
|
||||
message_id: Optional[str] = None,
|
||||
message_started: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
将 OpenAI SSE chunk 转换为 Claude SSE 事件
|
||||
|
||||
Args:
|
||||
chunk: OpenAI SSE chunk
|
||||
model: 模型名称
|
||||
message_id: 消息 ID
|
||||
message_started: 是否已发送 message_start
|
||||
|
||||
Returns:
|
||||
Claude SSE 事件列表
|
||||
"""
|
||||
events: List[Dict[str, Any]] = []
|
||||
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
return events
|
||||
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta", {})
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
# 处理角色(第一个 chunk)
|
||||
role = delta.get("role")
|
||||
if role and not message_started:
|
||||
msg_id = message_id or f"msg_{uuid.uuid4().hex[:8]}"
|
||||
events.append(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": msg_id,
|
||||
"type": "message",
|
||||
"role": role,
|
||||
"model": model,
|
||||
"content": [],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 处理文本内容
|
||||
content_delta = delta.get("content")
|
||||
if isinstance(content_delta, str):
|
||||
events.append(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": content_delta},
|
||||
}
|
||||
)
|
||||
|
||||
# 处理工具调用
|
||||
tool_calls = delta.get("tool_calls", [])
|
||||
for tool_call in tool_calls:
|
||||
index = tool_call.get("index", 0)
|
||||
|
||||
# 工具调用开始
|
||||
if "id" in tool_call:
|
||||
function = tool_call.get("function", {})
|
||||
events.append(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": {
|
||||
"type": self.CONTENT_TYPE_TOOL_USE,
|
||||
"id": tool_call["id"],
|
||||
"name": function.get("name", ""),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 工具调用参数增量
|
||||
function = tool_call.get("function", {})
|
||||
if "arguments" in function:
|
||||
events.append(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": {
|
||||
"type": "input_json_delta",
|
||||
"partial_json": function.get("arguments", ""),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 处理结束
|
||||
if finish_reason:
|
||||
stop_reason = self.FINISH_REASON_MAP.get(finish_reason, "end_turn")
|
||||
events.append(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": stop_reason},
|
||||
}
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
# ==================== 工具方法 ====================
|
||||
|
||||
def _collapse_content(
|
||||
self, content: Optional[Union[str, List[Dict[str, Any]]]]
|
||||
) -> Optional[str]:
|
||||
"""折叠内容为字符串"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if not content:
|
||||
return None
|
||||
|
||||
text_parts = [part.get("text", "") for part in content if part.get("type") == "text"]
|
||||
return "\n\n".join(filter(None, text_parts)) or None
|
||||
|
||||
|
||||
__all__ = ["OpenAIToClaudeConverter"]
|
||||
150
src/api/handlers/claude/handler.py
Normal file
150
src/api/handlers/claude/handler.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
|
||||
|
||||
继承 ChatHandlerBase,只需覆盖格式特定的方法。
|
||||
代码量从原来的 ~1470 行减少到 ~120 行。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
|
||||
|
||||
class ClaudeChatHandler(ChatHandlerBase):
|
||||
"""
|
||||
Claude Chat Handler - 处理 Claude Chat/CLI API 格式的请求
|
||||
|
||||
格式特点:
|
||||
- 使用 input_tokens/output_tokens
|
||||
- 支持 cache_creation_input_tokens/cache_read_input_tokens
|
||||
- 请求格式:ClaudeMessagesRequest
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - Claude 格式实现
|
||||
|
||||
Claude API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
path_params: URL 路径参数(Claude 不使用)
|
||||
|
||||
Returns:
|
||||
模型名
|
||||
"""
|
||||
model = request_body.get("model")
|
||||
return str(model) if model else "unknown"
|
||||
|
||||
def apply_mapped_model(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
mapped_model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将映射后的模型名应用到请求体
|
||||
|
||||
Claude API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 原始请求体
|
||||
mapped_model: 映射后的模型名
|
||||
|
||||
Returns:
|
||||
更新了 model 字段的请求体
|
||||
"""
|
||||
result = dict(request_body)
|
||||
result["model"] = mapped_model
|
||||
return result
|
||||
|
||||
async def _convert_request(self, request):
|
||||
"""
|
||||
将请求转换为 Claude 格式
|
||||
|
||||
Args:
|
||||
request: 原始请求对象
|
||||
|
||||
Returns:
|
||||
ClaudeMessagesRequest 对象
|
||||
"""
|
||||
from src.api.handlers.claude.converter import OpenAIToClaudeConverter
|
||||
from src.models.claude import ClaudeMessagesRequest
|
||||
from src.models.openai import OpenAIRequest
|
||||
|
||||
# 如果已经是 Claude 格式,直接返回
|
||||
if isinstance(request, ClaudeMessagesRequest):
|
||||
return request
|
||||
|
||||
# 如果是 OpenAI 格式,转换为 Claude 格式
|
||||
if isinstance(request, OpenAIRequest):
|
||||
converter = OpenAIToClaudeConverter()
|
||||
claude_dict = converter.convert_request(request.dict())
|
||||
return ClaudeMessagesRequest(**claude_dict)
|
||||
|
||||
# 如果是字典,根据内容判断格式
|
||||
if isinstance(request, dict):
|
||||
if "messages" in request and len(request["messages"]) > 0:
|
||||
first_msg = request["messages"][0]
|
||||
if "role" in first_msg and "content" in first_msg:
|
||||
# 可能是 OpenAI 格式
|
||||
converter = OpenAIToClaudeConverter()
|
||||
claude_dict = converter.convert_request(request)
|
||||
return ClaudeMessagesRequest(**claude_dict)
|
||||
|
||||
# 否则假设已经是 Claude 格式
|
||||
return ClaudeMessagesRequest(**request)
|
||||
|
||||
return request
|
||||
|
||||
def _extract_usage(self, response: Dict) -> Dict[str, int]:
|
||||
"""
|
||||
从 Claude 响应中提取 token 使用情况
|
||||
|
||||
Claude 格式使用:
|
||||
- input_tokens / output_tokens
|
||||
- cache_creation_input_tokens / cache_read_input_tokens
|
||||
"""
|
||||
usage = response.get("usage", {})
|
||||
|
||||
input_tokens = usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("output_tokens", 0)
|
||||
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
|
||||
|
||||
# 处理新的 cache_creation 格式
|
||||
if "cache_creation" in usage:
|
||||
cache_creation_data = usage.get("cache_creation", {})
|
||||
if not cache_creation_input_tokens:
|
||||
cache_creation_input_tokens = cache_creation_data.get(
|
||||
"ephemeral_5m_input_tokens", 0
|
||||
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
|
||||
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cache_creation_input_tokens": cache_creation_input_tokens,
|
||||
"cache_read_input_tokens": cache_read_input_tokens,
|
||||
}
|
||||
|
||||
def _normalize_response(self, response: Dict) -> Dict:
|
||||
"""
|
||||
规范化 Claude 响应
|
||||
|
||||
Args:
|
||||
response: 原始响应
|
||||
|
||||
Returns:
|
||||
规范化后的响应
|
||||
"""
|
||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
||||
return self.response_normalizer.normalize_claude_response(
|
||||
response_data=response,
|
||||
request_id=self.request_id,
|
||||
)
|
||||
return response
|
||||
241
src/api/handlers/claude/stream_parser.py
Normal file
241
src/api/handlers/claude/stream_parser.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
Claude SSE 流解析器
|
||||
|
||||
解析 Claude Messages API 的 Server-Sent Events 流。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class ClaudeStreamParser:
|
||||
"""
|
||||
Claude SSE 流解析器
|
||||
|
||||
解析 Claude Messages API 的 SSE 事件流。
|
||||
|
||||
事件类型:
|
||||
- message_start: 消息开始,包含初始 message 对象
|
||||
- content_block_start: 内容块开始
|
||||
- content_block_delta: 内容块增量(文本、工具输入等)
|
||||
- content_block_stop: 内容块结束
|
||||
- message_delta: 消息增量,包含 stop_reason 和最终 usage
|
||||
- message_stop: 消息结束
|
||||
- ping: 心跳事件
|
||||
- error: 错误事件
|
||||
"""
|
||||
|
||||
# Claude SSE 事件类型
|
||||
EVENT_MESSAGE_START = "message_start"
|
||||
EVENT_MESSAGE_STOP = "message_stop"
|
||||
EVENT_MESSAGE_DELTA = "message_delta"
|
||||
EVENT_CONTENT_BLOCK_START = "content_block_start"
|
||||
EVENT_CONTENT_BLOCK_STOP = "content_block_stop"
|
||||
EVENT_CONTENT_BLOCK_DELTA = "content_block_delta"
|
||||
EVENT_PING = "ping"
|
||||
EVENT_ERROR = "error"
|
||||
|
||||
# Delta 类型
|
||||
DELTA_TEXT = "text_delta"
|
||||
DELTA_INPUT_JSON = "input_json_delta"
|
||||
|
||||
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析 SSE 数据块
|
||||
|
||||
Args:
|
||||
chunk: 原始 SSE 数据(bytes 或 str)
|
||||
|
||||
Returns:
|
||||
解析后的事件列表
|
||||
"""
|
||||
if isinstance(chunk, bytes):
|
||||
text = chunk.decode("utf-8")
|
||||
else:
|
||||
text = chunk
|
||||
|
||||
events: List[Dict[str, Any]] = []
|
||||
lines = text.strip().split("\n")
|
||||
|
||||
current_event_type: Optional[str] = None
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 解析事件类型行
|
||||
if line.startswith("event: "):
|
||||
current_event_type = line[7:]
|
||||
continue
|
||||
|
||||
# 解析数据行
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
|
||||
# 处理 [DONE] 标记
|
||||
if data_str == "[DONE]":
|
||||
events.append({"type": "__done__", "raw": "[DONE]"})
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
# 如果数据中没有 type,使用事件行的类型
|
||||
if "type" not in data and current_event_type:
|
||||
data["type"] = current_event_type
|
||||
events.append(data)
|
||||
except json.JSONDecodeError:
|
||||
# 无法解析的数据,跳过
|
||||
pass
|
||||
|
||||
current_event_type = None
|
||||
|
||||
return events
|
||||
|
||||
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
解析单行 SSE 数据
|
||||
|
||||
Args:
|
||||
line: SSE 数据行(已去除 "data: " 前缀)
|
||||
|
||||
Returns:
|
||||
解析后的事件字典,如果无法解析返回 None
|
||||
"""
|
||||
if not line or line == "[DONE]":
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def is_done_event(self, event: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断是否为结束事件
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
True 如果是结束事件
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
return event_type in (self.EVENT_MESSAGE_STOP, "__done__")
|
||||
|
||||
def is_error_event(self, event: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断是否为错误事件
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
True 如果是错误事件
|
||||
"""
|
||||
return event.get("type") == self.EVENT_ERROR
|
||||
|
||||
def get_event_type(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
获取事件类型
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
事件类型字符串
|
||||
"""
|
||||
return event.get("type")
|
||||
|
||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从 content_block_delta 事件中提取文本增量
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
文本增量,如果不是文本 delta 返回 None
|
||||
"""
|
||||
if event.get("type") != self.EVENT_CONTENT_BLOCK_DELTA:
|
||||
return None
|
||||
|
||||
delta = event.get("delta", {})
|
||||
if delta.get("type") == self.DELTA_TEXT:
|
||||
return delta.get("text")
|
||||
|
||||
return None
|
||||
|
||||
def extract_usage(self, event: Dict[str, Any]) -> Optional[Dict[str, int]]:
|
||||
"""
|
||||
从事件中提取 token 使用量
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
使用量字典,如果没有使用量信息返回 None
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
|
||||
# message_start 事件包含初始 usage
|
||||
if event_type == self.EVENT_MESSAGE_START:
|
||||
message = event.get("message", {})
|
||||
usage = message.get("usage", {})
|
||||
if usage:
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
# message_delta 事件包含最终 usage
|
||||
if event_type == self.EVENT_MESSAGE_DELTA:
|
||||
usage = event.get("usage", {})
|
||||
if usage:
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def extract_message_id(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从 message_start 事件中提取消息 ID
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
消息 ID,如果不是 message_start 返回 None
|
||||
"""
|
||||
if event.get("type") != self.EVENT_MESSAGE_START:
|
||||
return None
|
||||
|
||||
message = event.get("message", {})
|
||||
return message.get("id")
|
||||
|
||||
def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从 message_delta 事件中提取停止原因
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
停止原因,如果没有返回 None
|
||||
"""
|
||||
if event.get("type") != self.EVENT_MESSAGE_DELTA:
|
||||
return None
|
||||
|
||||
delta = event.get("delta", {})
|
||||
return delta.get("stop_reason")
|
||||
|
||||
|
||||
__all__ = ["ClaudeStreamParser"]
|
||||
11
src/api/handlers/claude_cli/__init__.py
Normal file
11
src/api/handlers/claude_cli/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Claude CLI 透传处理器
|
||||
"""
|
||||
|
||||
from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter
|
||||
from src.api.handlers.claude_cli.handler import ClaudeCliMessageHandler
|
||||
|
||||
__all__ = [
|
||||
"ClaudeCliAdapter",
|
||||
"ClaudeCliMessageHandler",
|
||||
]
|
||||
103
src/api/handlers/claude_cli/adapter.py
Normal file
103
src/api/handlers/claude_cli/adapter.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
||||
|
||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
from src.api.handlers.claude.adapter import ClaudeCapabilityDetector
|
||||
|
||||
|
||||
@register_cli_adapter
|
||||
class ClaudeCliAdapter(CliAdapterBase):
|
||||
"""
|
||||
Claude CLI API 适配器
|
||||
|
||||
处理 Claude CLI 格式的请求(/v1/messages 端点,使用 Bearer 认证)。
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE_CLI"
|
||||
name = "claude.cli"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.claude_cli.handler import ClaudeCliMessageHandler
|
||||
|
||||
return ClaudeCliMessageHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["CLAUDE_CLI"])
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
|
||||
authorization = request.headers.get("authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
def detect_capability_requirements(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""检测 Claude CLI 请求中隐含的能力需求"""
|
||||
return ClaudeCapabilityDetector.detect_from_headers(headers)
|
||||
|
||||
# =========================================================================
|
||||
# Claude CLI 特定的计费逻辑
|
||||
# =========================================================================
|
||||
|
||||
def compute_total_input_context(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
计算 Claude CLI 的总输入上下文(用于阶梯计费判定)
|
||||
|
||||
Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
"""
|
||||
return input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
|
||||
"""Claude CLI 使用 messages 字段"""
|
||||
messages = payload.get("messages", [])
|
||||
return len(messages) if isinstance(messages, list) else 0
|
||||
|
||||
def _build_audit_metadata(
|
||||
self,
|
||||
payload: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
|
||||
) -> Dict[str, Any]:
|
||||
"""Claude CLI 特定的审计元数据"""
|
||||
model = payload.get("model", "unknown")
|
||||
stream = payload.get("stream", False)
|
||||
messages = payload.get("messages", [])
|
||||
|
||||
role_counts = {}
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
role_counts[role] = role_counts.get(role, 0) + 1
|
||||
|
||||
return {
|
||||
"action": "claude_cli_request",
|
||||
"model": model,
|
||||
"stream": bool(stream),
|
||||
"max_tokens": payload.get("max_tokens"),
|
||||
"messages_count": len(messages),
|
||||
"message_roles": role_counts,
|
||||
"temperature": payload.get("temperature"),
|
||||
"top_p": payload.get("top_p"),
|
||||
"tool_count": len(payload.get("tools") or []),
|
||||
"system_present": bool(payload.get("system")),
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["ClaudeCliAdapter"]
|
||||
195
src/api/handlers/claude_cli/handler.py
Normal file
195
src/api/handlers/claude_cli/handler.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
Claude CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
|
||||
|
||||
继承 CliMessageHandlerBase,只需覆盖格式特定的配置和事件处理逻辑。
|
||||
验证新架构的有效性:代码量从数百行减少到 ~80 行。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.cli_handler_base import (
|
||||
CliMessageHandlerBase,
|
||||
StreamContext,
|
||||
)
|
||||
|
||||
|
||||
class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
"""
|
||||
Claude CLI Message Handler - 处理 Claude CLI API 格式
|
||||
|
||||
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
|
||||
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
|
||||
|
||||
响应格式特点:
|
||||
- 使用 content[] 数组
|
||||
- 使用 text 类型
|
||||
- 流式事件:message_start, content_block_delta, message_delta, message_stop
|
||||
- 支持 cache_creation_input_tokens 和 cache_read_input_tokens
|
||||
|
||||
模型字段:请求体顶级 model 字段
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE_CLI"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - Claude 格式实现
|
||||
|
||||
Claude API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
path_params: URL 路径参数(Claude 不使用)
|
||||
|
||||
Returns:
|
||||
模型名
|
||||
"""
|
||||
model = request_body.get("model")
|
||||
return str(model) if model else "unknown"
|
||||
|
||||
def apply_mapped_model(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
mapped_model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Claude API 的 model 在请求体顶级
|
||||
|
||||
Args:
|
||||
request_body: 原始请求体
|
||||
mapped_model: 映射后的模型名
|
||||
|
||||
Returns:
|
||||
更新了 model 字段的请求体
|
||||
"""
|
||||
result = dict(request_body)
|
||||
result["model"] = mapped_model
|
||||
return result
|
||||
|
||||
def _process_event_data(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
event_type: str,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
处理 Claude CLI 格式的 SSE 事件
|
||||
|
||||
事件类型:
|
||||
- message_start: 消息开始,包含初始 usage(含缓存 tokens)
|
||||
- content_block_delta: 文本增量
|
||||
- message_delta: 消息增量,包含最终 usage
|
||||
- message_stop: 消息结束
|
||||
"""
|
||||
# 处理 message_start 事件
|
||||
if event_type == "message_start":
|
||||
message = data.get("message", {})
|
||||
if message.get("id"):
|
||||
ctx.response_id = message["id"]
|
||||
|
||||
# 提取初始 usage(包含缓存 tokens)
|
||||
usage = message.get("usage", {})
|
||||
if usage:
|
||||
ctx.input_tokens = usage.get("input_tokens", 0)
|
||||
# Claude 的缓存 tokens 使用不同的字段名
|
||||
cache_read = usage.get("cache_read_input_tokens", 0)
|
||||
if cache_read:
|
||||
ctx.cached_tokens = cache_read
|
||||
cache_creation = usage.get("cache_creation_input_tokens", 0)
|
||||
if cache_creation:
|
||||
ctx.cache_creation_tokens = cache_creation
|
||||
|
||||
# 处理文本增量
|
||||
elif event_type == "content_block_delta":
|
||||
delta = data.get("delta", {})
|
||||
if delta.get("type") == "text_delta":
|
||||
text = delta.get("text", "")
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
|
||||
# 处理消息增量(包含最终 usage)
|
||||
elif event_type == "message_delta":
|
||||
usage = data.get("usage", {})
|
||||
if usage:
|
||||
if "input_tokens" in usage:
|
||||
ctx.input_tokens = usage["input_tokens"]
|
||||
if "output_tokens" in usage:
|
||||
ctx.output_tokens = usage["output_tokens"]
|
||||
# 更新缓存 tokens
|
||||
if "cache_read_input_tokens" in usage:
|
||||
ctx.cached_tokens = usage["cache_read_input_tokens"]
|
||||
if "cache_creation_input_tokens" in usage:
|
||||
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"]
|
||||
|
||||
# 检查是否结束
|
||||
delta = data.get("delta", {})
|
||||
if delta.get("stop_reason"):
|
||||
ctx.has_completion = True
|
||||
ctx.final_response = data
|
||||
|
||||
# 处理消息结束
|
||||
elif event_type == "message_stop":
|
||||
ctx.has_completion = True
|
||||
|
||||
def _extract_response_metadata(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从 Claude 响应中提取元数据
|
||||
|
||||
提取 model、stop_reason 等字段作为元数据。
|
||||
|
||||
Args:
|
||||
response: Claude API 响应
|
||||
|
||||
Returns:
|
||||
提取的元数据字典
|
||||
"""
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
# 提取模型名称(实际使用的模型)
|
||||
if "model" in response:
|
||||
metadata["model"] = response["model"]
|
||||
|
||||
# 提取停止原因
|
||||
if "stop_reason" in response:
|
||||
metadata["stop_reason"] = response["stop_reason"]
|
||||
|
||||
# 提取消息 ID
|
||||
if "id" in response:
|
||||
metadata["message_id"] = response["id"]
|
||||
|
||||
# 提取消息类型
|
||||
if "type" in response:
|
||||
metadata["type"] = response["type"]
|
||||
|
||||
return metadata
|
||||
|
||||
def _finalize_stream_metadata(self, ctx: StreamContext) -> None:
|
||||
"""
|
||||
从流上下文中提取最终元数据
|
||||
|
||||
在流传输完成后调用,从收集的事件中提取元数据。
|
||||
|
||||
Args:
|
||||
ctx: 流上下文
|
||||
"""
|
||||
# 从 response_id 提取消息 ID
|
||||
if ctx.response_id:
|
||||
ctx.response_metadata["message_id"] = ctx.response_id
|
||||
|
||||
# 从 final_response 提取停止原因(message_delta 事件中的 delta.stop_reason)
|
||||
if ctx.final_response:
|
||||
delta = ctx.final_response.get("delta", {})
|
||||
if "stop_reason" in delta:
|
||||
ctx.response_metadata["stop_reason"] = delta["stop_reason"]
|
||||
|
||||
# 记录模型名称
|
||||
if ctx.model:
|
||||
ctx.response_metadata["model"] = ctx.model
|
||||
|
||||
26
src/api/handlers/gemini/__init__.py
Normal file
26
src/api/handlers/gemini/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Gemini API Handler 模块
|
||||
|
||||
提供 Gemini API 格式的请求处理
|
||||
"""
|
||||
|
||||
from src.api.handlers.gemini.adapter import GeminiChatAdapter, build_gemini_adapter
|
||||
from src.api.handlers.gemini.converter import (
|
||||
ClaudeToGeminiConverter,
|
||||
GeminiToClaudeConverter,
|
||||
GeminiToOpenAIConverter,
|
||||
OpenAIToGeminiConverter,
|
||||
)
|
||||
from src.api.handlers.gemini.handler import GeminiChatHandler
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
__all__ = [
|
||||
"GeminiChatAdapter",
|
||||
"GeminiChatHandler",
|
||||
"GeminiStreamParser",
|
||||
"ClaudeToGeminiConverter",
|
||||
"GeminiToClaudeConverter",
|
||||
"OpenAIToGeminiConverter",
|
||||
"GeminiToOpenAIConverter",
|
||||
"build_gemini_adapter",
|
||||
]
|
||||
170
src/api/handlers/gemini/adapter.py
Normal file
170
src/api/handlers/gemini/adapter.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Gemini Chat Adapter
|
||||
|
||||
处理 Gemini API 格式的请求适配
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.core.logger import logger
|
||||
from src.models.gemini import GeminiRequest
|
||||
|
||||
|
||||
@register_adapter
|
||||
class GeminiChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
Gemini Chat API 适配器
|
||||
|
||||
处理 Gemini Chat 格式的请求
|
||||
端点: /v1beta/models/{model}:generateContent
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI"
|
||||
name = "gemini.chat"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.gemini.handler import GeminiChatHandler
|
||||
|
||||
return GeminiChatHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["GEMINI"])
|
||||
logger.info(f"[{self.name}] 初始化 Gemini Chat 适配器 | API格式: {self.allowed_api_formats}")
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (x-goog-api-key)"""
|
||||
return request.headers.get("x-goog-api-key")
|
||||
|
||||
def _merge_path_params(
|
||||
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any] # noqa: ARG002
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并 URL 路径参数到请求体 - Gemini 特化版本
|
||||
|
||||
Gemini API 特点:
|
||||
- model 不合并到请求体(通过 extract_model_from_request 从 path_params 获取)
|
||||
- stream 不合并到请求体(Gemini API 通过 URL 端点区分流式/非流式)
|
||||
|
||||
Handler 层的 extract_model_from_request 会从 path_params 获取 model,
|
||||
prepare_provider_request_body 会确保发送给 Gemini API 的请求体不含 model。
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数字典(不使用)
|
||||
|
||||
Returns:
|
||||
原始请求体(不合并任何 path_params)
|
||||
"""
|
||||
return original_request_body.copy()
|
||||
|
||||
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
||||
"""验证请求体"""
|
||||
path_params = path_params or {}
|
||||
is_stream = path_params.get("stream", False)
|
||||
model = path_params.get("model", "unknown")
|
||||
|
||||
try:
|
||||
if not isinstance(original_request_body, dict):
|
||||
raise ValueError("Request body must be a JSON object")
|
||||
|
||||
# Gemini 必需字段: contents
|
||||
if "contents" not in original_request_body:
|
||||
raise ValueError("Missing required field: contents")
|
||||
|
||||
request = GeminiRequest.model_validate(
|
||||
original_request_body,
|
||||
strict=False,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"请求体基本验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
|
||||
request = GeminiRequest.model_construct(
|
||||
contents=original_request_body.get("contents", []),
|
||||
)
|
||||
|
||||
# 设置 model(从 path_params 获取,用于日志和审计)
|
||||
request.model = model
|
||||
# 设置 stream 属性(用于 ChatAdapterBase 判断流式模式)
|
||||
request.stream = is_stream
|
||||
return request
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
|
||||
"""提取消息数量"""
|
||||
contents = payload.get("contents", [])
|
||||
if hasattr(request_obj, "contents"):
|
||||
contents = request_obj.contents
|
||||
return len(contents) if isinstance(contents, list) else 0
|
||||
|
||||
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
||||
"""构建 Gemini Chat 特定的审计元数据"""
|
||||
role_counts: dict[str, int] = {}
|
||||
|
||||
contents = getattr(request_obj, "contents", []) or []
|
||||
for content in contents:
|
||||
role = getattr(content, "role", None) or content.get("role", "unknown")
|
||||
role_counts[role] = role_counts.get(role, 0) + 1
|
||||
|
||||
generation_config = getattr(request_obj, "generation_config", None) or {}
|
||||
if hasattr(generation_config, "dict"):
|
||||
generation_config = generation_config.dict()
|
||||
elif not isinstance(generation_config, dict):
|
||||
generation_config = {}
|
||||
|
||||
# 判断流式模式
|
||||
stream = getattr(request_obj, "stream", False)
|
||||
|
||||
return {
|
||||
"action": "gemini_generate_content",
|
||||
"model": getattr(request_obj, "model", payload.get("model", "unknown")),
|
||||
"stream": bool(stream),
|
||||
"max_output_tokens": generation_config.get("max_output_tokens"),
|
||||
"temperature": generation_config.get("temperature"),
|
||||
"top_p": generation_config.get("top_p"),
|
||||
"top_k": generation_config.get("top_k"),
|
||||
"contents_count": len(contents),
|
||||
"content_roles": role_counts,
|
||||
"tools_count": len(getattr(request_obj, "tools", None) or []),
|
||||
"system_instruction_present": bool(getattr(request_obj, "system_instruction", None)),
|
||||
"safety_settings_count": len(getattr(request_obj, "safety_settings", None) or []),
|
||||
}
|
||||
|
||||
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
|
||||
"""生成 Gemini 格式的错误响应"""
|
||||
# Gemini 错误响应格式
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"code": status_code,
|
||||
"message": message,
|
||||
"status": error_type.upper(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
|
||||
"""
|
||||
根据请求头构建适当的 Gemini 适配器
|
||||
|
||||
Args:
|
||||
x_app_header: X-App 请求头值
|
||||
|
||||
Returns:
|
||||
GeminiChatAdapter 实例
|
||||
"""
|
||||
# 目前只有一种 Gemini 适配器
|
||||
# 未来可以根据 x_app_header 返回不同的适配器(如 CLI 模式)
|
||||
return GeminiChatAdapter()
|
||||
|
||||
|
||||
__all__ = ["GeminiChatAdapter", "build_gemini_adapter"]
|
||||
544
src/api/handlers/gemini/converter.py
Normal file
544
src/api/handlers/gemini/converter.py
Normal file
@@ -0,0 +1,544 @@
|
||||
"""
|
||||
Gemini 格式转换器
|
||||
|
||||
提供 Gemini 与其他 API 格式(Claude、OpenAI)之间的转换
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class ClaudeToGeminiConverter:
|
||||
"""
|
||||
Claude -> Gemini 请求转换器
|
||||
|
||||
将 Claude Messages API 格式转换为 Gemini generateContent 格式
|
||||
"""
|
||||
|
||||
def convert_request(self, claude_request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Claude 请求转换为 Gemini 请求
|
||||
|
||||
Args:
|
||||
claude_request: Claude 格式的请求字典
|
||||
|
||||
Returns:
|
||||
Gemini 格式的请求字典
|
||||
"""
|
||||
gemini_request: Dict[str, Any] = {
|
||||
"contents": self._convert_messages(claude_request.get("messages", [])),
|
||||
}
|
||||
|
||||
# 转换 system prompt
|
||||
system = claude_request.get("system")
|
||||
if system:
|
||||
gemini_request["system_instruction"] = self._convert_system(system)
|
||||
|
||||
# 转换生成配置
|
||||
generation_config = self._build_generation_config(claude_request)
|
||||
if generation_config:
|
||||
gemini_request["generation_config"] = generation_config
|
||||
|
||||
# 转换工具
|
||||
tools = claude_request.get("tools")
|
||||
if tools:
|
||||
gemini_request["tools"] = self._convert_tools(tools)
|
||||
|
||||
return gemini_request
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""转换消息列表"""
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
# Gemini 使用 "model" 而不是 "assistant"
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
|
||||
content = msg.get("content", "")
|
||||
parts = self._convert_content_to_parts(content)
|
||||
|
||||
contents.append(
|
||||
{
|
||||
"role": gemini_role,
|
||||
"parts": parts,
|
||||
}
|
||||
)
|
||||
return contents
|
||||
|
||||
def _convert_content_to_parts(self, content: Any) -> List[Dict[str, Any]]:
|
||||
"""将 Claude 内容转换为 Gemini parts"""
|
||||
if isinstance(content, str):
|
||||
return [{"text": content}]
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append({"text": block})
|
||||
elif isinstance(block, dict):
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
parts.append({"text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
# 转换图片
|
||||
source = block.get("source", {})
|
||||
if source.get("type") == "base64":
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": source.get("media_type", "image/png"),
|
||||
"data": source.get("data", ""),
|
||||
}
|
||||
}
|
||||
)
|
||||
elif block_type == "tool_use":
|
||||
# 转换工具调用
|
||||
parts.append(
|
||||
{
|
||||
"function_call": {
|
||||
"name": block.get("name", ""),
|
||||
"args": block.get("input", {}),
|
||||
}
|
||||
}
|
||||
)
|
||||
elif block_type == "tool_result":
|
||||
# 转换工具结果
|
||||
parts.append(
|
||||
{
|
||||
"function_response": {
|
||||
"name": block.get("tool_use_id", ""),
|
||||
"response": {"result": block.get("content", "")},
|
||||
}
|
||||
}
|
||||
)
|
||||
return parts
|
||||
|
||||
return [{"text": str(content)}]
|
||||
|
||||
def _convert_system(self, system: Any) -> Dict[str, Any]:
|
||||
"""转换 system prompt"""
|
||||
if isinstance(system, str):
|
||||
return {"parts": [{"text": system}]}
|
||||
|
||||
if isinstance(system, list):
|
||||
parts = []
|
||||
for item in system:
|
||||
if isinstance(item, str):
|
||||
parts.append({"text": item})
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append({"text": item.get("text", "")})
|
||||
return {"parts": parts}
|
||||
|
||||
return {"parts": [{"text": str(system)}]}
|
||||
|
||||
def _build_generation_config(self, claude_request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""构建生成配置"""
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
if "max_tokens" in claude_request:
|
||||
config["max_output_tokens"] = claude_request["max_tokens"]
|
||||
if "temperature" in claude_request:
|
||||
config["temperature"] = claude_request["temperature"]
|
||||
if "top_p" in claude_request:
|
||||
config["top_p"] = claude_request["top_p"]
|
||||
if "top_k" in claude_request:
|
||||
config["top_k"] = claude_request["top_k"]
|
||||
if "stop_sequences" in claude_request:
|
||||
config["stop_sequences"] = claude_request["stop_sequences"]
|
||||
|
||||
return config if config else None
|
||||
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""转换工具定义"""
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
func_decl = {
|
||||
"name": tool.get("name", ""),
|
||||
}
|
||||
if "description" in tool:
|
||||
func_decl["description"] = tool["description"]
|
||||
if "input_schema" in tool:
|
||||
func_decl["parameters"] = tool["input_schema"]
|
||||
function_declarations.append(func_decl)
|
||||
|
||||
return [{"function_declarations": function_declarations}]
|
||||
|
||||
|
||||
class GeminiToClaudeConverter:
|
||||
"""
|
||||
Gemini -> Claude 响应转换器
|
||||
|
||||
将 Gemini generateContent 响应转换为 Claude Messages API 格式
|
||||
"""
|
||||
|
||||
def convert_response(self, gemini_response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Gemini 响应转换为 Claude 响应
|
||||
|
||||
Args:
|
||||
gemini_response: Gemini 格式的响应字典
|
||||
|
||||
Returns:
|
||||
Claude 格式的响应字典
|
||||
"""
|
||||
candidates = gemini_response.get("candidates", [])
|
||||
if not candidates:
|
||||
return self._create_empty_response()
|
||||
|
||||
candidate = candidates[0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
# 转换内容块
|
||||
claude_content = self._convert_parts_to_content(parts)
|
||||
|
||||
# 转换使用量
|
||||
usage = self._convert_usage(gemini_response.get("usageMetadata", {}))
|
||||
|
||||
# 转换停止原因
|
||||
stop_reason = self._convert_finish_reason(candidate.get("finishReason"))
|
||||
|
||||
return {
|
||||
"id": f"msg_{gemini_response.get('modelVersion', 'gemini')}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": claude_content,
|
||||
"model": gemini_response.get("modelVersion", "gemini"),
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
def _convert_parts_to_content(self, parts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""将 Gemini parts 转换为 Claude content blocks"""
|
||||
content = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": part["text"],
|
||||
}
|
||||
)
|
||||
elif "functionCall" in part:
|
||||
func_call = part["functionCall"]
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": f"toolu_{func_call.get('name', '')}",
|
||||
"name": func_call.get("name", ""),
|
||||
"input": func_call.get("args", {}),
|
||||
}
|
||||
)
|
||||
return content
|
||||
|
||||
def _convert_usage(self, usage_metadata: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""转换使用量信息"""
|
||||
return {
|
||||
"input_tokens": usage_metadata.get("promptTokenCount", 0),
|
||||
"output_tokens": usage_metadata.get("candidatesTokenCount", 0),
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": usage_metadata.get("cachedContentTokenCount", 0),
|
||||
}
|
||||
|
||||
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
|
||||
"""转换停止原因"""
|
||||
mapping = {
|
||||
"STOP": "end_turn",
|
||||
"MAX_TOKENS": "max_tokens",
|
||||
"SAFETY": "content_filtered",
|
||||
"RECITATION": "content_filtered",
|
||||
"OTHER": "stop_sequence",
|
||||
}
|
||||
return mapping.get(finish_reason, "end_turn")
|
||||
|
||||
def _create_empty_response(self) -> Dict[str, Any]:
|
||||
"""创建空响应"""
|
||||
return {
|
||||
"id": "msg_empty",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": "gemini",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OpenAIToGeminiConverter:
|
||||
"""
|
||||
OpenAI -> Gemini 请求转换器
|
||||
|
||||
将 OpenAI Chat Completions API 格式转换为 Gemini generateContent 格式
|
||||
"""
|
||||
|
||||
def convert_request(self, openai_request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 OpenAI 请求转换为 Gemini 请求
|
||||
|
||||
Args:
|
||||
openai_request: OpenAI 格式的请求字典
|
||||
|
||||
Returns:
|
||||
Gemini 格式的请求字典
|
||||
"""
|
||||
messages = openai_request.get("messages", [])
|
||||
|
||||
# 分离 system 消息和其他消息
|
||||
system_messages = []
|
||||
other_messages = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
system_messages.append(msg)
|
||||
else:
|
||||
other_messages.append(msg)
|
||||
|
||||
gemini_request: Dict[str, Any] = {
|
||||
"contents": self._convert_messages(other_messages),
|
||||
}
|
||||
|
||||
# 转换 system messages
|
||||
if system_messages:
|
||||
system_text = "\n".join(msg.get("content", "") for msg in system_messages)
|
||||
gemini_request["system_instruction"] = {"parts": [{"text": system_text}]}
|
||||
|
||||
# 转换生成配置
|
||||
generation_config = self._build_generation_config(openai_request)
|
||||
if generation_config:
|
||||
gemini_request["generation_config"] = generation_config
|
||||
|
||||
# 转换工具
|
||||
tools = openai_request.get("tools")
|
||||
if tools:
|
||||
gemini_request["tools"] = self._convert_tools(tools)
|
||||
|
||||
return gemini_request
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""转换消息列表"""
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
|
||||
content = msg.get("content", "")
|
||||
parts = self._convert_content_to_parts(content)
|
||||
|
||||
# 处理工具调用
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
for tc in tool_calls:
|
||||
if tc.get("type") == "function":
|
||||
func = tc.get("function", {})
|
||||
import json
|
||||
|
||||
try:
|
||||
args = json.loads(func.get("arguments", "{}"))
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
parts.append(
|
||||
{
|
||||
"function_call": {
|
||||
"name": func.get("name", ""),
|
||||
"args": args,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if parts:
|
||||
contents.append(
|
||||
{
|
||||
"role": gemini_role,
|
||||
"parts": parts,
|
||||
}
|
||||
)
|
||||
return contents
|
||||
|
||||
def _convert_content_to_parts(self, content: Any) -> List[Dict[str, Any]]:
|
||||
"""将 OpenAI 内容转换为 Gemini parts"""
|
||||
if content is None:
|
||||
return []
|
||||
|
||||
if isinstance(content, str):
|
||||
return [{"text": content}]
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append({"text": item})
|
||||
elif isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
parts.append({"text": item.get("text", "")})
|
||||
elif item_type == "image_url":
|
||||
# OpenAI 图片 URL 格式
|
||||
image_url = item.get("image_url", {})
|
||||
url = image_url.get("url", "")
|
||||
if url.startswith("data:"):
|
||||
# base64 数据 URL
|
||||
# 格式: 
|
||||
try:
|
||||
header, data = url.split(",", 1)
|
||||
mime_type = header.split(":")[1].split(";")[0]
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": mime_type,
|
||||
"data": data,
|
||||
}
|
||||
}
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
return parts
|
||||
|
||||
return [{"text": str(content)}]
|
||||
|
||||
def _build_generation_config(self, openai_request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""构建生成配置"""
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
if "max_tokens" in openai_request:
|
||||
config["max_output_tokens"] = openai_request["max_tokens"]
|
||||
if "temperature" in openai_request:
|
||||
config["temperature"] = openai_request["temperature"]
|
||||
if "top_p" in openai_request:
|
||||
config["top_p"] = openai_request["top_p"]
|
||||
if "stop" in openai_request:
|
||||
stop = openai_request["stop"]
|
||||
if isinstance(stop, str):
|
||||
config["stop_sequences"] = [stop]
|
||||
elif isinstance(stop, list):
|
||||
config["stop_sequences"] = stop
|
||||
if "n" in openai_request:
|
||||
config["candidate_count"] = openai_request["n"]
|
||||
|
||||
return config if config else None
|
||||
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""转换工具定义"""
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool.get("function", {})
|
||||
func_decl = {
|
||||
"name": func.get("name", ""),
|
||||
}
|
||||
if "description" in func:
|
||||
func_decl["description"] = func["description"]
|
||||
if "parameters" in func:
|
||||
func_decl["parameters"] = func["parameters"]
|
||||
function_declarations.append(func_decl)
|
||||
|
||||
return [{"function_declarations": function_declarations}]
|
||||
|
||||
|
||||
class GeminiToOpenAIConverter:
|
||||
"""
|
||||
Gemini -> OpenAI 响应转换器
|
||||
|
||||
将 Gemini generateContent 响应转换为 OpenAI Chat Completions API 格式
|
||||
"""
|
||||
|
||||
def convert_response(self, gemini_response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Gemini 响应转换为 OpenAI 响应
|
||||
|
||||
Args:
|
||||
gemini_response: Gemini 格式的响应字典
|
||||
|
||||
Returns:
|
||||
OpenAI 格式的响应字典
|
||||
"""
|
||||
import time
|
||||
|
||||
candidates = gemini_response.get("candidates", [])
|
||||
choices = []
|
||||
|
||||
for i, candidate in enumerate(candidates):
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
# 提取文本内容
|
||||
text_parts = []
|
||||
tool_calls = []
|
||||
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
elif "functionCall" in part:
|
||||
func_call = part["functionCall"]
|
||||
import json
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{func_call.get('name', '')}_{i}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_call.get("name", ""),
|
||||
"arguments": json.dumps(func_call.get("args", {})),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
message: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": "".join(text_parts) if text_parts else None,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
|
||||
finish_reason = self._convert_finish_reason(candidate.get("finishReason"))
|
||||
|
||||
choices.append(
|
||||
{
|
||||
"index": i,
|
||||
"message": message,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
)
|
||||
|
||||
# 转换使用量
|
||||
usage = self._convert_usage(gemini_response.get("usageMetadata", {}))
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{gemini_response.get('modelVersion', 'gemini')}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": gemini_response.get("modelVersion", "gemini"),
|
||||
"choices": choices,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
def _convert_usage(self, usage_metadata: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""转换使用量信息"""
|
||||
prompt_tokens = usage_metadata.get("promptTokenCount", 0)
|
||||
completion_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
|
||||
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
|
||||
"""转换停止原因"""
|
||||
mapping = {
|
||||
"STOP": "stop",
|
||||
"MAX_TOKENS": "length",
|
||||
"SAFETY": "content_filter",
|
||||
"RECITATION": "content_filter",
|
||||
"OTHER": "stop",
|
||||
}
|
||||
return mapping.get(finish_reason, "stop")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ClaudeToGeminiConverter",
|
||||
"GeminiToClaudeConverter",
|
||||
"OpenAIToGeminiConverter",
|
||||
"GeminiToOpenAIConverter",
|
||||
]
|
||||
164
src/api/handlers/gemini/handler.py
Normal file
164
src/api/handlers/gemini/handler.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Gemini Chat Handler
|
||||
|
||||
处理 Gemini API 格式的请求
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
|
||||
|
||||
class GeminiChatHandler(ChatHandlerBase):
|
||||
"""
|
||||
Gemini Chat Handler - 处理 Google Gemini API 格式的请求
|
||||
|
||||
格式特点:
|
||||
- 使用 promptTokenCount / candidatesTokenCount
|
||||
- 支持 cachedContentTokenCount
|
||||
- 请求格式: GeminiRequest
|
||||
- 响应格式: JSON 数组流(非 SSE)
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - Gemini Chat 格式实现
|
||||
|
||||
Gemini Chat 模式下,model 在请求体中(经过转换后的 GeminiRequest)。
|
||||
与 Gemini CLI 不同,CLI 模式的 model 在 URL 路径中。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
path_params: URL 路径参数(Chat 模式通常不使用)
|
||||
|
||||
Returns:
|
||||
模型名
|
||||
"""
|
||||
# 优先从请求体获取,其次从 path_params
|
||||
model = request_body.get("model")
|
||||
if model:
|
||||
return str(model)
|
||||
if path_params and "model" in path_params:
|
||||
return str(path_params["model"])
|
||||
return "unknown"
|
||||
|
||||
async def _convert_request(self, request):
|
||||
"""
|
||||
将请求转换为 Gemini 格式
|
||||
|
||||
支持自动转换:
|
||||
- Claude 格式 → Gemini 格式
|
||||
- OpenAI 格式 → Gemini 格式
|
||||
|
||||
Args:
|
||||
request: 原始请求对象(可能是 Gemini/Claude/OpenAI 格式)
|
||||
|
||||
Returns:
|
||||
GeminiRequest 对象
|
||||
"""
|
||||
from src.api.handlers.gemini.converter import (
|
||||
ClaudeToGeminiConverter,
|
||||
OpenAIToGeminiConverter,
|
||||
)
|
||||
from src.models.claude import ClaudeMessagesRequest
|
||||
from src.models.gemini import GeminiRequest
|
||||
from src.models.openai import OpenAIRequest
|
||||
|
||||
# 如果已经是 Gemini 格式,直接返回
|
||||
if isinstance(request, GeminiRequest):
|
||||
return request
|
||||
|
||||
# 如果是 Claude 格式,转换为 Gemini 格式
|
||||
if isinstance(request, ClaudeMessagesRequest):
|
||||
converter = ClaudeToGeminiConverter()
|
||||
gemini_dict = converter.convert_request(request.model_dump())
|
||||
return GeminiRequest(**gemini_dict)
|
||||
|
||||
# 如果是 OpenAI 格式,转换为 Gemini 格式
|
||||
if isinstance(request, OpenAIRequest):
|
||||
converter = OpenAIToGeminiConverter()
|
||||
gemini_dict = converter.convert_request(request.model_dump())
|
||||
return GeminiRequest(**gemini_dict)
|
||||
|
||||
# 如果是字典,根据内容判断格式并转换
|
||||
if isinstance(request, dict):
|
||||
# 检测 Gemini 格式特征: contents 字段
|
||||
if "contents" in request:
|
||||
return GeminiRequest(**request)
|
||||
|
||||
# 检测 Claude 格式特征: messages + 没有 choices
|
||||
if "messages" in request and "choices" not in request:
|
||||
# 进一步区分 Claude 和 OpenAI
|
||||
# Claude 使用 max_tokens,OpenAI 也可能有
|
||||
# Claude 的 messages[].content 可以是数组,OpenAI 通常是字符串
|
||||
messages = request.get("messages", [])
|
||||
if messages and isinstance(messages[0].get("content"), list):
|
||||
# 可能是 Claude 格式
|
||||
converter = ClaudeToGeminiConverter()
|
||||
gemini_dict = converter.convert_request(request)
|
||||
return GeminiRequest(**gemini_dict)
|
||||
else:
|
||||
# 可能是 OpenAI 格式
|
||||
converter = OpenAIToGeminiConverter()
|
||||
gemini_dict = converter.convert_request(request)
|
||||
return GeminiRequest(**gemini_dict)
|
||||
|
||||
# 默认尝试作为 Gemini 格式
|
||||
return GeminiRequest(**request)
|
||||
|
||||
return request
|
||||
|
||||
def _extract_usage(self, response: Dict) -> Dict[str, int]:
|
||||
"""
|
||||
从 Gemini 响应中提取 token 使用情况
|
||||
|
||||
调用 GeminiStreamParser.extract_usage 作为单一实现源
|
||||
"""
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
usage = GeminiStreamParser().extract_usage(response)
|
||||
|
||||
if not usage:
|
||||
return {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_input_tokens": 0, # Gemini 不区分缓存创建
|
||||
"cache_read_input_tokens": usage.get("cached_tokens", 0),
|
||||
}
|
||||
|
||||
def _normalize_response(self, response: Dict) -> Dict:
|
||||
"""
|
||||
规范化 Gemini 响应
|
||||
|
||||
Args:
|
||||
response: 原始响应
|
||||
|
||||
Returns:
|
||||
规范化后的响应
|
||||
|
||||
TODO: 如果需要,实现响应规范化逻辑
|
||||
"""
|
||||
# 可选:使用 response_normalizer 进行规范化
|
||||
# if (
|
||||
# self.response_normalizer
|
||||
# and self.response_normalizer.should_normalize(response)
|
||||
# ):
|
||||
# return self.response_normalizer.normalize_gemini_response(
|
||||
# response_data=response,
|
||||
# request_id=self.request_id,
|
||||
# strict=False,
|
||||
# )
|
||||
return response
|
||||
307
src/api/handlers/gemini/stream_parser.py
Normal file
307
src/api/handlers/gemini/stream_parser.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
Gemini SSE/JSON 流解析器
|
||||
|
||||
Gemini API 的流式响应格式与 Claude/OpenAI 不同:
|
||||
- 使用 JSON 数组格式 (不是 SSE)
|
||||
- 每个块是一个完整的 JSON 对象
|
||||
- 响应以 [ 开始,以 ] 结束,块之间用 , 分隔
|
||||
|
||||
参考: https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class GeminiStreamParser:
|
||||
"""
|
||||
Gemini 流解析器
|
||||
|
||||
解析 Gemini streamGenerateContent API 的响应流。
|
||||
|
||||
Gemini 流式响应特点:
|
||||
- 返回 JSON 数组格式: [{chunk1}, {chunk2}, ...]
|
||||
- 每个 chunk 包含 candidates、usageMetadata 等字段
|
||||
- finish_reason 可能值: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
|
||||
"""
|
||||
|
||||
# 停止原因
|
||||
FINISH_REASON_STOP = "STOP"
|
||||
FINISH_REASON_MAX_TOKENS = "MAX_TOKENS"
|
||||
FINISH_REASON_SAFETY = "SAFETY"
|
||||
FINISH_REASON_RECITATION = "RECITATION"
|
||||
FINISH_REASON_OTHER = "OTHER"
|
||||
|
||||
def __init__(self):
|
||||
self._buffer = ""
|
||||
self._in_array = False
|
||||
self._brace_depth = 0
|
||||
|
||||
def reset(self):
|
||||
"""重置解析器状态"""
|
||||
self._buffer = ""
|
||||
self._in_array = False
|
||||
self._brace_depth = 0
|
||||
|
||||
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析流式数据块
|
||||
|
||||
Args:
|
||||
chunk: 原始数据(bytes 或 str)
|
||||
|
||||
Returns:
|
||||
解析后的事件列表
|
||||
"""
|
||||
if isinstance(chunk, bytes):
|
||||
text = chunk.decode("utf-8")
|
||||
else:
|
||||
text = chunk
|
||||
|
||||
events: List[Dict[str, Any]] = []
|
||||
|
||||
for char in text:
|
||||
if char == "[" and not self._in_array:
|
||||
self._in_array = True
|
||||
continue
|
||||
|
||||
if char == "]" and self._in_array and self._brace_depth == 0:
|
||||
# 数组结束
|
||||
self._in_array = False
|
||||
if self._buffer.strip():
|
||||
try:
|
||||
obj = json.loads(self._buffer.strip().rstrip(","))
|
||||
events.append(obj)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
self._buffer = ""
|
||||
continue
|
||||
|
||||
if self._in_array:
|
||||
if char == "{":
|
||||
self._brace_depth += 1
|
||||
elif char == "}":
|
||||
self._brace_depth -= 1
|
||||
|
||||
self._buffer += char
|
||||
|
||||
# 当 brace_depth 回到 0 时,说明一个完整的 JSON 对象结束
|
||||
if self._brace_depth == 0 and self._buffer.strip():
|
||||
try:
|
||||
obj = json.loads(self._buffer.strip().rstrip(","))
|
||||
events.append(obj)
|
||||
self._buffer = ""
|
||||
except json.JSONDecodeError:
|
||||
# 可能还不完整,继续累积
|
||||
pass
|
||||
|
||||
return events
|
||||
|
||||
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
解析单行 JSON 数据
|
||||
|
||||
Args:
|
||||
line: JSON 数据行
|
||||
|
||||
Returns:
|
||||
解析后的事件字典,如果无法解析返回 None
|
||||
"""
|
||||
if not line or line.strip() in ["[", "]", ","]:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line.strip().rstrip(","))
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def is_done_event(self, event: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断是否为结束事件
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
True 如果是结束事件
|
||||
"""
|
||||
candidates = event.get("candidates", [])
|
||||
if candidates:
|
||||
for candidate in candidates:
|
||||
finish_reason = candidate.get("finishReason")
|
||||
if finish_reason in (
|
||||
self.FINISH_REASON_STOP,
|
||||
self.FINISH_REASON_MAX_TOKENS,
|
||||
self.FINISH_REASON_SAFETY,
|
||||
self.FINISH_REASON_RECITATION,
|
||||
self.FINISH_REASON_OTHER,
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_error_event(self, event: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断是否为错误事件
|
||||
|
||||
检测多种 Gemini 错误格式:
|
||||
1. 顶层 error: {"error": {...}}
|
||||
2. chunks 内嵌套 error: {"chunks": [{"error": {...}}]}
|
||||
3. candidates 内的错误状态
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
True 如果是错误事件
|
||||
"""
|
||||
# 顶层 error
|
||||
if "error" in event:
|
||||
return True
|
||||
|
||||
# chunks 内嵌套 error (某些 Gemini 响应格式)
|
||||
chunks = event.get("chunks", [])
|
||||
if chunks and isinstance(chunks, list):
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk, dict) and "error" in chunk:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def extract_error_info(self, event: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
从事件中提取错误信息
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
错误信息字典 {"code": int, "message": str, "status": str},无错误返回 None
|
||||
"""
|
||||
# 顶层 error
|
||||
if "error" in event:
|
||||
error = event["error"]
|
||||
if isinstance(error, dict):
|
||||
return {
|
||||
"code": error.get("code"),
|
||||
"message": error.get("message", str(error)),
|
||||
"status": error.get("status"),
|
||||
}
|
||||
return {"code": None, "message": str(error), "status": None}
|
||||
|
||||
# chunks 内嵌套 error
|
||||
chunks = event.get("chunks", [])
|
||||
if chunks and isinstance(chunks, list):
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk, dict) and "error" in chunk:
|
||||
error = chunk["error"]
|
||||
if isinstance(error, dict):
|
||||
return {
|
||||
"code": error.get("code"),
|
||||
"message": error.get("message", str(error)),
|
||||
"status": error.get("status"),
|
||||
}
|
||||
return {"code": None, "message": str(error), "status": None}
|
||||
|
||||
return None
|
||||
|
||||
def get_finish_reason(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
获取结束原因
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
结束原因字符串
|
||||
"""
|
||||
candidates = event.get("candidates", [])
|
||||
if candidates:
|
||||
return candidates[0].get("finishReason")
|
||||
return None
|
||||
|
||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从响应中提取文本内容
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
文本内容,如果没有文本返回 None
|
||||
"""
|
||||
candidates = event.get("candidates", [])
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
content = candidates[0].get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
text_parts = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
|
||||
return "".join(text_parts) if text_parts else None
|
||||
|
||||
def extract_usage(self, event: Dict[str, Any]) -> Optional[Dict[str, int]]:
|
||||
"""
|
||||
从事件中提取 token 使用量
|
||||
|
||||
这是 Gemini token 提取的单一实现源,其他地方都应该调用此方法。
|
||||
|
||||
Args:
|
||||
event: 事件字典(包含 usageMetadata)
|
||||
|
||||
Returns:
|
||||
使用量字典,如果没有完整的使用量信息返回 None
|
||||
|
||||
注意:
|
||||
- 只有当 totalTokenCount 存在时才提取(确保是完整的 usage 数据)
|
||||
- 输出 token = thoughtsTokenCount + candidatesTokenCount
|
||||
"""
|
||||
usage_metadata = event.get("usageMetadata", {})
|
||||
if not usage_metadata or "totalTokenCount" not in usage_metadata:
|
||||
return None
|
||||
|
||||
# 输出 token = thoughtsTokenCount + candidatesTokenCount
|
||||
thoughts_tokens = usage_metadata.get("thoughtsTokenCount", 0)
|
||||
candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
||||
output_tokens = thoughts_tokens + candidates_tokens
|
||||
|
||||
return {
|
||||
"input_tokens": usage_metadata.get("promptTokenCount", 0),
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": usage_metadata.get("totalTokenCount", 0),
|
||||
"cached_tokens": usage_metadata.get("cachedContentTokenCount", 0),
|
||||
}
|
||||
|
||||
def extract_model_version(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从响应中提取模型版本
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
模型版本,如果没有返回 None
|
||||
"""
|
||||
return event.get("modelVersion")
|
||||
|
||||
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
从响应中提取安全评级
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
安全评级列表,如果没有返回 None
|
||||
"""
|
||||
candidates = event.get("candidates", [])
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
return candidates[0].get("safetyRatings")
|
||||
|
||||
|
||||
__all__ = ["GeminiStreamParser"]
|
||||
12
src/api/handlers/gemini_cli/__init__.py
Normal file
12
src/api/handlers/gemini_cli/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Gemini CLI 透传处理器
|
||||
"""
|
||||
|
||||
from src.api.handlers.gemini_cli.adapter import GeminiCliAdapter, build_gemini_cli_adapter
|
||||
from src.api.handlers.gemini_cli.handler import GeminiCliMessageHandler
|
||||
|
||||
__all__ = [
|
||||
"GeminiCliAdapter",
|
||||
"GeminiCliMessageHandler",
|
||||
"build_gemini_cli_adapter",
|
||||
]
|
||||
112
src/api/handlers/gemini_cli/adapter.py
Normal file
112
src/api/handlers/gemini_cli/adapter.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
|
||||
|
||||
继承 CliAdapterBase,处理 Gemini CLI 格式的请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
|
||||
|
||||
@register_cli_adapter
|
||||
class GeminiCliAdapter(CliAdapterBase):
|
||||
"""
|
||||
Gemini CLI API 适配器
|
||||
|
||||
处理 Gemini CLI 格式的请求(透传模式,最小验证)。
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI_CLI"
|
||||
name = "gemini.cli"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.gemini_cli.handler import GeminiCliMessageHandler
|
||||
|
||||
return GeminiCliMessageHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["GEMINI_CLI"])
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (x-goog-api-key)"""
|
||||
return request.headers.get("x-goog-api-key")
|
||||
|
||||
def _merge_path_params(
|
||||
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any] # noqa: ARG002
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并 URL 路径参数到请求体 - Gemini CLI 特化版本
|
||||
|
||||
Gemini API 特点:
|
||||
- model 不合并到请求体(Gemini 原生请求体不含 model,通过 URL 路径传递)
|
||||
- stream 不合并到请求体(Gemini API 通过 URL 端点区分流式/非流式)
|
||||
|
||||
基类已经从 path_params 获取 model 和 stream 用于日志和路由判断。
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数字典(包含 model、stream 等)
|
||||
|
||||
Returns:
|
||||
原始请求体(不合并任何 path_params)
|
||||
"""
|
||||
# Gemini: 不合并任何 path_params 到请求体
|
||||
return original_request_body.copy()
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
|
||||
"""Gemini CLI 使用 contents 字段"""
|
||||
contents = payload.get("contents", [])
|
||||
return len(contents) if isinstance(contents, list) else 0
|
||||
|
||||
def _build_audit_metadata(
|
||||
self,
|
||||
payload: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Gemini CLI 特定的审计元数据"""
|
||||
# 从 path_params 获取 model(Gemini 请求体不含 model)
|
||||
model = path_params.get("model", "unknown") if path_params else "unknown"
|
||||
contents = payload.get("contents", [])
|
||||
generation_config = payload.get("generation_config", {}) or {}
|
||||
|
||||
role_counts: Dict[str, int] = {}
|
||||
for content in contents:
|
||||
role = content.get("role", "unknown") if isinstance(content, dict) else "unknown"
|
||||
role_counts[role] = role_counts.get(role, 0) + 1
|
||||
|
||||
return {
|
||||
"action": "gemini_cli_request",
|
||||
"model": model,
|
||||
"stream": bool(payload.get("stream", False)),
|
||||
"max_output_tokens": generation_config.get("max_output_tokens"),
|
||||
"contents_count": len(contents),
|
||||
"content_roles": role_counts,
|
||||
"temperature": generation_config.get("temperature"),
|
||||
"top_p": generation_config.get("top_p"),
|
||||
"top_k": generation_config.get("top_k"),
|
||||
"tools_count": len(payload.get("tools") or []),
|
||||
"system_instruction_present": bool(payload.get("system_instruction")),
|
||||
"safety_settings_count": len(payload.get("safety_settings") or []),
|
||||
}
|
||||
|
||||
|
||||
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
|
||||
"""
|
||||
构建 Gemini CLI 适配器
|
||||
|
||||
Args:
|
||||
x_app_header: X-App 请求头值(预留扩展)
|
||||
|
||||
Returns:
|
||||
GeminiCliAdapter 实例
|
||||
"""
|
||||
return GeminiCliAdapter()
|
||||
|
||||
|
||||
__all__ = ["GeminiCliAdapter", "build_gemini_cli_adapter"]
|
||||
210
src/api/handlers/gemini_cli/handler.py
Normal file
210
src/api/handlers/gemini_cli/handler.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Gemini CLI Message Handler - 基于通用 CLI Handler 基类的实现
|
||||
|
||||
继承 CliMessageHandlerBase,处理 Gemini CLI API 格式的请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.cli_handler_base import (
|
||||
CliMessageHandlerBase,
|
||||
StreamContext,
|
||||
)
|
||||
|
||||
|
||||
class GeminiCliMessageHandler(CliMessageHandlerBase):
|
||||
"""
|
||||
Gemini CLI Message Handler - 处理 Gemini CLI API 格式
|
||||
|
||||
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
|
||||
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
|
||||
|
||||
响应格式特点:
|
||||
- Gemini 使用 JSON 数组格式流式响应(非 SSE)
|
||||
- 每个 chunk 包含 candidates、usageMetadata 等字段
|
||||
- finish_reason: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
|
||||
- Token 使用: promptTokenCount (输入), thoughtsTokenCount + candidatesTokenCount (输出), cachedContentTokenCount (缓存)
|
||||
|
||||
Gemini API 特殊处理:
|
||||
- model 在 URL 路径中而非请求体,如 /v1beta/models/{model}:generateContent
|
||||
- 请求体中的 model 字段用于内部路由,不发送给 API
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI_CLI"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any], # noqa: ARG002 - 基类签名要求
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - Gemini 格式实现
|
||||
|
||||
Gemini API 的 model 在 URL 路径中而非请求体:
|
||||
/v1beta/models/{model}:generateContent
|
||||
|
||||
Args:
|
||||
request_body: 请求体(Gemini 不包含 model)
|
||||
path_params: URL 路径参数(包含 model)
|
||||
|
||||
Returns:
|
||||
模型名,如果无法提取则返回 "unknown"
|
||||
"""
|
||||
# Gemini: model 从 URL 路径参数获取
|
||||
if path_params and "model" in path_params:
|
||||
return str(path_params["model"])
|
||||
return "unknown"
|
||||
|
||||
def prepare_provider_request_body(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
准备发送给 Gemini API 的请求体 - 移除 model 字段
|
||||
|
||||
Gemini API 要求 model 只在 URL 路径中,请求体中的 model 字段
|
||||
会导致某些代理返回 404 错误。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
|
||||
Returns:
|
||||
不含 model 字段的请求体
|
||||
"""
|
||||
result = dict(request_body)
|
||||
result.pop("model", None)
|
||||
return result
|
||||
|
||||
def get_model_for_url(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
mapped_model: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Gemini 需要将 model 放入 URL 路径中
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
mapped_model: 映射后的模型名(如果有)
|
||||
|
||||
Returns:
|
||||
用于 URL 路径的模型名
|
||||
"""
|
||||
# 优先使用映射后的模型名,否则使用请求体中的
|
||||
return mapped_model or request_body.get("model")
|
||||
|
||||
def _extract_usage_from_event(self, event: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""
|
||||
从 Gemini 事件中提取 token 使用情况
|
||||
|
||||
调用 GeminiStreamParser.extract_usage 作为单一实现源
|
||||
|
||||
Args:
|
||||
event: Gemini 流式响应事件
|
||||
|
||||
Returns:
|
||||
包含 input_tokens, output_tokens, cached_tokens 的字典
|
||||
"""
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
usage = GeminiStreamParser().extract_usage(event)
|
||||
|
||||
if not usage:
|
||||
return {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cached_tokens": 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cached_tokens": usage.get("cached_tokens", 0),
|
||||
}
|
||||
|
||||
def _process_event_data(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
_event_type: str,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
处理 Gemini CLI 格式的流式事件
|
||||
|
||||
Gemini 的流式响应是 JSON 数组格式,每个元素结构如下:
|
||||
{
|
||||
"candidates": [{
|
||||
"content": {"parts": [{"text": "..."}], "role": "model"},
|
||||
"finishReason": "STOP",
|
||||
"safetyRatings": [...]
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 10,
|
||||
"candidatesTokenCount": 20,
|
||||
"totalTokenCount": 30,
|
||||
"cachedContentTokenCount": 5
|
||||
},
|
||||
"modelVersion": "gemini-1.5-pro"
|
||||
}
|
||||
|
||||
注意: Gemini 流解析器会将每个 JSON 对象作为一个"事件"传递
|
||||
event_type 在这里可能为空或是自定义的标记
|
||||
"""
|
||||
# 提取候选响应
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
candidate = candidates[0]
|
||||
content = candidate.get("content", {})
|
||||
|
||||
# 提取文本内容
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
ctx.collected_text += part["text"]
|
||||
|
||||
# 检查结束原因
|
||||
finish_reason = candidate.get("finishReason")
|
||||
if finish_reason in ("STOP", "MAX_TOKENS", "SAFETY", "RECITATION", "OTHER"):
|
||||
ctx.has_completion = True
|
||||
ctx.final_response = data
|
||||
|
||||
# 提取使用量信息(复用 GeminiStreamParser.extract_usage)
|
||||
usage = self._extract_usage_from_event(data)
|
||||
if usage["input_tokens"] > 0 or usage["output_tokens"] > 0:
|
||||
ctx.input_tokens = usage["input_tokens"]
|
||||
ctx.output_tokens = usage["output_tokens"]
|
||||
ctx.cached_tokens = usage["cached_tokens"]
|
||||
|
||||
# 提取模型版本作为响应 ID
|
||||
model_version = data.get("modelVersion")
|
||||
if model_version:
|
||||
if not ctx.response_id:
|
||||
ctx.response_id = f"gemini-{model_version}"
|
||||
# 存储到 response_metadata 供 Usage 记录使用
|
||||
ctx.response_metadata["model_version"] = model_version
|
||||
|
||||
# 检查错误
|
||||
if "error" in data:
|
||||
ctx.has_completion = True
|
||||
ctx.final_response = data
|
||||
|
||||
def _extract_response_metadata(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从 Gemini 响应中提取元数据
|
||||
|
||||
提取 modelVersion 字段,记录实际使用的模型版本。
|
||||
|
||||
Args:
|
||||
response: Gemini API 响应
|
||||
|
||||
Returns:
|
||||
包含 model_version 的元数据字典
|
||||
"""
|
||||
metadata: Dict[str, Any] = {}
|
||||
model_version = response.get("modelVersion")
|
||||
if model_version:
|
||||
metadata["model_version"] = model_version
|
||||
return metadata
|
||||
11
src/api/handlers/openai/__init__.py
Normal file
11
src/api/handlers/openai/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
OpenAI Chat API 处理器
|
||||
"""
|
||||
|
||||
from src.api.handlers.openai.adapter import OpenAIChatAdapter
|
||||
from src.api.handlers.openai.handler import OpenAIChatHandler
|
||||
|
||||
__all__ = [
|
||||
"OpenAIChatAdapter",
|
||||
"OpenAIChatHandler",
|
||||
]
|
||||
109
src/api/handlers/openai/adapter.py
Normal file
109
src/api/handlers/openai/adapter.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
|
||||
|
||||
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.core.logger import logger
|
||||
from src.models.openai import OpenAIRequest
|
||||
|
||||
|
||||
@register_adapter
|
||||
class OpenAIChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
OpenAI Chat Completions API 适配器
|
||||
|
||||
处理 OpenAI Chat 格式的请求(/v1/chat/completions 端点)。
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI"
|
||||
name = "openai.chat"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.openai.handler import OpenAIChatHandler
|
||||
|
||||
return OpenAIChatHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["OPENAI"])
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
|
||||
authorization = request.headers.get("authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
||||
"""验证请求体"""
|
||||
if not isinstance(original_request_body, dict):
|
||||
return self._error_response(
|
||||
400, "Request body must be a JSON object", "invalid_request_error"
|
||||
)
|
||||
|
||||
required_fields = ["model", "messages"]
|
||||
missing = [f for f in required_fields if f not in original_request_body]
|
||||
if missing:
|
||||
return self._error_response(
|
||||
400,
|
||||
f"Missing required fields: {', '.join(missing)}",
|
||||
"invalid_request_error",
|
||||
)
|
||||
|
||||
try:
|
||||
return OpenAIRequest.model_validate(original_request_body, strict=False)
|
||||
except ValueError as e:
|
||||
return self._error_response(400, str(e), "invalid_request_error")
|
||||
except Exception as e:
|
||||
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
|
||||
return OpenAIRequest.model_construct(
|
||||
model=original_request_body.get("model"),
|
||||
messages=original_request_body.get("messages", []),
|
||||
stream=original_request_body.get("stream", False),
|
||||
max_tokens=original_request_body.get("max_tokens"),
|
||||
)
|
||||
|
||||
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
||||
"""构建 OpenAI Chat 特定的审计元数据"""
|
||||
role_counts = {}
|
||||
for message in request_obj.messages:
|
||||
role_counts[message.role] = role_counts.get(message.role, 0) + 1
|
||||
|
||||
return {
|
||||
"action": "openai_chat_completion",
|
||||
"model": request_obj.model,
|
||||
"stream": bool(request_obj.stream),
|
||||
"max_tokens": request_obj.max_tokens,
|
||||
"temperature": request_obj.temperature,
|
||||
"top_p": request_obj.top_p,
|
||||
"messages_count": len(request_obj.messages),
|
||||
"message_roles": role_counts,
|
||||
"tools_count": len(request_obj.tools or []),
|
||||
"response_format": bool(request_obj.response_format),
|
||||
"user_identifier": request_obj.user,
|
||||
}
|
||||
|
||||
def _error_response(self, status_code: int, message: str, error_type: str) -> JSONResponse:
|
||||
"""生成 OpenAI 格式的错误响应"""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
"code": status_code,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["OpenAIChatAdapter"]
|
||||
424
src/api/handlers/openai/converter.py
Normal file
424
src/api/handlers/openai/converter.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
Claude -> OpenAI 格式转换器
|
||||
|
||||
将 Claude Messages API 格式转换为 OpenAI Chat Completions API 格式。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ClaudeToOpenAIConverter:
|
||||
"""
|
||||
Claude -> OpenAI 格式转换器
|
||||
|
||||
支持:
|
||||
- 请求转换:Claude Request -> OpenAI Chat Request
|
||||
- 响应转换:Claude Response -> OpenAI Chat Response
|
||||
- 流式转换:Claude SSE -> OpenAI SSE
|
||||
"""
|
||||
|
||||
# 内容类型常量
|
||||
CONTENT_TYPE_TEXT = "text"
|
||||
CONTENT_TYPE_IMAGE = "image"
|
||||
CONTENT_TYPE_TOOL_USE = "tool_use"
|
||||
CONTENT_TYPE_TOOL_RESULT = "tool_result"
|
||||
|
||||
# 停止原因映射
|
||||
STOP_REASON_MAP = {
|
||||
"end_turn": "stop",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"tool_use": "tool_calls",
|
||||
}
|
||||
|
||||
def __init__(self, model_mapping: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Args:
|
||||
model_mapping: Claude 模型到 OpenAI 模型的映射
|
||||
"""
|
||||
self._model_mapping = model_mapping or {}
|
||||
|
||||
# ==================== 请求转换 ====================
|
||||
|
||||
def convert_request(self, request: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Claude 请求转换为 OpenAI 格式
|
||||
|
||||
Args:
|
||||
request: Claude 请求(Dict 或 Pydantic 模型)
|
||||
|
||||
Returns:
|
||||
OpenAI 格式的请求字典
|
||||
"""
|
||||
if hasattr(request, "model_dump"):
|
||||
data = request.model_dump(exclude_none=True)
|
||||
else:
|
||||
data = dict(request)
|
||||
|
||||
# 模型映射
|
||||
model = data.get("model", "")
|
||||
openai_model = self._model_mapping.get(model, model)
|
||||
|
||||
# 构建消息列表
|
||||
messages: List[Dict[str, Any]] = []
|
||||
|
||||
# 处理 system 消息
|
||||
system_content = self._extract_text_content(data.get("system"))
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
|
||||
# 处理对话消息
|
||||
for message in data.get("messages", []):
|
||||
converted = self._convert_message(message)
|
||||
if converted:
|
||||
messages.append(converted)
|
||||
|
||||
# 构建 OpenAI 请求
|
||||
result: Dict[str, Any] = {
|
||||
"model": openai_model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
# 可选参数
|
||||
if data.get("max_tokens"):
|
||||
result["max_tokens"] = data["max_tokens"]
|
||||
if data.get("temperature") is not None:
|
||||
result["temperature"] = data["temperature"]
|
||||
if data.get("top_p") is not None:
|
||||
result["top_p"] = data["top_p"]
|
||||
if data.get("stream"):
|
||||
result["stream"] = data["stream"]
|
||||
if data.get("stop_sequences"):
|
||||
result["stop"] = data["stop_sequences"]
|
||||
|
||||
# 工具转换
|
||||
tools = self._convert_tools(data.get("tools"))
|
||||
if tools:
|
||||
result["tools"] = tools
|
||||
|
||||
tool_choice = self._convert_tool_choice(data.get("tool_choice"))
|
||||
if tool_choice:
|
||||
result["tool_choice"] = tool_choice
|
||||
|
||||
return result
|
||||
|
||||
def _convert_message(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""转换单条消息"""
|
||||
role = message.get("role")
|
||||
|
||||
if role == "user":
|
||||
return self._convert_user_message(message)
|
||||
if role == "assistant":
|
||||
return self._convert_assistant_message(message)
|
||||
|
||||
return None
|
||||
|
||||
def _convert_user_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""转换用户消息"""
|
||||
content = message.get("content")
|
||||
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
openai_content: List[Dict[str, Any]] = []
|
||||
for block in content or []:
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == self.CONTENT_TYPE_TEXT:
|
||||
openai_content.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == self.CONTENT_TYPE_IMAGE:
|
||||
source = block.get("source", {})
|
||||
media_type = source.get("media_type", "image/jpeg")
|
||||
data = source.get("data", "")
|
||||
openai_content.append(
|
||||
{"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{data}"}}
|
||||
)
|
||||
elif block_type == self.CONTENT_TYPE_TOOL_RESULT:
|
||||
tool_content = block.get("content", "")
|
||||
rendered = self._render_tool_content(tool_content)
|
||||
openai_content.append({"type": "text", "text": f"Tool result: {rendered}"})
|
||||
|
||||
# 简化单文本内容
|
||||
if len(openai_content) == 1 and openai_content[0]["type"] == "text":
|
||||
return {"role": "user", "content": openai_content[0]["text"]}
|
||||
|
||||
return {"role": "user", "content": openai_content or ""}
|
||||
|
||||
def _convert_assistant_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""转换助手消息"""
|
||||
content = message.get("content")
|
||||
text_parts: List[str] = []
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
|
||||
if isinstance(content, str):
|
||||
text_parts.append(content)
|
||||
else:
|
||||
for idx, block in enumerate(content or []):
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == self.CONTENT_TYPE_TEXT:
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif block_type == self.CONTENT_TYPE_TOOL_USE:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id", f"call_{idx}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name", ""),
|
||||
"arguments": json.dumps(block.get("input", {}), ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result: Dict[str, Any] = {"role": "assistant"}
|
||||
|
||||
message_content = "\n".join([p for p in text_parts if p]) or None
|
||||
if message_content:
|
||||
result["content"] = message_content
|
||||
|
||||
if tool_calls:
|
||||
result["tool_calls"] = tool_calls
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: Optional[List[Dict[str, Any]]]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""转换工具定义"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
result: List[Dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
result.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description"),
|
||||
"parameters": tool.get("input_schema", {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def _convert_tool_choice(
|
||||
self, tool_choice: Optional[Dict[str, Any]]
|
||||
) -> Optional[Union[str, Dict[str, Any]]]:
|
||||
"""转换工具选择"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
|
||||
choice_type = tool_choice.get("type")
|
||||
if choice_type in ("tool", "tool_use"):
|
||||
return {"type": "function", "function": {"name": tool_choice.get("name", "")}}
|
||||
if choice_type == "any":
|
||||
return "required"
|
||||
if choice_type == "auto":
|
||||
return "auto"
|
||||
|
||||
return tool_choice
|
||||
|
||||
# ==================== 响应转换 ====================
|
||||
|
||||
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Claude 响应转换为 OpenAI 格式
|
||||
|
||||
Args:
|
||||
response: Claude 响应字典
|
||||
|
||||
Returns:
|
||||
OpenAI 格式的响应字典
|
||||
"""
|
||||
# 提取内容
|
||||
content_parts: List[str] = []
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
|
||||
for idx, block in enumerate(response.get("content", [])):
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == self.CONTENT_TYPE_TEXT:
|
||||
content_parts.append(block.get("text", ""))
|
||||
elif block_type == self.CONTENT_TYPE_TOOL_USE:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id", f"call_{idx}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name", ""),
|
||||
"arguments": json.dumps(block.get("input", {}), ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 构建消息
|
||||
message: Dict[str, Any] = {"role": "assistant"}
|
||||
text_content = "\n".join([p for p in content_parts if p]) or None
|
||||
if text_content:
|
||||
message["content"] = text_content
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
|
||||
# 转换停止原因
|
||||
stop_reason = response.get("stop_reason")
|
||||
finish_reason = self.STOP_REASON_MAP.get(stop_reason, stop_reason)
|
||||
|
||||
# 转换 usage
|
||||
usage = response.get("usage", {})
|
||||
openai_usage = {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": (usage.get("input_tokens", 0) + usage.get("output_tokens", 0)),
|
||||
}
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{response.get('id', uuid.uuid4().hex[:8])}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": response.get("model", ""),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": openai_usage,
|
||||
}
|
||||
|
||||
# ==================== 流式转换 ====================
|
||||
|
||||
def convert_stream_event(
|
||||
self,
|
||||
event: Dict[str, Any],
|
||||
model: str = "",
|
||||
message_id: Optional[str] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
将 Claude SSE 事件转换为 OpenAI 格式
|
||||
|
||||
Args:
|
||||
event: Claude SSE 事件
|
||||
model: 模型名称
|
||||
message_id: 消息 ID
|
||||
|
||||
Returns:
|
||||
OpenAI 格式的 SSE chunk,如果无法转换返回 None
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
chunk_id = f"chatcmpl-{(message_id or 'stream')[-8:]}"
|
||||
|
||||
if event_type == "message_start":
|
||||
message = event.get("message", {})
|
||||
return self._base_chunk(
|
||||
chunk_id,
|
||||
model or message.get("model", ""),
|
||||
{"role": "assistant"},
|
||||
)
|
||||
|
||||
if event_type == "content_block_start":
|
||||
content_block = event.get("content_block", {})
|
||||
if content_block.get("type") == self.CONTENT_TYPE_TOOL_USE:
|
||||
delta = {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": event.get("index", 0),
|
||||
"id": content_block.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": content_block.get("name", ""),
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
return self._base_chunk(chunk_id, model, delta)
|
||||
return None
|
||||
|
||||
if event_type == "content_block_delta":
|
||||
delta_payload = event.get("delta", {})
|
||||
delta_type = delta_payload.get("type")
|
||||
|
||||
if delta_type == "text_delta":
|
||||
delta = {"content": delta_payload.get("text", "")}
|
||||
return self._base_chunk(chunk_id, model, delta)
|
||||
|
||||
if delta_type == "input_json_delta":
|
||||
delta = {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": event.get("index", 0),
|
||||
"function": {"arguments": delta_payload.get("partial_json", "")},
|
||||
}
|
||||
]
|
||||
}
|
||||
return self._base_chunk(chunk_id, model, delta)
|
||||
return None
|
||||
|
||||
if event_type == "message_delta":
|
||||
delta = event.get("delta", {})
|
||||
stop_reason = delta.get("stop_reason")
|
||||
finish_reason = self.STOP_REASON_MAP.get(stop_reason, stop_reason)
|
||||
return self._base_chunk(chunk_id, model, {}, finish_reason=finish_reason)
|
||||
|
||||
if event_type == "message_stop":
|
||||
return self._base_chunk(chunk_id, model, {}, finish_reason="stop")
|
||||
|
||||
return None
|
||||
|
||||
def _base_chunk(
|
||||
self,
|
||||
chunk_id: str,
|
||||
model: str,
|
||||
delta: Dict[str, Any],
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建基础 OpenAI chunk"""
|
||||
return {
|
||||
"id": chunk_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"system_fingerprint": None,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# ==================== 工具方法 ====================
|
||||
|
||||
def _extract_text_content(
|
||||
self, content: Optional[Union[str, List[Dict[str, Any]]]]
|
||||
) -> Optional[str]:
|
||||
"""提取文本内容"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
block.get("text", "")
|
||||
for block in content
|
||||
if block.get("type") == self.CONTENT_TYPE_TEXT
|
||||
]
|
||||
return "\n\n".join(filter(None, parts)) or None
|
||||
return None
|
||||
|
||||
def _render_tool_content(self, tool_content: Any) -> str:
|
||||
"""渲染工具内容"""
|
||||
if isinstance(tool_content, list):
|
||||
return json.dumps(tool_content, ensure_ascii=False)
|
||||
return str(tool_content)
|
||||
|
||||
|
||||
__all__ = ["ClaudeToOpenAIConverter"]
|
||||
137
src/api/handlers/openai/handler.py
Normal file
137
src/api/handlers/openai/handler.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
OpenAI Chat Handler - 基于通用 Chat Handler 基类的简化实现
|
||||
|
||||
继承 ChatHandlerBase,只需覆盖格式特定的方法。
|
||||
代码量从原来的 ~1315 行减少到 ~100 行。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
|
||||
|
||||
class OpenAIChatHandler(ChatHandlerBase):
|
||||
"""
|
||||
OpenAI Chat Handler - 处理 OpenAI Chat Completions API 格式的请求
|
||||
|
||||
格式特点:
|
||||
- 使用 prompt_tokens/completion_tokens
|
||||
- 不支持 cache tokens
|
||||
- 请求格式:OpenAIRequest
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - OpenAI 格式实现
|
||||
|
||||
OpenAI API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
path_params: URL 路径参数(OpenAI 不使用)
|
||||
|
||||
Returns:
|
||||
模型名
|
||||
"""
|
||||
model = request_body.get("model")
|
||||
return str(model) if model else "unknown"
|
||||
|
||||
def apply_mapped_model(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
mapped_model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将映射后的模型名应用到请求体
|
||||
|
||||
OpenAI API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 原始请求体
|
||||
mapped_model: 映射后的模型名
|
||||
|
||||
Returns:
|
||||
更新了 model 字段的请求体
|
||||
"""
|
||||
result = dict(request_body)
|
||||
result["model"] = mapped_model
|
||||
return result
|
||||
|
||||
async def _convert_request(self, request):
|
||||
"""
|
||||
将请求转换为 OpenAI 格式
|
||||
|
||||
Args:
|
||||
request: 原始请求对象
|
||||
|
||||
Returns:
|
||||
OpenAIRequest 对象
|
||||
"""
|
||||
from src.api.handlers.openai.converter import ClaudeToOpenAIConverter
|
||||
from src.models.claude import ClaudeMessagesRequest
|
||||
from src.models.openai import OpenAIRequest
|
||||
|
||||
# 如果已经是 OpenAI 格式,直接返回
|
||||
if isinstance(request, OpenAIRequest):
|
||||
return request
|
||||
|
||||
# 如果是 Claude 格式,转换为 OpenAI 格式
|
||||
if isinstance(request, ClaudeMessagesRequest):
|
||||
converter = ClaudeToOpenAIConverter()
|
||||
openai_dict = converter.convert_request(request.dict())
|
||||
return OpenAIRequest(**openai_dict)
|
||||
|
||||
# 如果是字典,尝试判断格式
|
||||
if isinstance(request, dict):
|
||||
try:
|
||||
return OpenAIRequest(**request)
|
||||
except Exception:
|
||||
try:
|
||||
converter = ClaudeToOpenAIConverter()
|
||||
openai_dict = converter.convert_request(request)
|
||||
return OpenAIRequest(**openai_dict)
|
||||
except Exception:
|
||||
return OpenAIRequest(**request)
|
||||
|
||||
return request
|
||||
|
||||
def _extract_usage(self, response: Dict) -> Dict[str, int]:
|
||||
"""
|
||||
从 OpenAI 响应中提取 token 使用情况
|
||||
|
||||
OpenAI 格式使用:
|
||||
- prompt_tokens / completion_tokens
|
||||
- 不支持 cache tokens
|
||||
"""
|
||||
usage = response.get("usage", {})
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("prompt_tokens", 0),
|
||||
"output_tokens": usage.get("completion_tokens", 0),
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
}
|
||||
|
||||
def _normalize_response(self, response: Dict) -> Dict:
|
||||
"""
|
||||
规范化 OpenAI 响应
|
||||
|
||||
Args:
|
||||
response: 原始响应
|
||||
|
||||
Returns:
|
||||
规范化后的响应
|
||||
"""
|
||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
||||
return self.response_normalizer.normalize_openai_response(
|
||||
response_data=response,
|
||||
request_id=self.request_id,
|
||||
strict=False,
|
||||
)
|
||||
return response
|
||||
181
src/api/handlers/openai/stream_parser.py
Normal file
181
src/api/handlers/openai/stream_parser.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
OpenAI SSE 流解析器
|
||||
|
||||
解析 OpenAI Chat Completions API 的 Server-Sent Events 流。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class OpenAIStreamParser:
|
||||
"""
|
||||
OpenAI SSE 流解析器
|
||||
|
||||
解析 OpenAI Chat Completions API 的 SSE 事件流。
|
||||
|
||||
OpenAI 流格式:
|
||||
- 每个 chunk 是一个 JSON 对象,包含 choices 数组
|
||||
- choices[0].delta 包含增量内容
|
||||
- choices[0].finish_reason 表示结束原因
|
||||
- 流结束时发送 data: [DONE]
|
||||
"""
|
||||
|
||||
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析 SSE 数据块
|
||||
|
||||
Args:
|
||||
chunk: 原始 SSE 数据(bytes 或 str)
|
||||
|
||||
Returns:
|
||||
解析后的 chunk 列表
|
||||
"""
|
||||
if isinstance(chunk, bytes):
|
||||
text = chunk.decode("utf-8")
|
||||
else:
|
||||
text = chunk
|
||||
|
||||
chunks: List[Dict[str, Any]] = []
|
||||
lines = text.strip().split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 解析数据行
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
|
||||
# 处理 [DONE] 标记
|
||||
if data_str == "[DONE]":
|
||||
chunks.append({"__done__": True})
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
chunks.append(data)
|
||||
except json.JSONDecodeError:
|
||||
# 无法解析的数据,跳过
|
||||
pass
|
||||
|
||||
return chunks
|
||||
|
||||
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
解析单行 SSE 数据
|
||||
|
||||
Args:
|
||||
line: SSE 数据行(已去除 "data: " 前缀)
|
||||
|
||||
Returns:
|
||||
解析后的 chunk 字典,如果无法解析返回 None
|
||||
"""
|
||||
if not line or line == "[DONE]":
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def is_done_chunk(self, chunk: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断是否为结束 chunk
|
||||
|
||||
Args:
|
||||
chunk: chunk 字典
|
||||
|
||||
Returns:
|
||||
True 如果是结束 chunk
|
||||
"""
|
||||
# 内部标记
|
||||
if chunk.get("__done__"):
|
||||
return True
|
||||
|
||||
# 检查 finish_reason
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
finish_reason = choices[0].get("finish_reason")
|
||||
return finish_reason is not None
|
||||
|
||||
return False
|
||||
|
||||
def get_finish_reason(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
获取结束原因
|
||||
|
||||
Args:
|
||||
chunk: chunk 字典
|
||||
|
||||
Returns:
|
||||
结束原因字符串
|
||||
"""
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
return choices[0].get("finish_reason")
|
||||
return None
|
||||
|
||||
def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从 chunk 中提取文本增量
|
||||
|
||||
Args:
|
||||
chunk: chunk 字典
|
||||
|
||||
Returns:
|
||||
文本增量,如果没有返回 None
|
||||
"""
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content")
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
def extract_tool_calls_delta(self, chunk: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
从 chunk 中提取工具调用增量
|
||||
|
||||
Args:
|
||||
chunk: chunk 字典
|
||||
|
||||
Returns:
|
||||
工具调用列表,如果没有返回 None
|
||||
"""
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
return delta.get("tool_calls")
|
||||
|
||||
def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从 chunk 中提取角色
|
||||
|
||||
通常只在第一个 chunk 中出现。
|
||||
|
||||
Args:
|
||||
chunk: chunk 字典
|
||||
|
||||
Returns:
|
||||
角色字符串
|
||||
"""
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
return delta.get("role")
|
||||
|
||||
|
||||
__all__ = ["OpenAIStreamParser"]
|
||||
11
src/api/handlers/openai_cli/__init__.py
Normal file
11
src/api/handlers/openai_cli/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
OpenAI CLI 透传处理器
|
||||
"""
|
||||
|
||||
from src.api.handlers.openai_cli.adapter import OpenAICliAdapter
|
||||
from src.api.handlers.openai_cli.handler import OpenAICliMessageHandler
|
||||
|
||||
__all__ = [
|
||||
"OpenAICliAdapter",
|
||||
"OpenAICliMessageHandler",
|
||||
]
|
||||
44
src/api/handlers/openai_cli/adapter.py
Normal file
44
src/api/handlers/openai_cli/adapter.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
||||
|
||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||
"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
|
||||
|
||||
@register_cli_adapter
|
||||
class OpenAICliAdapter(CliAdapterBase):
|
||||
"""
|
||||
OpenAI CLI API 适配器
|
||||
|
||||
处理 /v1/responses 端点的请求。
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI_CLI"
|
||||
name = "openai.cli"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.openai_cli.handler import OpenAICliMessageHandler
|
||||
|
||||
return OpenAICliMessageHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["OPENAI_CLI"])
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
|
||||
authorization = request.headers.get("authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["OpenAICliAdapter"]
|
||||
211
src/api/handlers/openai_cli/handler.py
Normal file
211
src/api/handlers/openai_cli/handler.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
OpenAI CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
|
||||
|
||||
继承 CliMessageHandlerBase,只需覆盖格式特定的配置和事件处理逻辑。
|
||||
代码量从原来的 900+ 行减少到 ~100 行。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.cli_handler_base import (
|
||||
CliMessageHandlerBase,
|
||||
StreamContext,
|
||||
)
|
||||
|
||||
|
||||
class OpenAICliMessageHandler(CliMessageHandlerBase):
|
||||
"""
|
||||
OpenAI CLI Message Handler - 处理 OpenAI CLI Responses API 格式
|
||||
|
||||
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
|
||||
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
|
||||
|
||||
响应格式特点:
|
||||
- 使用 output[] 数组而非 content[]
|
||||
- 使用 output_text 类型而非普通 text
|
||||
- 流式事件:response.output_text.delta, response.output_text.done
|
||||
|
||||
模型字段:请求体顶级 model 字段
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI_CLI"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - OpenAI 格式实现
|
||||
|
||||
OpenAI API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
path_params: URL 路径参数(OpenAI 不使用)
|
||||
|
||||
Returns:
|
||||
模型名
|
||||
"""
|
||||
model = request_body.get("model")
|
||||
return str(model) if model else "unknown"
|
||||
|
||||
def apply_mapped_model(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
mapped_model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
OpenAI CLI (Responses API) 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 原始请求体
|
||||
mapped_model: 映射后的模型名
|
||||
|
||||
Returns:
|
||||
更新了 model 字段的请求体
|
||||
"""
|
||||
result = dict(request_body)
|
||||
result["model"] = mapped_model
|
||||
return result
|
||||
|
||||
def _process_event_data(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
event_type: str,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
处理 OpenAI CLI 格式的 SSE 事件
|
||||
|
||||
事件类型:
|
||||
- response.output_text.delta: 文本增量
|
||||
- response.completed: 响应完成(包含 usage)
|
||||
"""
|
||||
# 提取 response_id
|
||||
if not ctx.response_id:
|
||||
response_obj = data.get("response")
|
||||
if isinstance(response_obj, dict) and response_obj.get("id"):
|
||||
ctx.response_id = response_obj["id"]
|
||||
elif "id" in data:
|
||||
ctx.response_id = data["id"]
|
||||
|
||||
# 处理文本增量
|
||||
if event_type in ["response.output_text.delta", "response.outtext.delta"]:
|
||||
delta = data.get("delta")
|
||||
if isinstance(delta, str):
|
||||
ctx.collected_text += delta
|
||||
elif isinstance(delta, dict) and "text" in delta:
|
||||
ctx.collected_text += delta["text"]
|
||||
|
||||
# 处理完成事件
|
||||
elif event_type == "response.completed":
|
||||
ctx.has_completion = True
|
||||
response_obj = data.get("response")
|
||||
if isinstance(response_obj, dict):
|
||||
ctx.final_response = response_obj
|
||||
|
||||
usage_obj = response_obj.get("usage")
|
||||
if isinstance(usage_obj, dict):
|
||||
ctx.final_usage = usage_obj
|
||||
ctx.input_tokens = usage_obj.get("input_tokens", 0)
|
||||
ctx.output_tokens = usage_obj.get("output_tokens", 0)
|
||||
|
||||
details = usage_obj.get("input_tokens_details")
|
||||
if isinstance(details, dict):
|
||||
ctx.cached_tokens = details.get("cached_tokens", 0)
|
||||
|
||||
# 如果没有收集到文本,从 output 中提取
|
||||
if not ctx.collected_text and "output" in response_obj:
|
||||
for output_item in response_obj.get("output", []):
|
||||
if output_item.get("type") != "message":
|
||||
continue
|
||||
for content_item in output_item.get("content", []):
|
||||
if content_item.get("type") == "output_text":
|
||||
text = content_item.get("text", "")
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
|
||||
# 备用:从顶层 usage 提取
|
||||
usage_obj = data.get("usage")
|
||||
if isinstance(usage_obj, dict) and not ctx.final_usage:
|
||||
ctx.final_usage = usage_obj
|
||||
ctx.input_tokens = usage_obj.get("input_tokens", 0)
|
||||
ctx.output_tokens = usage_obj.get("output_tokens", 0)
|
||||
|
||||
details = usage_obj.get("input_tokens_details")
|
||||
if isinstance(details, dict):
|
||||
ctx.cached_tokens = details.get("cached_tokens", 0)
|
||||
|
||||
# 备用:从 response 字段提取
|
||||
response_obj = data.get("response")
|
||||
if isinstance(response_obj, dict) and not ctx.final_response:
|
||||
ctx.final_response = response_obj
|
||||
|
||||
def _extract_response_metadata(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从 OpenAI 响应中提取元数据
|
||||
|
||||
提取 model、status、response_id 等字段作为元数据。
|
||||
|
||||
Args:
|
||||
response: OpenAI API 响应
|
||||
|
||||
Returns:
|
||||
提取的元数据字典
|
||||
"""
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
# 提取模型名称(实际使用的模型)
|
||||
if "model" in response:
|
||||
metadata["model"] = response["model"]
|
||||
|
||||
# 提取响应 ID
|
||||
if "id" in response:
|
||||
metadata["response_id"] = response["id"]
|
||||
|
||||
# 提取状态
|
||||
if "status" in response:
|
||||
metadata["status"] = response["status"]
|
||||
|
||||
# 提取对象类型
|
||||
if "object" in response:
|
||||
metadata["object"] = response["object"]
|
||||
|
||||
# 提取系统指纹(如果存在)
|
||||
if "system_fingerprint" in response:
|
||||
metadata["system_fingerprint"] = response["system_fingerprint"]
|
||||
|
||||
return metadata
|
||||
|
||||
def _finalize_stream_metadata(self, ctx: StreamContext) -> None:
|
||||
"""
|
||||
从流上下文中提取最终元数据
|
||||
|
||||
在流传输完成后调用,从收集的事件中提取元数据。
|
||||
|
||||
Args:
|
||||
ctx: 流上下文
|
||||
"""
|
||||
# 从 response_id 提取响应 ID
|
||||
if ctx.response_id:
|
||||
ctx.response_metadata["response_id"] = ctx.response_id
|
||||
|
||||
# 从 final_response 提取更多元数据
|
||||
if ctx.final_response and isinstance(ctx.final_response, dict):
|
||||
if "model" in ctx.final_response:
|
||||
ctx.response_metadata["model"] = ctx.final_response["model"]
|
||||
if "status" in ctx.final_response:
|
||||
ctx.response_metadata["status"] = ctx.final_response["status"]
|
||||
if "object" in ctx.final_response:
|
||||
ctx.response_metadata["object"] = ctx.final_response["object"]
|
||||
if "system_fingerprint" in ctx.final_response:
|
||||
ctx.response_metadata["system_fingerprint"] = ctx.final_response["system_fingerprint"]
|
||||
|
||||
# 如果没有从响应中获取到 model,使用上下文中的
|
||||
if "model" not in ctx.response_metadata and ctx.model:
|
||||
ctx.response_metadata["model"] = ctx.model
|
||||
|
||||
10
src/api/monitoring/__init__.py
Normal file
10
src/api/monitoring/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""User monitoring routers."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .user import router as monitoring_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(monitoring_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
148
src/api/monitoring/user.py
Normal file
148
src/api/monitoring/user.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""普通用户可访问的监控与审计端点。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_query
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import ApiKey, AuditLog
|
||||
from src.plugins.manager import get_plugin_manager
|
||||
|
||||
router = APIRouter(prefix="/api/monitoring", tags=["Monitoring"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/my-audit-logs")
|
||||
async def get_my_audit_logs(
|
||||
request: Request,
|
||||
event_type: Optional[str] = Query(None, description="事件类型筛选"),
|
||||
days: int = Query(30, description="查询天数"),
|
||||
limit: int = Query(50, description="返回数量限制"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = UserAuditLogsAdapter(event_type=event_type, days=days, limit=limit, offset=offset)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/rate-limit-status")
|
||||
async def get_rate_limit_status(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = UserRateLimitStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AuthenticatedApiAdapter(ApiAdapter):
|
||||
"""需要用户登录的适配器基类。"""
|
||||
|
||||
mode = ApiMode.USER
|
||||
|
||||
def authorize(self, context): # type: ignore[override]
|
||||
if not context.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserAuditLogsAdapter(AuthenticatedApiAdapter):
|
||||
event_type: Optional[str]
|
||||
days: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = context.user
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
query = db.query(AuditLog).filter(AuditLog.user_id == user.id)
|
||||
if self.event_type:
|
||||
query = query.filter(AuditLog.event_type == self.event_type)
|
||||
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(days=self.days)
|
||||
query = query.filter(AuditLog.created_at >= cutoff_time)
|
||||
query = query.order_by(AuditLog.created_at.desc())
|
||||
|
||||
total, logs = paginate_query(query, self.limit, self.offset)
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": log.id,
|
||||
"event_type": log.event_type,
|
||||
"description": log.description,
|
||||
"ip_address": log.ip_address,
|
||||
"status_code": log.status_code,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
|
||||
meta = PaginationMeta(
|
||||
total=total,
|
||||
limit=self.limit,
|
||||
offset=self.offset,
|
||||
count=len(items),
|
||||
)
|
||||
|
||||
return build_pagination_payload(
|
||||
items,
|
||||
meta,
|
||||
filters={
|
||||
"event_type": self.event_type,
|
||||
"days": self.days,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class UserRateLimitStatusAdapter(AuthenticatedApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = context.user
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
rate_limiter = _get_rate_limit_plugin()
|
||||
if not rate_limiter or not hasattr(rate_limiter, "get_rate_limit_headers"):
|
||||
raise HTTPException(status_code=503, detail="速率限制插件未启用或不支持状态查询")
|
||||
|
||||
api_keys = (
|
||||
db.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user.id, ApiKey.is_active.is_(True))
|
||||
.order_by(ApiKey.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
rate_limit_info = []
|
||||
for key in api_keys:
|
||||
try:
|
||||
headers = rate_limiter.get_rate_limit_headers(key)
|
||||
except Exception as exc:
|
||||
logger.warning(f"无法获取Key {key.id} 的限流信息: {exc}")
|
||||
headers = {}
|
||||
|
||||
rate_limit_info.append(
|
||||
{
|
||||
"api_key_name": key.name or f"Key-{key.id}",
|
||||
"limit": headers.get("X-RateLimit-Limit"),
|
||||
"remaining": headers.get("X-RateLimit-Remaining"),
|
||||
"reset_time": headers.get("X-RateLimit-Reset"),
|
||||
"window": headers.get("X-RateLimit-Window"),
|
||||
}
|
||||
)
|
||||
|
||||
return {"user_id": user.id, "api_keys": rate_limit_info}
|
||||
|
||||
|
||||
def _get_rate_limit_plugin():
|
||||
try:
|
||||
plugin_manager = get_plugin_manager()
|
||||
return plugin_manager.get_plugin("rate_limit")
|
||||
except Exception as exc:
|
||||
logger.warning(f"获取速率限制插件失败: {exc}")
|
||||
return None
|
||||
20
src/api/public/__init__.py
Normal file
20
src/api/public/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Public-facing API routers."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .capabilities import router as capabilities_router
|
||||
from .catalog import router as catalog_router
|
||||
from .claude import router as claude_router
|
||||
from .gemini import router as gemini_router
|
||||
from .openai import router as openai_router
|
||||
from .system_catalog import router as system_catalog_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(claude_router, tags=["Claude API"])
|
||||
router.include_router(openai_router)
|
||||
router.include_router(gemini_router, tags=["Gemini API"])
|
||||
router.include_router(system_catalog_router, tags=["System Catalog"])
|
||||
router.include_router(catalog_router)
|
||||
router.include_router(capabilities_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
104
src/api/public/capabilities.py
Normal file
104
src/api/public/capabilities.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
能力配置公共 API
|
||||
|
||||
提供系统支持的能力列表,供前端展示和配置使用。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.key_capabilities import (
|
||||
get_all_capabilities,
|
||||
get_user_configurable_capabilities,
|
||||
)
|
||||
from src.database import get_db
|
||||
|
||||
router = APIRouter(prefix="/api/capabilities", tags=["Capabilities"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_capabilities():
|
||||
"""获取所有能力定义"""
|
||||
return {
|
||||
"capabilities": [
|
||||
{
|
||||
"name": cap.name,
|
||||
"display_name": cap.display_name,
|
||||
"short_name": cap.short_name,
|
||||
"description": cap.description,
|
||||
"match_mode": cap.match_mode.value,
|
||||
"config_mode": cap.config_mode.value,
|
||||
}
|
||||
for cap in get_all_capabilities()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/user-configurable")
|
||||
async def list_user_configurable_capabilities():
|
||||
"""获取用户可配置的能力列表(用于前端展示配置选项)"""
|
||||
return {
|
||||
"capabilities": [
|
||||
{
|
||||
"name": cap.name,
|
||||
"display_name": cap.display_name,
|
||||
"short_name": cap.short_name,
|
||||
"description": cap.description,
|
||||
"match_mode": cap.match_mode.value,
|
||||
"config_mode": cap.config_mode.value,
|
||||
}
|
||||
for cap in get_user_configurable_capabilities()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/model/{model_name}")
|
||||
async def get_model_supported_capabilities(
|
||||
model_name: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定模型支持的能力列表
|
||||
|
||||
Args:
|
||||
model_name: 模型名称(如 claude-sonnet-4-20250514)
|
||||
|
||||
Returns:
|
||||
模型支持的能力列表,以及每个能力的详细定义
|
||||
"""
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
|
||||
|
||||
if not global_model:
|
||||
return {
|
||||
"model": model_name,
|
||||
"supported_capabilities": [],
|
||||
"capability_details": [],
|
||||
"error": "模型不存在",
|
||||
}
|
||||
|
||||
supported_caps = global_model.supported_capabilities or []
|
||||
|
||||
# 获取支持的能力详情
|
||||
all_caps = {cap.name: cap for cap in get_all_capabilities()}
|
||||
capability_details = []
|
||||
for cap_name in supported_caps:
|
||||
if cap_name in all_caps:
|
||||
cap = all_caps[cap_name]
|
||||
capability_details.append({
|
||||
"name": cap.name,
|
||||
"display_name": cap.display_name,
|
||||
"description": cap.description,
|
||||
"match_mode": cap.match_mode.value,
|
||||
"config_mode": cap.config_mode.value,
|
||||
})
|
||||
|
||||
return {
|
||||
"model": model_name,
|
||||
"global_model_id": str(global_model.id),
|
||||
"global_model_name": global_model.name,
|
||||
"supported_capabilities": supported_caps,
|
||||
"capability_details": capability_details,
|
||||
}
|
||||
643
src/api/public/catalog.py
Normal file
643
src/api/public/catalog.py
Normal file
@@ -0,0 +1,643 @@
|
||||
"""
|
||||
公开API端点 - 用户可查看的提供商和模型信息
|
||||
不包含敏感信息,普通用户可访问
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import (
|
||||
ProviderStatsResponse,
|
||||
PublicGlobalModelListResponse,
|
||||
PublicGlobalModelResponse,
|
||||
PublicModelMappingResponse,
|
||||
PublicModelResponse,
|
||||
PublicProviderResponse,
|
||||
)
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
ModelMapping,
|
||||
Provider,
|
||||
ProviderEndpoint,
|
||||
RequestCandidate,
|
||||
)
|
||||
from src.models.endpoint_models import (
|
||||
PublicApiFormatHealthMonitor,
|
||||
PublicApiFormatHealthMonitorResponse,
|
||||
PublicHealthEvent,
|
||||
)
|
||||
from src.services.health.endpoint import EndpointHealthService
|
||||
|
||||
router = APIRouter(prefix="/api/public", tags=["Public Catalog"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/providers", response_model=List[PublicProviderResponse])
|
||||
async def get_public_providers(
|
||||
request: Request,
|
||||
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
|
||||
skip: int = Query(0, description="跳过记录数"),
|
||||
limit: int = Query(100, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取提供商列表(用户视图)。"""
|
||||
|
||||
adapter = PublicProvidersAdapter(is_active=is_active, skip=skip, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/models", response_model=List[PublicModelResponse])
|
||||
async def get_public_models(
|
||||
request: Request,
|
||||
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
|
||||
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
|
||||
skip: int = Query(0, description="跳过记录数"),
|
||||
limit: int = Query(100, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = PublicModelsAdapter(
|
||||
provider_id=provider_id, is_active=is_active, skip=skip, limit=limit
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/model-mappings", response_model=List[PublicModelMappingResponse])
|
||||
async def get_public_model_mappings(
|
||||
request: Request,
|
||||
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
|
||||
alias: Optional[str] = Query(None, description="别名过滤(原source_model)"),
|
||||
skip: int = Query(0, description="跳过记录数"),
|
||||
limit: int = Query(100, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = PublicModelMappingsAdapter(
|
||||
provider_id=provider_id,
|
||||
alias=alias,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ProviderStatsResponse)
|
||||
async def get_public_stats(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = PublicStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/search/models")
|
||||
async def search_models(
|
||||
request: Request,
|
||||
q: str = Query(..., description="搜索关键词"),
|
||||
provider_id: Optional[int] = Query(None, description="提供商ID过滤"),
|
||||
limit: int = Query(20, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = PublicSearchModelsAdapter(query=q, provider_id=provider_id, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/health/api-formats", response_model=PublicApiFormatHealthMonitorResponse)
|
||||
async def get_public_api_format_health(
|
||||
request: Request,
|
||||
lookback_hours: int = Query(6, ge=1, le=168, description="回溯小时数"),
|
||||
per_format_limit: int = Query(100, ge=10, le=500, description="每个格式的事件数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取各 API 格式的健康监控数据(公开版,不含敏感信息)"""
|
||||
adapter = PublicApiFormatHealthMonitorAdapter(
|
||||
lookback_hours=lookback_hours,
|
||||
per_format_limit=per_format_limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/global-models", response_model=PublicGlobalModelListResponse)
|
||||
async def get_public_global_models(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回记录数限制"),
|
||||
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取 GlobalModel 列表(用户视图,只读)"""
|
||||
adapter = PublicGlobalModelsAdapter(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
search=search,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
# -------- 公共适配器 --------
|
||||
|
||||
|
||||
class PublicApiAdapter(ApiAdapter):
|
||||
mode = ApiMode.PUBLIC
|
||||
|
||||
def authorize(self, context): # type: ignore[override]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicProvidersAdapter(PublicApiAdapter):
|
||||
is_active: Optional[bool]
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求提供商列表")
|
||||
query = db.query(Provider)
|
||||
if self.is_active is not None:
|
||||
query = query.filter(Provider.is_active == self.is_active)
|
||||
else:
|
||||
query = query.filter(Provider.is_active.is_(True))
|
||||
|
||||
providers = query.offset(self.skip).limit(self.limit).all()
|
||||
result = []
|
||||
for provider in providers:
|
||||
models_count = db.query(Model).filter(Model.provider_id == provider.id).count()
|
||||
active_models_count = (
|
||||
db.query(Model)
|
||||
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True)))
|
||||
.count()
|
||||
)
|
||||
mappings_count = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
and_(ModelMapping.provider_id == provider.id, ModelMapping.is_active.is_(True))
|
||||
)
|
||||
.count()
|
||||
)
|
||||
endpoints_count = len(provider.endpoints) if provider.endpoints else 0
|
||||
active_endpoints_count = (
|
||||
sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0
|
||||
)
|
||||
provider_data = PublicProviderResponse(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
display_name=provider.display_name,
|
||||
description=provider.description,
|
||||
is_active=provider.is_active,
|
||||
provider_priority=provider.provider_priority,
|
||||
models_count=models_count,
|
||||
active_models_count=active_models_count,
|
||||
mappings_count=mappings_count,
|
||||
endpoints_count=endpoints_count,
|
||||
active_endpoints_count=active_endpoints_count,
|
||||
)
|
||||
result.append(provider_data.model_dump())
|
||||
|
||||
logger.debug(f"返回 {len(result)} 个提供商信息")
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicModelsAdapter(PublicApiAdapter):
|
||||
provider_id: Optional[str]
|
||||
is_active: Optional[bool]
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求模型列表")
|
||||
query = (
|
||||
db.query(Model, Provider)
|
||||
.options(joinedload(Model.global_model))
|
||||
.join(Provider)
|
||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
||||
)
|
||||
if self.provider_id is not None:
|
||||
query = query.filter(Model.provider_id == self.provider_id)
|
||||
results = query.offset(self.skip).limit(self.limit).all()
|
||||
|
||||
response = []
|
||||
for model, provider in results:
|
||||
global_model = model.global_model
|
||||
display_name = global_model.display_name if global_model else model.provider_model_name
|
||||
unified_name = global_model.name if global_model else model.provider_model_name
|
||||
model_data = PublicModelResponse(
|
||||
id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
provider_name=provider.name,
|
||||
provider_display_name=provider.display_name,
|
||||
name=unified_name,
|
||||
display_name=display_name,
|
||||
description=global_model.description if global_model else None,
|
||||
tags=None,
|
||||
icon_url=global_model.icon_url if global_model else None,
|
||||
input_price_per_1m=model.get_effective_input_price(),
|
||||
output_price_per_1m=model.get_effective_output_price(),
|
||||
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||
cache_read_price_per_1m=model.get_effective_cache_read_price(),
|
||||
supports_vision=model.get_effective_supports_vision(),
|
||||
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
supports_streaming=model.get_effective_supports_streaming(),
|
||||
is_active=model.is_active,
|
||||
)
|
||||
response.append(model_data.model_dump())
|
||||
|
||||
logger.debug(f"返回 {len(response)} 个模型信息")
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicModelMappingsAdapter(PublicApiAdapter):
|
||||
provider_id: Optional[str]
|
||||
alias: Optional[str] # 原 source_model,改为 alias
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求模型映射列表")
|
||||
|
||||
query = (
|
||||
db.query(ModelMapping, GlobalModel, Provider)
|
||||
.join(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
|
||||
.outerjoin(Provider, ModelMapping.provider_id == Provider.id)
|
||||
.filter(
|
||||
and_(
|
||||
ModelMapping.is_active.is_(True),
|
||||
GlobalModel.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if self.provider_id is not None:
|
||||
provider_global_model_ids = (
|
||||
db.query(Model.global_model_id)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.filter(
|
||||
Provider.id == self.provider_id,
|
||||
Model.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
Model.global_model_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelMapping.provider_id == self.provider_id,
|
||||
and_(
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.target_global_model_id.in_(provider_global_model_ids),
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
|
||||
if self.alias is not None:
|
||||
query = query.filter(ModelMapping.source_model.ilike(f"%{self.alias}%"))
|
||||
|
||||
results = query.offset(self.skip).limit(self.limit).all()
|
||||
response = []
|
||||
for mapping, global_model, provider in results:
|
||||
scope = "provider" if mapping.provider_id else "global"
|
||||
mapping_data = PublicModelMappingResponse(
|
||||
id=mapping.id,
|
||||
source_model=mapping.source_model,
|
||||
target_global_model_id=mapping.target_global_model_id,
|
||||
target_global_model_name=global_model.name if global_model else None,
|
||||
target_global_model_display_name=(
|
||||
global_model.display_name if global_model else None
|
||||
),
|
||||
provider_id=mapping.provider_id,
|
||||
scope=scope,
|
||||
is_active=mapping.is_active,
|
||||
)
|
||||
response.append(mapping_data.model_dump())
|
||||
|
||||
logger.debug(f"返回 {len(response)} 个模型映射")
|
||||
return response
|
||||
|
||||
|
||||
class PublicStatsAdapter(PublicApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求系统统计信息")
|
||||
active_providers = db.query(Provider).filter(Provider.is_active.is_(True)).count()
|
||||
active_models = (
|
||||
db.query(Model)
|
||||
.join(Provider)
|
||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
||||
.count()
|
||||
)
|
||||
from ...models.database import ModelMapping
|
||||
|
||||
active_mappings = db.query(ModelMapping).filter(ModelMapping.is_active.is_(True)).count()
|
||||
formats = (
|
||||
db.query(Provider.api_format).filter(Provider.is_active.is_(True)).distinct().all()
|
||||
)
|
||||
supported_formats = [f.api_format for f in formats if f.api_format]
|
||||
stats = ProviderStatsResponse(
|
||||
total_providers=active_providers,
|
||||
active_providers=active_providers,
|
||||
total_models=active_models,
|
||||
active_models=active_models,
|
||||
total_mappings=active_mappings,
|
||||
supported_formats=supported_formats,
|
||||
)
|
||||
logger.debug("返回系统统计信息")
|
||||
return stats.model_dump()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicSearchModelsAdapter(PublicApiAdapter):
|
||||
query: str
|
||||
provider_id: Optional[int]
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug(f"公共API搜索模型: {self.query}")
|
||||
query_stmt = (
|
||||
db.query(Model, Provider)
|
||||
.options(joinedload(Model.global_model))
|
||||
.join(Provider)
|
||||
.outerjoin(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
||||
)
|
||||
search_filter = (
|
||||
Model.provider_model_name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.display_name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.description.ilike(f"%{self.query}%")
|
||||
)
|
||||
query_stmt = query_stmt.filter(search_filter)
|
||||
if self.provider_id is not None:
|
||||
query_stmt = query_stmt.filter(Model.provider_id == self.provider_id)
|
||||
results = query_stmt.limit(self.limit).all()
|
||||
|
||||
response = []
|
||||
for model, provider in results:
|
||||
global_model = model.global_model
|
||||
display_name = global_model.display_name if global_model else model.provider_model_name
|
||||
unified_name = global_model.name if global_model else model.provider_model_name
|
||||
model_data = PublicModelResponse(
|
||||
id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
provider_name=provider.name,
|
||||
provider_display_name=provider.display_name,
|
||||
name=unified_name,
|
||||
display_name=display_name,
|
||||
description=global_model.description if global_model else None,
|
||||
tags=None,
|
||||
icon_url=global_model.icon_url if global_model else None,
|
||||
input_price_per_1m=model.get_effective_input_price(),
|
||||
output_price_per_1m=model.get_effective_output_price(),
|
||||
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||
cache_read_price_per_1m=model.get_effective_cache_read_price(),
|
||||
supports_vision=model.get_effective_supports_vision(),
|
||||
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
supports_streaming=model.get_effective_supports_streaming(),
|
||||
is_active=model.is_active,
|
||||
)
|
||||
response.append(model_data.model_dump())
|
||||
|
||||
logger.debug(f"搜索 '{self.query}' 返回 {len(response)} 个结果")
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicApiFormatHealthMonitorAdapter(PublicApiAdapter):
|
||||
"""公开版 API 格式健康监控适配器(返回 events 数组,前端复用 EndpointHealthTimeline 组件)"""
|
||||
|
||||
lookback_hours: int
|
||||
per_format_limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
now = datetime.now(timezone.utc)
|
||||
since = now - timedelta(hours=self.lookback_hours)
|
||||
|
||||
# 1. 获取所有活跃的 API 格式
|
||||
active_formats = (
|
||||
db.query(ProviderEndpoint.api_format)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
|
||||
all_formats: List[str] = []
|
||||
for (api_format_enum,) in active_formats:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
all_formats.append(api_format)
|
||||
|
||||
# API 格式 -> Endpoint ID 映射(用于 Usage 时间线)
|
||||
endpoint_rows = (
|
||||
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
endpoint_map: Dict[str, List[str]] = defaultdict(list)
|
||||
for api_format_enum, endpoint_id in endpoint_rows:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
endpoint_map[api_format].append(endpoint_id)
|
||||
|
||||
# 2. 获取最近一段时间的 RequestCandidate(限制数量)
|
||||
# 只查询最终状态的记录:success, failed, skipped
|
||||
final_statuses = ["success", "failed", "skipped"]
|
||||
limit_rows = max(500, self.per_format_limit * 10)
|
||||
rows = (
|
||||
db.query(
|
||||
RequestCandidate,
|
||||
ProviderEndpoint.api_format,
|
||||
)
|
||||
.join(ProviderEndpoint, RequestCandidate.endpoint_id == ProviderEndpoint.id)
|
||||
.filter(
|
||||
RequestCandidate.created_at >= since,
|
||||
RequestCandidate.status.in_(final_statuses),
|
||||
)
|
||||
.order_by(RequestCandidate.created_at.desc())
|
||||
.limit(limit_rows)
|
||||
.all()
|
||||
)
|
||||
|
||||
grouped_candidates: Dict[str, List[RequestCandidate]] = {}
|
||||
|
||||
for candidate, api_format_enum in rows:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
if api_format not in grouped_candidates:
|
||||
grouped_candidates[api_format] = []
|
||||
|
||||
if len(grouped_candidates[api_format]) < self.per_format_limit:
|
||||
grouped_candidates[api_format].append(candidate)
|
||||
|
||||
# 3. 为所有活跃格式生成监控数据
|
||||
monitors: List[PublicApiFormatHealthMonitor] = []
|
||||
for api_format in all_formats:
|
||||
candidates = grouped_candidates.get(api_format, [])
|
||||
|
||||
# 统计
|
||||
success_count = sum(1 for c in candidates if c.status == "success")
|
||||
failed_count = sum(1 for c in candidates if c.status == "failed")
|
||||
skipped_count = sum(1 for c in candidates if c.status == "skipped")
|
||||
total_attempts = len(candidates)
|
||||
|
||||
# 计算成功率 = success / (success + failed)
|
||||
actual_completed = success_count + failed_count
|
||||
success_rate = success_count / actual_completed if actual_completed > 0 else 1.0
|
||||
|
||||
# 转换为公开版事件列表(不含敏感信息如 provider_id, key_id)
|
||||
events: List[PublicHealthEvent] = []
|
||||
for c in candidates:
|
||||
event_time = c.finished_at or c.started_at or c.created_at
|
||||
events.append(
|
||||
PublicHealthEvent(
|
||||
timestamp=event_time,
|
||||
status=c.status,
|
||||
status_code=c.status_code,
|
||||
latency_ms=c.latency_ms,
|
||||
error_type=c.error_type,
|
||||
)
|
||||
)
|
||||
|
||||
# 最后事件时间
|
||||
last_event_at = None
|
||||
if candidates:
|
||||
last_event_at = (
|
||||
candidates[0].finished_at
|
||||
or candidates[0].started_at
|
||||
or candidates[0].created_at
|
||||
)
|
||||
|
||||
timeline_data = EndpointHealthService._generate_timeline_from_usage(
|
||||
db=db,
|
||||
endpoint_ids=endpoint_map.get(api_format, []),
|
||||
now=now,
|
||||
lookback_hours=self.lookback_hours,
|
||||
)
|
||||
|
||||
# 获取本站入口路径
|
||||
from src.core.api_format_metadata import get_local_path
|
||||
from src.core.enums import APIFormat
|
||||
|
||||
try:
|
||||
api_format_enum = APIFormat(api_format)
|
||||
local_path = get_local_path(api_format_enum)
|
||||
except ValueError:
|
||||
local_path = "/"
|
||||
|
||||
monitors.append(
|
||||
PublicApiFormatHealthMonitor(
|
||||
api_format=api_format,
|
||||
api_path=local_path,
|
||||
total_attempts=total_attempts,
|
||||
success_count=success_count,
|
||||
failed_count=failed_count,
|
||||
skipped_count=skipped_count,
|
||||
success_rate=success_rate,
|
||||
last_event_at=last_event_at,
|
||||
events=events,
|
||||
timeline=timeline_data.get("timeline", []),
|
||||
time_range_start=timeline_data.get("time_range_start"),
|
||||
time_range_end=timeline_data.get("time_range_end"),
|
||||
)
|
||||
)
|
||||
|
||||
response = PublicApiFormatHealthMonitorResponse(
|
||||
generated_at=now,
|
||||
formats=monitors,
|
||||
)
|
||||
|
||||
logger.debug(f"公开健康监控: 返回 {len(monitors)} 个 API 格式的健康数据")
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicGlobalModelsAdapter(PublicApiAdapter):
|
||||
"""公开的 GlobalModel 列表适配器"""
|
||||
|
||||
skip: int
|
||||
limit: int
|
||||
is_active: Optional[bool]
|
||||
search: Optional[str]
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求 GlobalModel 列表")
|
||||
|
||||
query = db.query(GlobalModel)
|
||||
|
||||
# 默认只返回活跃的模型
|
||||
if self.is_active is not None:
|
||||
query = query.filter(GlobalModel.is_active == self.is_active)
|
||||
else:
|
||||
query = query.filter(GlobalModel.is_active.is_(True))
|
||||
|
||||
# 搜索过滤
|
||||
if self.search:
|
||||
search_term = f"%{self.search}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
GlobalModel.name.ilike(search_term),
|
||||
GlobalModel.display_name.ilike(search_term),
|
||||
GlobalModel.description.ilike(search_term),
|
||||
)
|
||||
)
|
||||
|
||||
# 统计总数
|
||||
total = query.count()
|
||||
|
||||
# 分页
|
||||
models = query.order_by(GlobalModel.name).offset(self.skip).limit(self.limit).all()
|
||||
|
||||
# 转换为响应格式
|
||||
model_responses = []
|
||||
for gm in models:
|
||||
model_responses.append(
|
||||
PublicGlobalModelResponse(
|
||||
id=gm.id,
|
||||
name=gm.name,
|
||||
display_name=gm.display_name,
|
||||
description=gm.description,
|
||||
icon_url=gm.icon_url,
|
||||
is_active=gm.is_active,
|
||||
default_price_per_request=gm.default_price_per_request,
|
||||
default_tiered_pricing=gm.default_tiered_pricing,
|
||||
default_supports_vision=gm.default_supports_vision or False,
|
||||
default_supports_function_calling=gm.default_supports_function_calling or False,
|
||||
default_supports_streaming=(
|
||||
gm.default_supports_streaming
|
||||
if gm.default_supports_streaming is not None
|
||||
else True
|
||||
),
|
||||
default_supports_extended_thinking=gm.default_supports_extended_thinking
|
||||
or False,
|
||||
supported_capabilities=gm.supported_capabilities,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"返回 {len(model_responses)} 个 GlobalModel")
|
||||
return PublicGlobalModelListResponse(models=model_responses, total=total)
|
||||
52
src/api/public/claude.py
Normal file
52
src/api/public/claude.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Claude API 端点
|
||||
|
||||
- /v1/messages - Claude Messages API
|
||||
- /v1/messages/count_tokens - Token Count API
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.api.handlers.claude import (
|
||||
ClaudeTokenCountAdapter,
|
||||
build_claude_adapter,
|
||||
)
|
||||
from src.core.api_format_metadata import get_api_format_definition
|
||||
from src.core.enums import APIFormat
|
||||
from src.database import get_db
|
||||
|
||||
_claude_def = get_api_format_definition(APIFormat.CLAUDE)
|
||||
router = APIRouter(tags=["Claude API"], prefix=_claude_def.path_prefix)
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.post("/v1/messages")
|
||||
async def create_message(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""统一入口:根据 x-app 自动在标准/Claude Code 之间切换。"""
|
||||
adapter = build_claude_adapter(http_request.headers.get("x-app", ""))
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
api_format_hint=adapter.allowed_api_formats[0],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/messages/count_tokens")
|
||||
async def count_tokens(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = ClaudeTokenCountAdapter()
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
)
|
||||
130
src/api/public/gemini.py
Normal file
130
src/api/public/gemini.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Gemini API 专属端点
|
||||
|
||||
托管 Gemini API 相关路由:
|
||||
- /v1beta/models/{model}:generateContent
|
||||
- /v1beta/models/{model}:streamGenerateContent
|
||||
|
||||
注意: Gemini API 的 model 在 URL 路径中,而不是请求体中
|
||||
|
||||
路径配置来源: src.core.api_format_metadata.APIFormat.GEMINI
|
||||
- path_prefix: 本站路径前缀(如 /gemini),通过 router prefix 配置
|
||||
- default_path: 标准 API 路径模板
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.api.handlers.gemini import build_gemini_adapter
|
||||
from src.api.handlers.gemini_cli import build_gemini_cli_adapter
|
||||
from src.core.api_format_metadata import get_api_format_definition
|
||||
from src.core.enums import APIFormat
|
||||
from src.database import get_db
|
||||
|
||||
# 从配置获取路径前缀
|
||||
_gemini_def = get_api_format_definition(APIFormat.GEMINI)
|
||||
|
||||
router = APIRouter(tags=["Gemini API"], prefix=_gemini_def.path_prefix)
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
def _is_cli_request(request: Request) -> bool:
|
||||
"""
|
||||
判断是否为 CLI 请求
|
||||
|
||||
检查顺序:
|
||||
1. x-app header 包含 "cli"
|
||||
2. user-agent 包含 "GeminiCLI" 或 "gemini-cli"
|
||||
"""
|
||||
# 检查 x-app header
|
||||
x_app = request.headers.get("x-app", "")
|
||||
if "cli" in x_app.lower():
|
||||
return True
|
||||
|
||||
# 检查 user-agent
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
user_agent_lower = user_agent.lower()
|
||||
if "geminicli" in user_agent_lower or "gemini-cli" in user_agent_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/v1beta/models/{model}:generateContent")
|
||||
async def generate_content(
|
||||
model: str,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Gemini generateContent 端点
|
||||
|
||||
非流式生成内容请求
|
||||
"""
|
||||
# 根据 user-agent 或 x-app header 选择适配器
|
||||
if _is_cli_request(http_request):
|
||||
adapter = build_gemini_cli_adapter()
|
||||
else:
|
||||
adapter = build_gemini_adapter()
|
||||
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
api_format_hint=adapter.allowed_api_formats[0],
|
||||
# 将 model 注入到请求体中,stream 用于内部判断流式模式
|
||||
path_params={"model": model, "stream": False},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1beta/models/{model}:streamGenerateContent")
|
||||
async def stream_generate_content(
|
||||
model: str,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Gemini streamGenerateContent 端点
|
||||
|
||||
流式生成内容请求
|
||||
|
||||
注意: Gemini API 通过 URL 端点区分流式/非流式,不需要在请求体中添加 stream 字段
|
||||
"""
|
||||
# 根据 user-agent 或 x-app header 选择适配器
|
||||
if _is_cli_request(http_request):
|
||||
adapter = build_gemini_cli_adapter()
|
||||
else:
|
||||
adapter = build_gemini_adapter()
|
||||
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
api_format_hint=adapter.allowed_api_formats[0],
|
||||
# model 注入到请求体,stream 用于内部判断流式模式(不发送到 API)
|
||||
path_params={"model": model, "stream": True},
|
||||
)
|
||||
|
||||
|
||||
# 兼容 v1 路径(部分 SDK 可能使用)
|
||||
@router.post("/v1/models/{model}:generateContent")
|
||||
async def generate_content_v1(
|
||||
model: str,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""v1 兼容端点"""
|
||||
return await generate_content(model, http_request, db)
|
||||
|
||||
|
||||
@router.post("/v1/models/{model}:streamGenerateContent")
|
||||
async def stream_generate_content_v1(
|
||||
model: str,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""v1 兼容端点"""
|
||||
return await stream_generate_content(model, http_request, db)
|
||||
50
src/api/public/openai.py
Normal file
50
src/api/public/openai.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
OpenAI API 端点
|
||||
|
||||
- /v1/chat/completions - OpenAI Chat API
|
||||
- /v1/responses - OpenAI Responses API (CLI)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.api.handlers.openai import OpenAIChatAdapter
|
||||
from src.api.handlers.openai_cli import OpenAICliAdapter
|
||||
from src.core.api_format_metadata import get_api_format_definition
|
||||
from src.core.enums import APIFormat
|
||||
from src.database import get_db
|
||||
|
||||
_openai_def = get_api_format_definition(APIFormat.OPENAI)
|
||||
router = APIRouter(tags=["OpenAI API"], prefix=_openai_def.path_prefix)
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def create_chat_completion(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = OpenAIChatAdapter()
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
api_format_hint=adapter.allowed_api_formats[0],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/responses")
|
||||
async def create_responses(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = OpenAICliAdapter()
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
api_format_hint=adapter.allowed_api_formats[0],
|
||||
)
|
||||
306
src/api/public/system_catalog.py
Normal file
306
src/api/public/system_catalog.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
System Catalog / 健康检查相关端点
|
||||
|
||||
这些是系统工具端点,不需要复杂的 Adapter 抽象。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
||||
from src.clients.redis_client import get_redis_client, get_redis_client_sync
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.database.database import get_pool_status
|
||||
from src.models.database import Model, Provider
|
||||
from src.services.orchestration.fallback_orchestrator import FallbackOrchestrator
|
||||
from src.services.provider.transport import build_provider_url
|
||||
|
||||
router = APIRouter(tags=["System Catalog"])
|
||||
|
||||
|
||||
# ============== 辅助函数 ==============
|
||||
|
||||
|
||||
def _as_bool(value: Optional[str], default: bool) -> bool:
|
||||
"""将字符串转换为布尔值"""
|
||||
if value is None:
|
||||
return default
|
||||
return value.lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _serialize_provider(
|
||||
provider: Provider,
|
||||
include_models: bool,
|
||||
include_endpoints: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""序列化 Provider 对象"""
|
||||
provider_data: Dict[str, Any] = {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"is_active": provider.is_active,
|
||||
"provider_priority": provider.provider_priority,
|
||||
}
|
||||
|
||||
if include_endpoints:
|
||||
provider_data["endpoints"] = [
|
||||
{
|
||||
"id": endpoint.id,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
"is_active": endpoint.is_active,
|
||||
}
|
||||
for endpoint in provider.endpoints or []
|
||||
]
|
||||
|
||||
if include_models:
|
||||
provider_data["models"] = [
|
||||
{
|
||||
"id": model.id,
|
||||
"name": (
|
||||
model.global_model.name if model.global_model else model.provider_model_name
|
||||
),
|
||||
"display_name": (
|
||||
model.global_model.display_name
|
||||
if model.global_model
|
||||
else model.provider_model_name
|
||||
),
|
||||
"is_active": model.is_active,
|
||||
"supports_streaming": model.supports_streaming,
|
||||
}
|
||||
for model in provider.models or []
|
||||
if model.is_active
|
||||
]
|
||||
|
||||
return provider_data
|
||||
|
||||
|
||||
def _select_provider(db: Session, provider_name: Optional[str]) -> Optional[Provider]:
|
||||
"""选择 Provider(按 provider_priority 优先级选择)"""
|
||||
query = db.query(Provider).filter(Provider.is_active == True)
|
||||
if provider_name:
|
||||
provider = query.filter(Provider.name == provider_name).first()
|
||||
if provider:
|
||||
return provider
|
||||
|
||||
# 按优先级选择(provider_priority 最小的优先)
|
||||
return query.order_by(Provider.provider_priority.asc()).first()
|
||||
|
||||
|
||||
# ============== 端点 ==============
|
||||
|
||||
|
||||
@router.get("/v1/health")
|
||||
async def service_health(db: Session = Depends(get_db)):
|
||||
"""返回服务健康状态与依赖信息"""
|
||||
active_providers = (
|
||||
db.query(func.count(Provider.id)).filter(Provider.is_active == True).scalar() or 0
|
||||
)
|
||||
active_models = db.query(func.count(Model.id)).filter(Model.is_active == True).scalar() or 0
|
||||
|
||||
redis_info: Dict[str, Any] = {"status": "unknown"}
|
||||
try:
|
||||
redis = await get_redis_client()
|
||||
if redis:
|
||||
await redis.ping()
|
||||
redis_info = {"status": "ok"}
|
||||
else:
|
||||
redis_info = {"status": "degraded", "message": "Redis client not initialized"}
|
||||
except Exception as exc:
|
||||
redis_info = {"status": "error", "message": str(exc)}
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"stats": {
|
||||
"active_providers": active_providers,
|
||||
"active_models": active_models,
|
||||
},
|
||||
"dependencies": {
|
||||
"database": {"status": "ok"},
|
||||
"redis": redis_info,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""简单健康检查端点(无需认证)"""
|
||||
try:
|
||||
pool_status = get_pool_status()
|
||||
pool_health = {
|
||||
"checked_out": pool_status["checked_out"],
|
||||
"pool_size": pool_status["pool_size"],
|
||||
"overflow": pool_status["overflow"],
|
||||
"max_capacity": pool_status["max_capacity"],
|
||||
"usage_rate": (
|
||||
f"{(pool_status['checked_out'] / pool_status['max_capacity'] * 100):.1f}%"
|
||||
if pool_status["max_capacity"] > 0
|
||||
else "0.0%"
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
pool_health = {"error": str(e)}
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"database_pool": pool_health,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root(db: Session = Depends(get_db)):
|
||||
"""Root endpoint - 服务信息概览"""
|
||||
# 按优先级选择最高优先级的提供商
|
||||
top_provider = (
|
||||
db.query(Provider)
|
||||
.filter(Provider.is_active == True)
|
||||
.order_by(Provider.provider_priority.asc())
|
||||
.first()
|
||||
)
|
||||
active_providers = db.query(Provider).filter(Provider.is_active == True).count()
|
||||
|
||||
return {
|
||||
"message": "AI Proxy with Modular Architecture v4.0.0",
|
||||
"status": "running",
|
||||
"current_provider": top_provider.name if top_provider else "None",
|
||||
"available_providers": active_providers,
|
||||
"config": {},
|
||||
"endpoints": {
|
||||
"messages": "/v1/messages",
|
||||
"count_tokens": "/v1/messages/count_tokens",
|
||||
"health": "/v1/health",
|
||||
"providers": "/v1/providers",
|
||||
"test_connection": "/v1/test-connection",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/providers")
|
||||
async def list_providers(
|
||||
db: Session = Depends(get_db),
|
||||
include_models: bool = Query(False),
|
||||
include_endpoints: bool = Query(False),
|
||||
active_only: bool = Query(True),
|
||||
):
|
||||
"""列出所有 Provider"""
|
||||
load_options = []
|
||||
if include_models:
|
||||
load_options.append(selectinload(Provider.models).selectinload(Model.global_model))
|
||||
if include_endpoints:
|
||||
load_options.append(selectinload(Provider.endpoints))
|
||||
|
||||
base_query = db.query(Provider)
|
||||
if load_options:
|
||||
base_query = base_query.options(*load_options)
|
||||
if active_only:
|
||||
base_query = base_query.filter(Provider.is_active == True)
|
||||
base_query = base_query.order_by(Provider.provider_priority.asc(), Provider.name.asc())
|
||||
|
||||
providers = base_query.all()
|
||||
return {
|
||||
"providers": [
|
||||
_serialize_provider(provider, include_models, include_endpoints)
|
||||
for provider in providers
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/providers/{provider_identifier}")
|
||||
async def provider_detail(
|
||||
provider_identifier: str,
|
||||
db: Session = Depends(get_db),
|
||||
include_models: bool = Query(False),
|
||||
include_endpoints: bool = Query(False),
|
||||
):
|
||||
"""获取单个 Provider 详情"""
|
||||
load_options = []
|
||||
if include_models:
|
||||
load_options.append(selectinload(Provider.models).selectinload(Model.global_model))
|
||||
if include_endpoints:
|
||||
load_options.append(selectinload(Provider.endpoints))
|
||||
|
||||
base_query = db.query(Provider)
|
||||
if load_options:
|
||||
base_query = base_query.options(*load_options)
|
||||
|
||||
provider = base_query.filter(
|
||||
(Provider.id == provider_identifier) | (Provider.name == provider_identifier)
|
||||
).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
return _serialize_provider(provider, include_models, include_endpoints)
|
||||
|
||||
|
||||
@router.get("/v1/test-connection")
|
||||
@router.get("/test-connection")
|
||||
async def test_connection(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
provider: Optional[str] = Query(None),
|
||||
model: str = Query("claude-3-haiku-20240307"),
|
||||
api_format: Optional[str] = Query(None),
|
||||
):
|
||||
"""测试 Provider 连接"""
|
||||
selected_provider = _select_provider(db, provider)
|
||||
if not selected_provider:
|
||||
raise HTTPException(status_code=503, detail="No active provider available")
|
||||
|
||||
# 构建测试请求体
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Health check"}],
|
||||
"max_tokens": 5,
|
||||
}
|
||||
|
||||
# 确定 API 格式
|
||||
format_value = api_format or "CLAUDE"
|
||||
|
||||
# 创建 FallbackOrchestrator
|
||||
redis_client = get_redis_client_sync()
|
||||
orchestrator = FallbackOrchestrator(db, redis_client)
|
||||
|
||||
# 定义请求函数
|
||||
async def test_request_func(_prov, endpoint, key):
|
||||
request_builder = PassthroughRequestBuilder()
|
||||
provider_payload, provider_headers = request_builder.build(
|
||||
payload, {}, endpoint, key, is_stream=False
|
||||
)
|
||||
|
||||
url = build_provider_url(
|
||||
endpoint,
|
||||
query_params=dict(request.query_params),
|
||||
path_params={"model": model},
|
||||
is_stream=False,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(url, json=provider_payload, headers=provider_headers)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
try:
|
||||
response, actual_provider, *_ = await orchestrator.execute_with_fallback(
|
||||
api_format=format_value,
|
||||
model_name=model,
|
||||
user_api_key=None,
|
||||
request_func=test_request_func,
|
||||
request_id=None,
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"provider": actual_provider,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"response_id": response.get("id", "unknown"),
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error(f"API connectivity test failed: {exc}")
|
||||
raise HTTPException(status_code=503, detail=str(exc))
|
||||
10
src/api/user_me/__init__.py
Normal file
10
src/api/user_me/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Routes for authenticated user self-service APIs."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .routes import router as me_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(me_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
1127
src/api/user_me/routes.py
Normal file
1127
src/api/user_me/routes.py
Normal file
File diff suppressed because it is too large
Load Diff
11
src/clients/__init__.py
Normal file
11
src/clients/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .http_client import HTTPClientPool, close_http_clients, get_http_client
|
||||
from .redis_client import close_redis_client, get_redis_client, get_redis_client_sync
|
||||
|
||||
__all__ = [
|
||||
"HTTPClientPool",
|
||||
"get_http_client",
|
||||
"close_http_clients",
|
||||
"get_redis_client",
|
||||
"get_redis_client_sync",
|
||||
"close_redis_client",
|
||||
]
|
||||
133
src/clients/http_client.py
Normal file
133
src/clients/http_client.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
全局HTTP客户端池管理
|
||||
避免每次请求都创建新的AsyncClient,提高性能
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class HTTPClientPool:
|
||||
"""
|
||||
全局HTTP客户端池单例
|
||||
|
||||
管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接
|
||||
"""
|
||||
|
||||
_instance: Optional["HTTPClientPool"] = None
|
||||
_default_client: Optional[httpx.AsyncClient] = None
|
||||
_clients: Dict[str, httpx.AsyncClient] = {}
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_default_client(cls) -> httpx.AsyncClient:
|
||||
"""
|
||||
获取默认的HTTP客户端
|
||||
|
||||
用于大多数HTTP请求,具有合理的默认配置
|
||||
"""
|
||||
if cls._default_client is None:
|
||||
cls._default_client = httpx.AsyncClient(
|
||||
http2=False, # 暂时禁用HTTP/2以提高兼容性
|
||||
verify=True, # 启用SSL验证
|
||||
timeout=httpx.Timeout(
|
||||
connect=10.0, # 连接超时
|
||||
read=300.0, # 读取超时(5分钟,适合流式响应)
|
||||
write=10.0, # 写入超时
|
||||
pool=5.0, # 连接池超时
|
||||
),
|
||||
limits=httpx.Limits(
|
||||
max_connections=100, # 最大连接数
|
||||
max_keepalive_connections=20, # 最大保活连接数
|
||||
keepalive_expiry=30.0, # 保活过期时间(秒)
|
||||
),
|
||||
follow_redirects=True, # 跟随重定向
|
||||
)
|
||||
logger.info("全局HTTP客户端池已初始化")
|
||||
return cls._default_client
|
||||
|
||||
@classmethod
|
||||
def get_client(cls, name: str, **kwargs: Any) -> httpx.AsyncClient:
|
||||
"""
|
||||
获取或创建命名的HTTP客户端
|
||||
|
||||
用于需要特定配置的场景(如不同的超时设置、代理等)
|
||||
|
||||
Args:
|
||||
name: 客户端标识符
|
||||
**kwargs: httpx.AsyncClient的配置参数
|
||||
"""
|
||||
if name not in cls._clients:
|
||||
# 合并默认配置和自定义配置
|
||||
config = {
|
||||
"http2": False,
|
||||
"verify": True,
|
||||
"timeout": httpx.Timeout(10.0, read=300.0),
|
||||
"follow_redirects": True,
|
||||
}
|
||||
config.update(kwargs)
|
||||
|
||||
cls._clients[name] = httpx.AsyncClient(**config)
|
||||
logger.debug(f"创建命名HTTP客户端: {name}")
|
||||
|
||||
return cls._clients[name]
|
||||
|
||||
@classmethod
|
||||
async def close_all(cls):
|
||||
"""关闭所有HTTP客户端"""
|
||||
if cls._default_client is not None:
|
||||
await cls._default_client.aclose()
|
||||
cls._default_client = None
|
||||
logger.info("默认HTTP客户端已关闭")
|
||||
|
||||
for name, client in cls._clients.items():
|
||||
await client.aclose()
|
||||
logger.debug(f"命名HTTP客户端已关闭: {name}")
|
||||
|
||||
cls._clients.clear()
|
||||
logger.info("所有HTTP客户端已关闭")
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def get_temp_client(cls, **kwargs: Any):
|
||||
"""
|
||||
获取临时HTTP客户端(上下文管理器)
|
||||
|
||||
用于一次性请求,使用后自动关闭
|
||||
|
||||
用法:
|
||||
async with HTTPClientPool.get_temp_client() as client:
|
||||
response = await client.get('https://example.com')
|
||||
"""
|
||||
config = {
|
||||
"http2": False,
|
||||
"verify": True,
|
||||
"timeout": httpx.Timeout(10.0),
|
||||
}
|
||||
config.update(kwargs)
|
||||
|
||||
client = httpx.AsyncClient(**config)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
# 便捷访问函数
|
||||
def get_http_client() -> httpx.AsyncClient:
|
||||
"""获取默认HTTP客户端的便捷函数"""
|
||||
return HTTPClientPool.get_default_client()
|
||||
|
||||
|
||||
async def close_http_clients():
|
||||
"""关闭所有HTTP客户端的便捷函数"""
|
||||
await HTTPClientPool.close_all()
|
||||
346
src/clients/redis_client.py
Normal file
346
src/clients/redis_client.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
全局Redis客户端管理
|
||||
|
||||
提供统一的Redis客户端访问,确保所有服务使用同一个连接池
|
||||
|
||||
熔断器说明:
|
||||
- 连续失败达到阈值后开启熔断
|
||||
- 熔断期间返回明确的状态而非静默失败
|
||||
- 调用方可以根据状态决定降级策略
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from src.core.logger import logger
|
||||
from redis.asyncio import sentinel as redis_sentinel
|
||||
|
||||
|
||||
class RedisState(Enum):
|
||||
"""Redis 连接状态"""
|
||||
|
||||
NOT_INITIALIZED = "not_initialized" # 未初始化
|
||||
CONNECTED = "connected" # 已连接
|
||||
CIRCUIT_OPEN = "circuit_open" # 熔断中
|
||||
DISCONNECTED = "disconnected" # 断开连接
|
||||
|
||||
|
||||
class RedisClientManager:
|
||||
"""
|
||||
Redis客户端管理器(单例)
|
||||
|
||||
提供 Redis 连接管理、熔断器保护和状态监控。
|
||||
"""
|
||||
|
||||
_instance: Optional["RedisClientManager"] = None
|
||||
_redis: Optional[aioredis.Redis] = None
|
||||
|
||||
def __new__(cls):
|
||||
"""单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# 避免重复初始化
|
||||
if getattr(self, "_initialized", False):
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self._circuit_open_until: Optional[float] = None
|
||||
self._consecutive_failures: int = 0
|
||||
self._circuit_threshold = int(os.getenv("REDIS_CIRCUIT_BREAKER_THRESHOLD", "3"))
|
||||
self._circuit_reset_seconds = int(os.getenv("REDIS_CIRCUIT_BREAKER_RESET_SECONDS", "60"))
|
||||
self._last_error: Optional[str] = None # 记录最后一次错误
|
||||
|
||||
def get_state(self) -> RedisState:
|
||||
"""
|
||||
获取 Redis 连接状态
|
||||
|
||||
Returns:
|
||||
当前连接状态枚举值
|
||||
"""
|
||||
if self._redis is not None:
|
||||
return RedisState.CONNECTED
|
||||
if self._circuit_open_until and time.time() < self._circuit_open_until:
|
||||
return RedisState.CIRCUIT_OPEN
|
||||
if self._last_error:
|
||||
return RedisState.DISCONNECTED
|
||||
return RedisState.NOT_INITIALIZED
|
||||
|
||||
def get_circuit_info(self) -> dict:
|
||||
"""
|
||||
获取熔断器详细信息
|
||||
|
||||
Returns:
|
||||
包含熔断器状态的字典
|
||||
"""
|
||||
state = self.get_state()
|
||||
info = {
|
||||
"state": state.value,
|
||||
"consecutive_failures": self._consecutive_failures,
|
||||
"circuit_threshold": self._circuit_threshold,
|
||||
"last_error": self._last_error,
|
||||
}
|
||||
|
||||
if state == RedisState.CIRCUIT_OPEN and self._circuit_open_until:
|
||||
info["circuit_remaining_seconds"] = max(0, self._circuit_open_until - time.time())
|
||||
|
||||
return info
|
||||
|
||||
def reset_circuit_breaker(self) -> None:
|
||||
"""
|
||||
手动重置熔断器(用于管理后台紧急恢复)
|
||||
"""
|
||||
logger.info("Redis 熔断器手动重置")
|
||||
self._circuit_open_until = None
|
||||
self._consecutive_failures = 0
|
||||
self._last_error = None
|
||||
|
||||
async def initialize(self, require_redis: bool = False) -> Optional[aioredis.Redis]:
|
||||
"""
|
||||
初始化Redis连接
|
||||
|
||||
Args:
|
||||
require_redis: 是否强制要求Redis连接成功,如果为True则连接失败时抛出异常
|
||||
|
||||
Returns:
|
||||
Redis客户端实例,如果连接失败返回None(当require_redis=False时)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当require_redis=True且连接失败时
|
||||
"""
|
||||
if self._redis is not None:
|
||||
return self._redis
|
||||
|
||||
# 检查熔断状态
|
||||
if self._circuit_open_until and time.time() < self._circuit_open_until:
|
||||
remaining = self._circuit_open_until - time.time()
|
||||
logger.warning(
|
||||
"Redis 客户端处于熔断状态,跳过初始化,剩余 %.1f 秒 (last_error: %s)",
|
||||
remaining,
|
||||
self._last_error,
|
||||
)
|
||||
if require_redis:
|
||||
raise RuntimeError(
|
||||
f"Redis 处于熔断状态,剩余 {remaining:.1f} 秒。"
|
||||
f"最后错误: {self._last_error}。"
|
||||
"使用管理 API 重置熔断器或等待自动恢复。"
|
||||
)
|
||||
return None
|
||||
|
||||
# 优先使用 REDIS_URL,如果没有则根据密码构建 URL
|
||||
redis_url = os.getenv("REDIS_URL")
|
||||
redis_max_conn = int(os.getenv("REDIS_MAX_CONNECTIONS", "50"))
|
||||
sentinel_hosts = os.getenv("REDIS_SENTINEL_HOSTS")
|
||||
sentinel_service = os.getenv("REDIS_SENTINEL_SERVICE_NAME", "mymaster")
|
||||
redis_password = os.getenv("REDIS_PASSWORD")
|
||||
|
||||
if not redis_url and not sentinel_hosts:
|
||||
# 本地开发模式:从 REDIS_PASSWORD 构建 URL
|
||||
if redis_password:
|
||||
redis_url = f"redis://:{redis_password}@localhost:6379/0"
|
||||
else:
|
||||
redis_url = "redis://localhost:6379/0"
|
||||
|
||||
try:
|
||||
if sentinel_hosts:
|
||||
sentinel_list = []
|
||||
for host in sentinel_hosts.split(","):
|
||||
host = host.strip()
|
||||
if not host:
|
||||
continue
|
||||
if ":" in host:
|
||||
hostname, port = host.split(":", 1)
|
||||
sentinel_list.append((hostname, int(port)))
|
||||
else:
|
||||
sentinel_list.append((host, 26379))
|
||||
|
||||
sentinel_kwargs = {
|
||||
"password": redis_password,
|
||||
"socket_timeout": 5.0,
|
||||
}
|
||||
sentinel = redis_sentinel.Sentinel(
|
||||
sentinel_list,
|
||||
**sentinel_kwargs,
|
||||
)
|
||||
self._redis = sentinel.master_for(
|
||||
service_name=sentinel_service,
|
||||
max_connections=redis_max_conn,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=5.0,
|
||||
)
|
||||
safe_url = f"sentinel://{sentinel_service}"
|
||||
else:
|
||||
self._redis = await aioredis.from_url(
|
||||
redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
socket_timeout=5.0,
|
||||
socket_connect_timeout=5.0,
|
||||
max_connections=redis_max_conn,
|
||||
)
|
||||
safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url
|
||||
|
||||
# 测试连接
|
||||
await self._redis.ping()
|
||||
logger.info(f"[OK] 全局Redis客户端初始化成功: {safe_url}")
|
||||
self._consecutive_failures = 0
|
||||
self._circuit_open_until = None
|
||||
return self._redis
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
self._last_error = error_msg
|
||||
logger.error(f"[ERROR] Redis连接失败: {error_msg}")
|
||||
|
||||
self._consecutive_failures += 1
|
||||
if self._consecutive_failures >= self._circuit_threshold:
|
||||
self._circuit_open_until = time.time() + self._circuit_reset_seconds
|
||||
logger.warning(
|
||||
"Redis 初始化连续失败 %s 次,开启熔断 %s 秒。"
|
||||
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
|
||||
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
|
||||
self._consecutive_failures,
|
||||
self._circuit_reset_seconds,
|
||||
)
|
||||
|
||||
if require_redis:
|
||||
# 强制要求Redis时,抛出异常拒绝启动
|
||||
raise RuntimeError(
|
||||
f"Redis连接失败: {error_msg}\n"
|
||||
"缓存亲和性功能需要Redis支持,请确保Redis服务正常运行。\n"
|
||||
"检查事项:\n"
|
||||
"1. Redis服务是否已启动(docker-compose up -d redis)\n"
|
||||
"2. 环境变量 REDIS_URL 或 REDIS_PASSWORD 是否配置正确\n"
|
||||
"3. Redis端口(默认6379)是否可访问"
|
||||
) from e
|
||||
|
||||
logger.warning(
|
||||
"[WARN] Redis 不可用,以下功能将降级运行(仅在单实例环境下安全):\n"
|
||||
" - 缓存亲和性: 禁用(每次请求随机选择 Endpoint)\n"
|
||||
" - 分布式并发控制: 降级为本地计数\n"
|
||||
" - RPM 限流: 降级为本地限流"
|
||||
)
|
||||
self._redis = None
|
||||
return None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭Redis连接"""
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
logger.info("全局Redis客户端已关闭")
|
||||
|
||||
def get_client(self) -> Optional[aioredis.Redis]:
|
||||
"""
|
||||
获取Redis客户端(非异步)
|
||||
|
||||
注意:必须先调用initialize()初始化
|
||||
|
||||
Returns:
|
||||
Redis客户端实例或None
|
||||
"""
|
||||
return self._redis
|
||||
|
||||
|
||||
# 全局单例
|
||||
_redis_manager: Optional[RedisClientManager] = None
|
||||
|
||||
|
||||
async def get_redis_client(require_redis: bool = False) -> Optional[aioredis.Redis]:
|
||||
"""
|
||||
获取全局Redis客户端
|
||||
|
||||
Args:
|
||||
require_redis: 是否强制要求Redis连接成功,如果为True则连接失败时抛出异常
|
||||
|
||||
Returns:
|
||||
Redis客户端实例,如果未初始化或连接失败返回None(当require_redis=False时)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当require_redis=True且连接失败时
|
||||
"""
|
||||
global _redis_manager
|
||||
|
||||
if _redis_manager is None:
|
||||
_redis_manager = RedisClientManager()
|
||||
await _redis_manager.initialize(require_redis=require_redis)
|
||||
|
||||
return _redis_manager.get_client()
|
||||
|
||||
|
||||
def get_redis_client_sync() -> Optional[aioredis.Redis]:
|
||||
"""
|
||||
同步获取Redis客户端(不会初始化)
|
||||
|
||||
Returns:
|
||||
Redis客户端实例或None
|
||||
"""
|
||||
global _redis_manager
|
||||
|
||||
if _redis_manager is None:
|
||||
return None
|
||||
|
||||
return _redis_manager.get_client()
|
||||
|
||||
|
||||
async def close_redis_client() -> None:
|
||||
"""关闭全局Redis客户端"""
|
||||
global _redis_manager
|
||||
|
||||
if _redis_manager:
|
||||
await _redis_manager.close()
|
||||
|
||||
|
||||
def get_redis_state() -> RedisState:
|
||||
"""
|
||||
获取 Redis 连接状态(同步方法)
|
||||
|
||||
Returns:
|
||||
Redis 连接状态枚举
|
||||
"""
|
||||
global _redis_manager
|
||||
|
||||
if _redis_manager is None:
|
||||
return RedisState.NOT_INITIALIZED
|
||||
|
||||
return _redis_manager.get_state()
|
||||
|
||||
|
||||
def get_redis_circuit_info() -> dict:
|
||||
"""
|
||||
获取 Redis 熔断器详细信息(同步方法)
|
||||
|
||||
Returns:
|
||||
熔断器状态字典
|
||||
"""
|
||||
global _redis_manager
|
||||
|
||||
if _redis_manager is None:
|
||||
return {
|
||||
"state": RedisState.NOT_INITIALIZED.value,
|
||||
"consecutive_failures": 0,
|
||||
"circuit_threshold": 3,
|
||||
"last_error": None,
|
||||
}
|
||||
|
||||
return _redis_manager.get_circuit_info()
|
||||
|
||||
|
||||
def reset_redis_circuit_breaker() -> bool:
|
||||
"""
|
||||
手动重置 Redis 熔断器(同步方法)
|
||||
|
||||
Returns:
|
||||
是否成功重置
|
||||
"""
|
||||
global _redis_manager
|
||||
|
||||
if _redis_manager is None:
|
||||
return False
|
||||
|
||||
_redis_manager.reset_circuit_breaker()
|
||||
return True
|
||||
3
src/config/__init__.py
Normal file
3
src/config/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .settings import Config, config
|
||||
|
||||
__all__ = ["Config", "config"]
|
||||
235
src/config/constants.py
Normal file
235
src/config/constants.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Constants for better maintainability
|
||||
# ==============================================================================
|
||||
# 缓存相关常量
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# 缓存 TTL(秒)
|
||||
class CacheTTL:
|
||||
"""缓存过期时间配置(秒)"""
|
||||
|
||||
# 用户缓存 - 用户信息变更较频繁
|
||||
USER = 60 # 1分钟
|
||||
|
||||
# Provider/Model 缓存 - 配置变更不频繁
|
||||
PROVIDER = 300 # 5分钟
|
||||
MODEL = 300 # 5分钟
|
||||
MODEL_MAPPING = 300 # 5分钟
|
||||
|
||||
# 缓存亲和性 - 对应 provider_api_key.cache_ttl_minutes 默认值
|
||||
CACHE_AFFINITY = 300 # 5分钟
|
||||
|
||||
# L1 本地缓存(用于减少 Redis 访问)
|
||||
L1_LOCAL = 3 # 3秒
|
||||
|
||||
# 并发锁 TTL - 防止死锁
|
||||
CONCURRENCY_LOCK = 600 # 10分钟
|
||||
|
||||
|
||||
# 缓存容量限制
|
||||
class CacheSize:
|
||||
"""缓存容量配置"""
|
||||
|
||||
# 默认 LRU 缓存大小
|
||||
DEFAULT = 1000
|
||||
|
||||
# ModelMapping 缓存(可能有较多别名)
|
||||
MODEL_MAPPING = 2000
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 并发和限流常量
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ConcurrencyDefaults:
|
||||
"""并发控制默认值"""
|
||||
|
||||
# 自适应并发初始限制(保守值)
|
||||
INITIAL_LIMIT = 3
|
||||
|
||||
# 429错误后的冷却时间(分钟)- 在此期间不会增加并发限制
|
||||
COOLDOWN_AFTER_429_MINUTES = 5
|
||||
|
||||
# 探测间隔上限(分钟)- 用于长期探测策略
|
||||
MAX_PROBE_INTERVAL_MINUTES = 60
|
||||
|
||||
# === 基于滑动窗口的扩容参数 ===
|
||||
# 滑动窗口大小(采样点数量)
|
||||
UTILIZATION_WINDOW_SIZE = 20
|
||||
|
||||
# 滑动窗口时间范围(秒)- 只保留最近这段时间内的采样
|
||||
UTILIZATION_WINDOW_SECONDS = 120 # 2分钟
|
||||
|
||||
# 利用率阈值 - 窗口内平均利用率 >= 此值时考虑扩容
|
||||
UTILIZATION_THRESHOLD = 0.7 # 70%
|
||||
|
||||
# 高利用率采样比例 - 窗口内超过阈值的采样点比例 >= 此值时触发扩容
|
||||
HIGH_UTILIZATION_RATIO = 0.6 # 60% 的采样点高于阈值
|
||||
|
||||
# 最小采样数 - 窗口内至少需要这么多采样才能做出扩容决策
|
||||
MIN_SAMPLES_FOR_DECISION = 5
|
||||
|
||||
# 扩容步长 - 每次扩容增加的并发数
|
||||
INCREASE_STEP = 1
|
||||
|
||||
# 缩容乘数 - 遇到 429 时的缩容比例
|
||||
DECREASE_MULTIPLIER = 0.7
|
||||
|
||||
# 最大并发限制上限
|
||||
MAX_CONCURRENT_LIMIT = 100
|
||||
|
||||
# 最小并发限制下限
|
||||
MIN_CONCURRENT_LIMIT = 1
|
||||
|
||||
# === 探测性扩容参数 ===
|
||||
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
|
||||
PROBE_INCREASE_INTERVAL_MINUTES = 30
|
||||
|
||||
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
|
||||
PROBE_INCREASE_MIN_REQUESTS = 10
|
||||
|
||||
|
||||
class CircuitBreakerDefaults:
|
||||
"""熔断器配置默认值(滑动窗口 + 半开状态模式)
|
||||
|
||||
新的熔断器基于滑动窗口错误率,而不是累计健康度。
|
||||
支持半开状态,允许少量请求验证服务是否恢复。
|
||||
"""
|
||||
|
||||
# === 滑动窗口配置 ===
|
||||
# 滑动窗口大小(最近 N 次请求)
|
||||
WINDOW_SIZE = 20
|
||||
|
||||
# 滑动窗口时间范围(秒)- 只保留最近这段时间内的请求记录
|
||||
WINDOW_SECONDS = 300 # 5分钟
|
||||
|
||||
# 最小请求数 - 窗口内至少需要这么多请求才能做出熔断决策
|
||||
MIN_REQUESTS_FOR_DECISION = 5
|
||||
|
||||
# 错误率阈值 - 窗口内错误率超过此值时触发熔断
|
||||
ERROR_RATE_THRESHOLD = 0.5 # 50%
|
||||
|
||||
# === 半开状态配置 ===
|
||||
# 半开状态持续时间(秒)- 在此期间允许少量请求通过
|
||||
HALF_OPEN_DURATION_SECONDS = 30
|
||||
|
||||
# 半开状态成功阈值 - 达到此成功次数则关闭熔断器
|
||||
HALF_OPEN_SUCCESS_THRESHOLD = 3
|
||||
|
||||
# 半开状态失败阈值 - 达到此失败次数则重新打开熔断器
|
||||
HALF_OPEN_FAILURE_THRESHOLD = 2
|
||||
|
||||
# === 熔断恢复配置 ===
|
||||
# 初始探测间隔(秒)- 熔断后多久进入半开状态
|
||||
INITIAL_RECOVERY_SECONDS = 30
|
||||
|
||||
# 探测间隔退避倍数
|
||||
RECOVERY_BACKOFF_MULTIPLIER = 2
|
||||
|
||||
# 最大探测间隔(秒)
|
||||
MAX_RECOVERY_SECONDS = 300 # 5分钟
|
||||
|
||||
# === 旧参数(向后兼容,仍用于展示健康度)===
|
||||
# 成功时健康度增量
|
||||
SUCCESS_INCREMENT = 0.15
|
||||
|
||||
# 失败时健康度减量
|
||||
FAILURE_DECREMENT = 0.03
|
||||
|
||||
# 探测成功后的快速恢复健康度
|
||||
PROBE_RECOVERY_SCORE = 0.5
|
||||
|
||||
|
||||
class AdaptiveReservationDefaults:
|
||||
"""动态预留比例配置默认值
|
||||
|
||||
动态预留机制根据学习置信度和负载自动调整缓存用户预留比例,
|
||||
解决固定 30% 预留在学习初期和负载变化时的不适应问题。
|
||||
"""
|
||||
|
||||
# 探测阶段配置
|
||||
PROBE_PHASE_REQUESTS = 100 # 探测阶段请求数阈值
|
||||
PROBE_RESERVATION = 0.1 # 探测阶段预留比例(10%)
|
||||
|
||||
# 稳定阶段配置
|
||||
STABLE_MIN_RESERVATION = 0.1 # 稳定阶段最小预留(10%)
|
||||
STABLE_MAX_RESERVATION = 0.35 # 稳定阶段最大预留(35%)
|
||||
|
||||
# 置信度计算参数
|
||||
SUCCESS_COUNT_FOR_FULL_CONFIDENCE = 50 # 连续成功多少次达到满置信
|
||||
COOLDOWN_HOURS_FOR_FULL_CONFIDENCE = 24 # 429后多少小时达到满置信
|
||||
|
||||
# 负载阈值
|
||||
LOW_LOAD_THRESHOLD = 0.5 # 低负载阈值(50%)
|
||||
HIGH_LOAD_THRESHOLD = 0.8 # 高负载阈值(80%)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 超时和重试常量
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TimeoutDefaults:
|
||||
"""超时配置默认值(秒)"""
|
||||
|
||||
# HTTP 请求默认超时
|
||||
HTTP_REQUEST = 300 # 5分钟
|
||||
|
||||
# 数据库连接池获取超时
|
||||
DB_POOL = 30
|
||||
|
||||
# Redis 操作超时
|
||||
REDIS_OPERATION = 5
|
||||
|
||||
|
||||
class RetryDefaults:
|
||||
"""重试配置默认值"""
|
||||
|
||||
# 最大重试次数
|
||||
MAX_RETRIES = 3
|
||||
|
||||
# 重试基础延迟(秒)
|
||||
BASE_DELAY = 1.0
|
||||
|
||||
# 重试延迟倍数(指数退避)
|
||||
DELAY_MULTIPLIER = 2.0
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 消息格式常量
|
||||
# ==============================================================================
|
||||
|
||||
# 角色常量
|
||||
ROLE_USER = "user"
|
||||
ROLE_ASSISTANT = "assistant"
|
||||
ROLE_SYSTEM = "system"
|
||||
ROLE_TOOL = "tool"
|
||||
|
||||
# 内容类型常量
|
||||
CONTENT_TEXT = "text"
|
||||
CONTENT_IMAGE = "image"
|
||||
CONTENT_TOOL_USE = "tool_use"
|
||||
CONTENT_TOOL_RESULT = "tool_result"
|
||||
|
||||
# 工具常量
|
||||
TOOL_FUNCTION = "function"
|
||||
|
||||
# 停止原因常量
|
||||
STOP_END_TURN = "end_turn"
|
||||
STOP_MAX_TOKENS = "max_tokens"
|
||||
STOP_TOOL_USE = "tool_use"
|
||||
STOP_ERROR = "error"
|
||||
|
||||
# 事件类型常量
|
||||
EVENT_MESSAGE_START = "message_start"
|
||||
EVENT_MESSAGE_STOP = "message_stop"
|
||||
EVENT_MESSAGE_DELTA = "message_delta"
|
||||
EVENT_CONTENT_BLOCK_START = "content_block_start"
|
||||
EVENT_CONTENT_BLOCK_STOP = "content_block_stop"
|
||||
EVENT_CONTENT_BLOCK_DELTA = "content_block_delta"
|
||||
EVENT_PING = "ping"
|
||||
|
||||
# Delta类型常量
|
||||
DELTA_TEXT = "text_delta"
|
||||
DELTA_INPUT_JSON = "input_json_delta"
|
||||
259
src/config/settings.py
Normal file
259
src/config/settings.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
服务器配置
|
||||
从环境变量或 .env 文件加载配置
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 尝试加载 .env 文件
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
|
||||
env_file = Path(".env")
|
||||
if env_file.exists():
|
||||
load_dotenv(env_file)
|
||||
except ImportError:
|
||||
# 如果没有安装 python-dotenv,仍然可以从环境变量读取
|
||||
pass
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self) -> None:
|
||||
# 服务器配置
|
||||
self.host = os.getenv("HOST", "0.0.0.0")
|
||||
self.port = int(os.getenv("PORT", "8084"))
|
||||
self.log_level = os.getenv("LOG_LEVEL", "INFO")
|
||||
self.worker_processes = int(
|
||||
os.getenv("WEB_CONCURRENCY", os.getenv("GUNICORN_WORKERS", "4"))
|
||||
)
|
||||
|
||||
# PostgreSQL 连接池计算相关配置
|
||||
# PG_MAX_CONNECTIONS: PostgreSQL 的 max_connections 设置(默认 100)
|
||||
# PG_RESERVED_CONNECTIONS: 为其他应用/管理工具预留的连接数(默认 10)
|
||||
self.pg_max_connections = int(os.getenv("PG_MAX_CONNECTIONS", "100"))
|
||||
self.pg_reserved_connections = int(os.getenv("PG_RESERVED_CONNECTIONS", "10"))
|
||||
|
||||
# 数据库配置 - 延迟验证,支持测试环境覆盖
|
||||
self._database_url = os.getenv("DATABASE_URL")
|
||||
|
||||
# JWT配置
|
||||
self.jwt_secret_key = os.getenv("JWT_SECRET_KEY", None)
|
||||
self.jwt_algorithm = os.getenv("JWT_ALGORITHM", "HS256")
|
||||
self.jwt_expiration_hours = int(os.getenv("JWT_EXPIRATION_HOURS", "24"))
|
||||
|
||||
# 加密密钥配置(独立于JWT密钥,用于敏感数据加密)
|
||||
self.encryption_key = os.getenv("ENCRYPTION_KEY", None)
|
||||
|
||||
# 环境配置 - 智能检测
|
||||
# Docker 部署默认为生产环境,本地开发默认为开发环境
|
||||
is_docker = (
|
||||
os.path.exists("/.dockerenv")
|
||||
or os.environ.get("DOCKER_CONTAINER", "false").lower() == "true"
|
||||
)
|
||||
default_env = "production" if is_docker else "development"
|
||||
self.environment = os.getenv("ENVIRONMENT", default_env)
|
||||
|
||||
# Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖)
|
||||
redis_required_env = os.getenv("REDIS_REQUIRED")
|
||||
if redis_required_env is None:
|
||||
self.require_redis = self.environment not in {"development", "test", "testing"}
|
||||
else:
|
||||
self.require_redis = redis_required_env.lower() == "true"
|
||||
|
||||
# CORS配置 - 使用环境变量配置允许的源
|
||||
# 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com"
|
||||
cors_origins = os.getenv("CORS_ORIGINS", "")
|
||||
if cors_origins:
|
||||
self.cors_origins = [
|
||||
origin.strip() for origin in cors_origins.split(",") if origin.strip()
|
||||
]
|
||||
else:
|
||||
# 默认: 开发环境允许本地前端,生产环境不允许任何跨域
|
||||
if self.environment == "development":
|
||||
self.cors_origins = [
|
||||
"http://localhost:3000",
|
||||
"http://localhost:5173", # Vite 默认端口
|
||||
"http://127.0.0.1:3000",
|
||||
"http://127.0.0.1:5173",
|
||||
]
|
||||
else:
|
||||
# 生产环境默认不允许跨域,必须显式配置
|
||||
self.cors_origins = []
|
||||
|
||||
# CORS是否允许凭证(Cookie/Authorization header)
|
||||
# 注意: allow_credentials=True 时不能使用 allow_origins=["*"]
|
||||
self.cors_allow_credentials = os.getenv("CORS_ALLOW_CREDENTIALS", "true").lower() == "true"
|
||||
|
||||
# 管理员账户配置(用于初始化)
|
||||
self.admin_email = os.getenv("ADMIN_EMAIL", "admin@localhost")
|
||||
self.admin_username = os.getenv("ADMIN_USERNAME", "admin")
|
||||
|
||||
# 管理员密码 - 必须在环境变量中设置
|
||||
admin_password_env = os.getenv("ADMIN_PASSWORD")
|
||||
if admin_password_env:
|
||||
self.admin_password = admin_password_env
|
||||
else:
|
||||
# 未设置密码,启动时会报错
|
||||
self.admin_password = ""
|
||||
self._missing_admin_password = True
|
||||
|
||||
# API Key 配置
|
||||
self.api_key_prefix = os.getenv("API_KEY_PREFIX", "sk")
|
||||
|
||||
# LLM API 速率限制配置(每分钟请求数)
|
||||
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
|
||||
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
|
||||
|
||||
# 异常处理配置
|
||||
# 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers
|
||||
# 设置为 False 时,使用全局异常处理器统一处理
|
||||
self.propagate_provider_exceptions = os.getenv(
|
||||
"PROPAGATE_PROVIDER_EXCEPTIONS", "true"
|
||||
).lower() == "true"
|
||||
|
||||
# 数据库连接池配置 - 智能自动调整
|
||||
# 系统会根据 Worker 数量和 PostgreSQL 限制自动计算安全值
|
||||
self.db_pool_size = int(os.getenv("DB_POOL_SIZE") or self._auto_pool_size())
|
||||
self.db_max_overflow = int(os.getenv("DB_MAX_OVERFLOW") or self._auto_max_overflow())
|
||||
self.db_pool_timeout = int(os.getenv("DB_POOL_TIMEOUT", "60"))
|
||||
self.db_pool_recycle = int(os.getenv("DB_POOL_RECYCLE", "3600"))
|
||||
self.db_pool_warn_threshold = int(os.getenv("DB_POOL_WARN_THRESHOLD", "70"))
|
||||
|
||||
# 验证连接池配置
|
||||
self._validate_pool_config()
|
||||
|
||||
def _auto_pool_size(self) -> int:
|
||||
"""
|
||||
智能计算连接池大小 - 根据 Worker 数量和 PostgreSQL 限制计算
|
||||
|
||||
公式: (pg_max_connections - reserved) / workers / 2
|
||||
除以 2 是因为还要预留 max_overflow 的空间
|
||||
"""
|
||||
available_connections = self.pg_max_connections - self.pg_reserved_connections
|
||||
# 每个 Worker 可用的连接数(pool_size + max_overflow)
|
||||
per_worker_total = available_connections // max(self.worker_processes, 1)
|
||||
# pool_size 取总数的一半,另一半留给 overflow
|
||||
pool_size = max(per_worker_total // 2, 5) # 最小 5 个连接
|
||||
return min(pool_size, 30) # 最大 30 个连接
|
||||
|
||||
def _auto_max_overflow(self) -> int:
|
||||
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
|
||||
return self.db_pool_size
|
||||
|
||||
def _validate_pool_config(self) -> None:
|
||||
"""验证连接池配置是否安全"""
|
||||
total_per_worker = self.db_pool_size + self.db_max_overflow
|
||||
total_all_workers = total_per_worker * self.worker_processes
|
||||
safe_limit = self.pg_max_connections - self.pg_reserved_connections
|
||||
|
||||
if total_all_workers > safe_limit:
|
||||
# 记录警告(不抛出异常,避免阻止启动)
|
||||
self._pool_config_warning = (
|
||||
f"[WARN] 数据库连接池配置可能超过 PostgreSQL 限制: "
|
||||
f"{self.worker_processes} workers x {total_per_worker} connections = "
|
||||
f"{total_all_workers} > {safe_limit} (pg_max_connections - reserved). "
|
||||
f"建议调整 DB_POOL_SIZE 或 PG_MAX_CONNECTIONS 环境变量。"
|
||||
)
|
||||
else:
|
||||
self._pool_config_warning = None
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
"""
|
||||
数据库 URL(延迟验证)
|
||||
|
||||
在测试环境中可以通过依赖注入覆盖,而不会在导入时崩溃
|
||||
"""
|
||||
if not self._database_url:
|
||||
raise ValueError(
|
||||
"DATABASE_URL environment variable is required. "
|
||||
"Example: postgresql://username:password@localhost:5432/dbname"
|
||||
)
|
||||
return self._database_url
|
||||
|
||||
@database_url.setter
|
||||
def database_url(self, value: str):
|
||||
"""允许在测试中设置数据库 URL"""
|
||||
self._database_url = value
|
||||
|
||||
def log_startup_warnings(self) -> None:
|
||||
"""
|
||||
记录启动时的安全警告
|
||||
这个方法应该在 logger 初始化后调用
|
||||
"""
|
||||
from src.core.logger import logger
|
||||
|
||||
# 连接池配置警告
|
||||
if hasattr(self, "_pool_config_warning") and self._pool_config_warning:
|
||||
logger.warning(self._pool_config_warning)
|
||||
|
||||
# 管理员密码检查(必须在环境变量中设置)
|
||||
if hasattr(self, "_missing_admin_password") and self._missing_admin_password:
|
||||
logger.error("必须设置 ADMIN_PASSWORD 环境变量!")
|
||||
raise ValueError("ADMIN_PASSWORD environment variable must be set!")
|
||||
|
||||
# JWT 密钥警告
|
||||
if not self.jwt_secret_key:
|
||||
if self.environment == "production":
|
||||
logger.error(
|
||||
"生产环境未设置 JWT_SECRET_KEY! 这是严重的安全漏洞。"
|
||||
"使用 'python generate_keys.py' 生成安全密钥。"
|
||||
)
|
||||
else:
|
||||
logger.warning("JWT_SECRET_KEY 未设置,将使用默认密钥(仅限开发环境)")
|
||||
|
||||
# 加密密钥警告
|
||||
if not self.encryption_key and self.environment != "production":
|
||||
logger.warning(
|
||||
"ENCRYPTION_KEY 未设置,使用开发环境默认密钥。生产环境必须设置。"
|
||||
)
|
||||
|
||||
# CORS 配置警告(生产环境)
|
||||
if self.environment == "production" and not self.cors_origins:
|
||||
logger.warning("生产环境 CORS 未配置,前端将无法访问 API。请设置 CORS_ORIGINS。")
|
||||
|
||||
def validate_security_config(self) -> list[str]:
|
||||
"""
|
||||
验证安全配置,返回错误列表
|
||||
生产环境会阻止启动,开发环境仅警告
|
||||
|
||||
Returns:
|
||||
错误消息列表(空列表表示验证通过)
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
if self.environment == "production":
|
||||
# 生产环境必须设置 JWT 密钥
|
||||
if not self.jwt_secret_key:
|
||||
errors.append(
|
||||
"JWT_SECRET_KEY must be set in production. "
|
||||
"Use 'python generate_keys.py' to generate a secure key."
|
||||
)
|
||||
elif len(self.jwt_secret_key) < 32:
|
||||
errors.append("JWT_SECRET_KEY must be at least 32 characters in production.")
|
||||
|
||||
# 生产环境必须设置加密密钥
|
||||
if not self.encryption_key:
|
||||
errors.append(
|
||||
"ENCRYPTION_KEY must be set in production. "
|
||||
"Use 'python generate_keys.py' to generate a secure key."
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
def __repr__(self):
|
||||
"""配置信息字符串表示"""
|
||||
return f"""
|
||||
Configuration:
|
||||
Server: {self.host}:{self.port}
|
||||
Log Level: {self.log_level}
|
||||
Environment: {self.environment}
|
||||
"""
|
||||
|
||||
|
||||
# 创建全局配置实例
|
||||
config = Config()
|
||||
|
||||
# 在调试模式下记录配置(延迟到日志系统初始化后)
|
||||
# 这个配置信息会在应用启动时通过日志系统输出
|
||||
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
272
src/core/api_format_metadata.py
Normal file
272
src/core/api_format_metadata.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
集中维护 API 格式的元数据,避免新增格式时到处修改常量。
|
||||
|
||||
此模块与 src/formats/ 的 FormatProtocol 系统配合使用:
|
||||
- api_format_metadata: 定义格式的元数据(别名、默认路径)
|
||||
- src/formats/: 定义格式的协议实现(解析、转换、验证)
|
||||
|
||||
使用方式:
|
||||
# 解析格式别名
|
||||
from src.core.api_format_metadata import resolve_api_format
|
||||
api_format = resolve_api_format("claude") # -> APIFormat.CLAUDE
|
||||
|
||||
# 获取格式协议
|
||||
from src.core.api_format_metadata import get_format_protocol
|
||||
protocol = get_format_protocol(APIFormat.CLAUDE) # -> ClaudeProtocol
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from types import MappingProxyType
|
||||
from typing import Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Union
|
||||
|
||||
from .enums import APIFormat
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ApiFormatDefinition:
|
||||
"""
|
||||
描述一个 API 格式的所有通用信息。
|
||||
|
||||
- aliases: 用于 detect_api_format 的 provider 别名或快捷名称
|
||||
- default_path: 上游默认请求路径(如 /v1/messages),可通过 Endpoint.custom_path 覆盖
|
||||
- path_prefix: 本站路径前缀(如 /claude, /openai),为空表示无前缀
|
||||
- auth_header: 认证头名称 (如 "x-api-key", "x-goog-api-key")
|
||||
- auth_type: 认证类型 ("header" 直接放值, "bearer" 加 Bearer 前缀)
|
||||
"""
|
||||
|
||||
api_format: APIFormat
|
||||
aliases: Sequence[str] = field(default_factory=tuple)
|
||||
default_path: str = "/" # 上游默认请求路径
|
||||
path_prefix: str = "" # 本站路径前缀,为空表示无前缀
|
||||
auth_header: str = "Authorization"
|
||||
auth_type: str = "bearer" # "bearer" or "header"
|
||||
|
||||
def iter_aliases(self) -> Iterable[str]:
|
||||
"""返回大小写统一后的别名集合,包含枚举名本身。"""
|
||||
yield normalize_alias_value(self.api_format.value)
|
||||
for alias in self.aliases:
|
||||
normalized = normalize_alias_value(alias)
|
||||
if normalized:
|
||||
yield normalized
|
||||
|
||||
|
||||
_DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
|
||||
APIFormat.CLAUDE: ApiFormatDefinition(
|
||||
api_format=APIFormat.CLAUDE,
|
||||
aliases=("claude", "anthropic", "claude_compatible"),
|
||||
default_path="/v1/messages",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/claude"
|
||||
auth_header="x-api-key",
|
||||
auth_type="header",
|
||||
),
|
||||
APIFormat.CLAUDE_CLI: ApiFormatDefinition(
|
||||
api_format=APIFormat.CLAUDE_CLI,
|
||||
aliases=("claude_cli", "claude-cli"),
|
||||
default_path="/v1/messages",
|
||||
path_prefix="", # 与 CLAUDE 共享入口,通过 header 区分
|
||||
auth_header="authorization",
|
||||
auth_type="bearer",
|
||||
),
|
||||
APIFormat.OPENAI: ApiFormatDefinition(
|
||||
api_format=APIFormat.OPENAI,
|
||||
aliases=(
|
||||
"openai",
|
||||
"deepseek",
|
||||
"grok",
|
||||
"moonshot",
|
||||
"zhipu",
|
||||
"qwen",
|
||||
"baichuan",
|
||||
"minimax",
|
||||
"openai_compatible",
|
||||
),
|
||||
default_path="/v1/chat/completions",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/openai"
|
||||
auth_header="Authorization",
|
||||
auth_type="bearer",
|
||||
),
|
||||
APIFormat.OPENAI_CLI: ApiFormatDefinition(
|
||||
api_format=APIFormat.OPENAI_CLI,
|
||||
aliases=("openai_cli", "responses"),
|
||||
default_path="/responses",
|
||||
path_prefix="",
|
||||
auth_header="Authorization",
|
||||
auth_type="bearer",
|
||||
),
|
||||
APIFormat.GEMINI: ApiFormatDefinition(
|
||||
api_format=APIFormat.GEMINI,
|
||||
aliases=("gemini", "google", "vertex"),
|
||||
default_path="/v1beta/models/{model}:{action}",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/gemini"
|
||||
auth_header="x-goog-api-key",
|
||||
auth_type="header",
|
||||
),
|
||||
APIFormat.GEMINI_CLI: ApiFormatDefinition(
|
||||
api_format=APIFormat.GEMINI_CLI,
|
||||
aliases=("gemini_cli", "gemini-cli"),
|
||||
default_path="/v1beta/models/{model}:{action}",
|
||||
path_prefix="", # 与 GEMINI 共享入口
|
||||
auth_header="x-goog-api-key",
|
||||
auth_type="header",
|
||||
),
|
||||
}
|
||||
|
||||
# 对外只暴露只读视图,避免被随意修改
|
||||
API_FORMAT_DEFINITIONS: Mapping[APIFormat, ApiFormatDefinition] = MappingProxyType(_DEFINITIONS)
|
||||
|
||||
|
||||
def get_api_format_definition(api_format: APIFormat) -> ApiFormatDefinition:
|
||||
"""获取指定格式的定义,不存在时抛出 KeyError。"""
|
||||
return API_FORMAT_DEFINITIONS[api_format]
|
||||
|
||||
|
||||
def list_api_format_definitions() -> List[ApiFormatDefinition]:
|
||||
"""返回所有定义的浅拷贝列表,供遍历使用。"""
|
||||
return list(API_FORMAT_DEFINITIONS.values())
|
||||
|
||||
|
||||
def build_alias_lookup() -> Dict[str, APIFormat]:
|
||||
"""
|
||||
构建 alias -> APIFormat 的查找表。
|
||||
每次调用都会返回新的 dict,避免可变全局引发并发问题。
|
||||
"""
|
||||
lookup: MutableMapping[str, APIFormat] = {}
|
||||
for definition in API_FORMAT_DEFINITIONS.values():
|
||||
for alias in definition.iter_aliases():
|
||||
lookup.setdefault(alias, definition.api_format)
|
||||
return dict(lookup)
|
||||
|
||||
|
||||
def get_default_path(api_format: APIFormat) -> str:
|
||||
"""
|
||||
获取该格式的上游默认请求路径。
|
||||
|
||||
可通过 Endpoint.custom_path 覆盖。
|
||||
"""
|
||||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||||
return definition.default_path if definition else "/"
|
||||
|
||||
|
||||
def get_local_path(api_format: APIFormat) -> str:
|
||||
"""
|
||||
获取该格式的本站入口路径。
|
||||
|
||||
本站入口路径 = path_prefix + default_path
|
||||
例如:path_prefix="/openai" + default_path="/v1/chat/completions" -> "/openai/v1/chat/completions"
|
||||
"""
|
||||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||||
if definition:
|
||||
prefix = definition.path_prefix or ""
|
||||
return prefix + definition.default_path
|
||||
return "/"
|
||||
|
||||
|
||||
def get_auth_config(api_format: APIFormat) -> tuple[str, str]:
|
||||
"""
|
||||
获取该格式的认证配置。
|
||||
|
||||
Returns:
|
||||
(auth_header, auth_type) 元组
|
||||
- auth_header: 认证头名称
|
||||
- auth_type: "bearer" 或 "header"
|
||||
"""
|
||||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||||
if definition:
|
||||
return definition.auth_header, definition.auth_type
|
||||
return "Authorization", "bearer"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _alias_lookup_cache() -> Dict[str, APIFormat]:
|
||||
"""缓存 alias -> APIFormat 查找表,减少重复构建。"""
|
||||
return build_alias_lookup()
|
||||
|
||||
|
||||
def resolve_api_format_alias(value: str) -> Optional[APIFormat]:
|
||||
"""根据别名查找 APIFormat,找不到时返回 None。"""
|
||||
if not value:
|
||||
return None
|
||||
normalized = normalize_alias_value(value)
|
||||
if not normalized:
|
||||
return None
|
||||
return _alias_lookup_cache().get(normalized)
|
||||
|
||||
|
||||
def resolve_api_format(
|
||||
value: Union[str, APIFormat, None],
|
||||
default: Optional[APIFormat] = None,
|
||||
) -> Optional[APIFormat]:
|
||||
"""
|
||||
将任意字符串/枚举值解析为 APIFormat。
|
||||
|
||||
Args:
|
||||
value: 可以是 APIFormat 或任意字符串/别名
|
||||
default: 未解析成功时返回的默认值
|
||||
"""
|
||||
if isinstance(value, APIFormat):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return default
|
||||
upper = stripped.upper()
|
||||
if upper in APIFormat.__members__:
|
||||
return APIFormat[upper]
|
||||
alias = resolve_api_format_alias(stripped)
|
||||
if alias:
|
||||
return alias
|
||||
return default
|
||||
|
||||
|
||||
def register_api_format_definition(definition: ApiFormatDefinition, *, override: bool = False):
|
||||
"""
|
||||
注册或覆盖 API 格式定义,允许运行时扩展。
|
||||
|
||||
Args:
|
||||
definition: 要注册的定义
|
||||
override: 若目标枚举已存在,是否允许覆盖
|
||||
"""
|
||||
existing = _DEFINITIONS.get(definition.api_format)
|
||||
if existing and not override:
|
||||
raise ValueError(f"{definition.api_format.value} 已存在,如需覆盖请设置 override=True")
|
||||
_DEFINITIONS[definition.api_format] = definition
|
||||
_refresh_metadata_cache()
|
||||
|
||||
|
||||
def _refresh_metadata_cache():
|
||||
"""更新别名缓存,供注册函数调用。"""
|
||||
_alias_lookup_cache.cache_clear()
|
||||
|
||||
|
||||
def normalize_alias_value(value: str) -> str:
|
||||
"""统一别名格式:去空白、转小写,并将非字母数字转为单个下划线。"""
|
||||
if value is None:
|
||||
return ""
|
||||
text = value.strip().lower()
|
||||
# 将所有非字母数字字符替换为下划线,并折叠连续的下划线
|
||||
text = re.sub(r"[^a-z0-9]+", "_", text)
|
||||
return text.strip("_")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 格式判断工具
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def is_cli_api_format(api_format: APIFormat) -> bool:
|
||||
"""
|
||||
判断是否为 CLI 透传格式。
|
||||
|
||||
Args:
|
||||
api_format: APIFormat 枚举值
|
||||
|
||||
Returns:
|
||||
True 如果是 CLI 格式
|
||||
"""
|
||||
from src.api.handlers.base.parsers import is_cli_format
|
||||
|
||||
return is_cli_format(api_format.value)
|
||||
115
src/core/batch_committer.py
Normal file
115
src/core/batch_committer.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
批量提交器 - 减少数据库 commit 次数,提升并发能力
|
||||
|
||||
核心思想:
|
||||
- 非关键数据(监控、统计)不立即 commit
|
||||
- 在后台定期批量 commit
|
||||
- 关键数据(计费)仍然立即 commit
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Set
|
||||
|
||||
from src.core.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class BatchCommitter:
|
||||
"""批量提交管理器"""
|
||||
|
||||
def __init__(self, interval_seconds: float = 1.0):
|
||||
"""
|
||||
Args:
|
||||
interval_seconds: 批量提交间隔(秒)
|
||||
"""
|
||||
self.interval_seconds = interval_seconds
|
||||
self._pending_sessions: Set[Session] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self._task = None
|
||||
|
||||
async def start(self):
|
||||
"""启动后台批量提交任务"""
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._batch_commit_loop())
|
||||
logger.info(f"批量提交器已启动,间隔: {self.interval_seconds}s")
|
||||
|
||||
async def stop(self):
|
||||
"""停止后台任务"""
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
logger.info("批量提交器已停止")
|
||||
|
||||
def mark_dirty(self, session: Session):
|
||||
"""标记 Session 有待提交的更改"""
|
||||
self._pending_sessions.add(session)
|
||||
|
||||
async def _batch_commit_loop(self):
|
||||
"""后台批量提交循环"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.interval_seconds)
|
||||
await self._commit_all()
|
||||
except asyncio.CancelledError:
|
||||
# 关闭前提交所有待处理的
|
||||
await self._commit_all()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"批量提交出错: {e}")
|
||||
|
||||
async def _commit_all(self):
|
||||
"""提交所有待处理的 Session"""
|
||||
async with self._lock:
|
||||
if not self._pending_sessions:
|
||||
return
|
||||
|
||||
sessions_to_commit = list(self._pending_sessions)
|
||||
self._pending_sessions.clear()
|
||||
|
||||
committed = 0
|
||||
failed = 0
|
||||
|
||||
for session in sessions_to_commit:
|
||||
try:
|
||||
session.commit()
|
||||
committed += 1
|
||||
except Exception as e:
|
||||
logger.error(f"提交 Session 失败: {e}")
|
||||
try:
|
||||
session.rollback()
|
||||
except:
|
||||
pass
|
||||
failed += 1
|
||||
|
||||
if committed > 0:
|
||||
logger.debug(f"批量提交完成: {committed} 个 Session")
|
||||
if failed > 0:
|
||||
logger.warning(f"批量提交失败: {failed} 个 Session")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_batch_committer: BatchCommitter = None
|
||||
|
||||
|
||||
def get_batch_committer() -> BatchCommitter:
|
||||
"""获取全局批量提交器"""
|
||||
global _batch_committer
|
||||
if _batch_committer is None:
|
||||
_batch_committer = BatchCommitter(interval_seconds=1.0)
|
||||
return _batch_committer
|
||||
|
||||
|
||||
async def init_batch_committer():
|
||||
"""初始化并启动批量提交器"""
|
||||
committer = get_batch_committer()
|
||||
await committer.start()
|
||||
|
||||
|
||||
async def shutdown_batch_committer():
|
||||
"""关闭批量提交器"""
|
||||
committer = get_batch_committer()
|
||||
await committer.stop()
|
||||
174
src/core/cache_service.py
Normal file
174
src/core/cache_service.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
缓存服务 - 统一的缓存抽象层
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.clients.redis_client import get_redis_client
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class CacheService:
|
||||
"""缓存服务"""
|
||||
|
||||
@staticmethod
|
||||
async def get(key: str) -> Optional[Any]:
|
||||
"""
|
||||
从缓存获取数据
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
缓存的值,如果不存在则返回 None
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return None
|
||||
|
||||
value = await redis.get(key)
|
||||
if value:
|
||||
# 尝试 JSON 反序列化
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存读取失败: {key} - {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def set(key: str, value: Any, ttl_seconds: int = 60) -> bool:
|
||||
"""
|
||||
设置缓存
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
value: 缓存值
|
||||
ttl_seconds: 过期时间(秒),默认 60 秒
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
# JSON 序列化
|
||||
if isinstance(value, (dict, list)):
|
||||
value = json.dumps(value)
|
||||
elif not isinstance(value, (str, bytes)):
|
||||
value = str(value)
|
||||
|
||||
await redis.setex(key, ttl_seconds, value)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存写入失败: {key} - {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def delete(key: str) -> bool:
|
||||
"""
|
||||
删除缓存
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
await redis.delete(key)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存删除失败: {key} - {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def exists(key: str) -> bool:
|
||||
"""
|
||||
检查缓存是否存在
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
return await redis.exists(key) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存检查失败: {key} - {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 缓存键前缀
|
||||
class CacheKeys:
|
||||
"""缓存键定义"""
|
||||
|
||||
# User 缓存(TTL 60秒)
|
||||
USER_BY_ID = "user:id:{user_id}"
|
||||
USER_BY_EMAIL = "user:email:{email}"
|
||||
|
||||
# API Key 缓存(TTL 30秒)
|
||||
APIKEY_HASH = "apikey:hash:{key_hash}"
|
||||
APIKEY_AUTH = "apikey:auth:{key_hash}" # 认证结果缓存
|
||||
|
||||
# Provider 配置缓存(TTL 300秒)
|
||||
PROVIDER_BY_ID = "provider:id:{provider_id}"
|
||||
ENDPOINT_BY_ID = "endpoint:id:{endpoint_id}"
|
||||
API_KEY_BY_ID = "api_key:id:{api_key_id}"
|
||||
|
||||
@staticmethod
|
||||
def user_by_id(user_id: str) -> str:
|
||||
"""User ID 缓存键"""
|
||||
return CacheKeys.USER_BY_ID.format(user_id=user_id)
|
||||
|
||||
@staticmethod
|
||||
def user_by_email(email: str) -> str:
|
||||
"""User Email 缓存键"""
|
||||
return CacheKeys.USER_BY_EMAIL.format(email=email)
|
||||
|
||||
@staticmethod
|
||||
def apikey_hash(key_hash: str) -> str:
|
||||
"""API Key Hash 缓存键"""
|
||||
return CacheKeys.APIKEY_HASH.format(key_hash=key_hash)
|
||||
|
||||
@staticmethod
|
||||
def apikey_auth(key_hash: str) -> str:
|
||||
"""API Key 认证结果缓存键"""
|
||||
return CacheKeys.APIKEY_AUTH.format(key_hash=key_hash)
|
||||
|
||||
@staticmethod
|
||||
def provider_by_id(provider_id: str) -> str:
|
||||
"""Provider ID 缓存键"""
|
||||
return CacheKeys.PROVIDER_BY_ID.format(provider_id=provider_id)
|
||||
|
||||
@staticmethod
|
||||
def endpoint_by_id(endpoint_id: str) -> str:
|
||||
"""Endpoint ID 缓存键"""
|
||||
return CacheKeys.ENDPOINT_BY_ID.format(endpoint_id=endpoint_id)
|
||||
|
||||
@staticmethod
|
||||
def api_key_by_id(api_key_id: str) -> str:
|
||||
"""API Key ID 缓存键"""
|
||||
return CacheKeys.API_KEY_BY_ID.format(api_key_id=api_key_id)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user