refactor: 重构限流系统和健康监控,支持按 API 格式区分

- 将 adaptive_concurrency 重命名为 adaptive_rpm,从并发控制改为 RPM 控制
- 健康监控器支持按 API 格式独立管理健康度和熔断器状态
- 新增 model_permissions 模块,支持按格式配置允许的模型
- 重构前端提供商相关表单组件,新增 Collapsible UI 组件
- 新增数据库迁移脚本支持新的数据结构
This commit is contained in:
fawney19
2026-01-10 18:43:53 +08:00
parent dd2fbf4424
commit 09e0f594ff
97 changed files with 6642 additions and 4169 deletions

View File

@@ -1,12 +1,12 @@
"""
自适应并发管理 API 端点
自适应 RPM 管理 API 端点
设计原则:
- 自适应模式由 max_concurrent 字段决定:
- max_concurrent = NULL启用自适应模式系统自动学习并调整并发限制
- max_concurrent = 数字:固定限制模式,使用用户指定的并发限制
- learned_max_concurrent自适应模式下学习到的并发限制值
- adaptive_mode 是计算字段,基于 max_concurrent 是否为 NULL
- 自适应模式由 rpm_limit 字段决定:
- rpm_limit = NULL启用自适应模式系统自动学习并调整 RPM 限制
- rpm_limit = 数字:固定限制模式,使用用户指定的 RPM 限制
- learned_rpm_limit自适应模式下学习到的 RPM 限制值
- adaptive_mode 是计算字段,基于 rpm_limit 是否为 NULL
"""
from dataclasses import dataclass
@@ -18,12 +18,13 @@ from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.config.constants import RPMDefaults
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.database import get_db
from src.models.database import ProviderAPIKey
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager
from src.services.rate_limit.adaptive_rpm import get_adaptive_rpm_manager
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive Concurrency"])
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive RPM"])
pipeline = ApiRequestPipeline()
@@ -35,19 +36,19 @@ class EnableAdaptiveRequest(BaseModel):
enabled: bool = Field(..., description="是否启用自适应模式true=自适应false=固定限制)")
fixed_limit: Optional[int] = Field(
None, ge=1, le=100, description="固定并发限制(仅当 enabled=false 时生效)"
None, ge=1, le=100, description="固定 RPM 限制(仅当 enabled=false 时生效1-100"
)
class AdaptiveStatsResponse(BaseModel):
"""自适应统计响应"""
adaptive_mode: bool = Field(..., description="是否为自适应模式(max_concurrent=NULL")
max_concurrent: Optional[int] = Field(None, description="用户配置的固定限制NULL=自适应)")
adaptive_mode: bool = Field(..., description="是否为自适应模式(rpm_limit=NULL")
rpm_limit: Optional[int] = Field(None, description="用户配置的固定限制NULL=自适应)")
effective_limit: Optional[int] = Field(
None, description="当前有效限制(自适应使用学习值,固定使用配置值)"
)
learned_limit: Optional[int] = Field(None, description="学习到的并发限制")
learned_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
concurrent_429_count: int
rpm_429_count: int
last_429_at: Optional[str]
@@ -61,11 +62,12 @@ class KeyListItem(BaseModel):
id: str
name: Optional[str]
endpoint_id: str
is_adaptive: bool = Field(..., description="是否为自适应模式max_concurrent=NULL")
max_concurrent: Optional[int] = Field(None, description="固定并发限制NULL=自适应")
provider_id: str
api_formats: List[str] = Field(default_factory=list)
is_adaptive: bool = Field(..., description="是否为自适应模式rpm_limit=NULL")
rpm_limit: Optional[int] = Field(None, description="固定 RPM 限制NULL=自适应)")
effective_limit: Optional[int] = Field(None, description="当前有效限制")
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
learned_rpm_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
concurrent_429_count: int
rpm_429_count: int
@@ -80,22 +82,22 @@ class KeyListItem(BaseModel):
)
async def list_adaptive_keys(
request: Request,
endpoint_id: Optional[str] = Query(None, description="Endpoint 过滤"),
provider_id: Optional[str] = Query(None, description="Provider 过滤"),
db: Session = Depends(get_db),
):
"""
获取所有启用自适应模式的Key列表
可选参数:
- endpoint_id: 按 Endpoint 过滤
- provider_id: 按 Provider 过滤
"""
adapter = ListAdaptiveKeysAdapter(endpoint_id=endpoint_id)
adapter = ListAdaptiveKeysAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch(
"/keys/{key_id}/mode",
summary="Toggle key's concurrency control mode",
summary="Toggle key's RPM control mode",
)
async def toggle_adaptive_mode(
key_id: str,
@@ -103,10 +105,10 @@ async def toggle_adaptive_mode(
db: Session = Depends(get_db),
):
"""
Toggle the concurrency control mode for a specific key
Toggle the RPM control mode for a specific key
Parameters:
- enabled: true=adaptive mode (max_concurrent=NULL), false=fixed limit mode
- enabled: true=adaptive mode (rpm_limit=NULL), false=fixed limit mode
- fixed_limit: fixed limit value (required when enabled=false)
"""
adapter = ToggleAdaptiveModeAdapter(key_id=key_id)
@@ -124,7 +126,7 @@ async def get_adaptive_stats(
db: Session = Depends(get_db),
):
"""
获取指定Key的自适应并发统计信息
获取指定Key的自适应 RPM 统计信息
包括:
- 当前配置
@@ -149,12 +151,12 @@ async def reset_adaptive_learning(
Reset the adaptive learning state for a specific key
Clears:
- Learned concurrency limit (learned_max_concurrent)
- Learned RPM limit (learned_rpm_limit)
- 429 error counts
- Adjustment history
Does not change:
- max_concurrent config (determines adaptive mode)
- rpm_limit config (determines adaptive mode)
"""
adapter = ResetAdaptiveLearningAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -162,40 +164,40 @@ async def reset_adaptive_learning(
@router.patch(
"/keys/{key_id}/limit",
summary="Set key to fixed concurrency limit mode",
summary="Set key to fixed RPM limit mode",
)
async def set_concurrent_limit(
async def set_rpm_limit(
key_id: str,
request: Request,
limit: int = Query(..., ge=1, le=100, description="Concurrency limit value"),
limit: int = Query(..., ge=1, le=100, description="RPM limit value (1-100)"),
db: Session = Depends(get_db),
):
"""
Set key to fixed concurrency limit mode
Set key to fixed RPM limit mode
Note:
- After setting this value, key switches to fixed limit mode and won't auto-adjust
- To restore adaptive mode, use PATCH /keys/{key_id}/mode
"""
adapter = SetConcurrentLimitAdapter(key_id=key_id, limit=limit)
adapter = SetRPMLimitAdapter(key_id=key_id, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get(
"/summary",
summary="获取自适应并发的全局统计",
summary="获取自适应 RPM 的全局统计",
)
async def get_adaptive_summary(
request: Request,
db: Session = Depends(get_db),
):
"""
获取自适应并发的全局统计摘要
获取自适应 RPM 的全局统计摘要
包括:
- 启用自适应模式的Key数量
- 总429错误数
- 并发限制调整次数
- RPM 限制调整次数
"""
adapter = AdaptiveSummaryAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -206,26 +208,29 @@ async def get_adaptive_summary(
@dataclass
class ListAdaptiveKeysAdapter(AdminApiAdapter):
endpoint_id: Optional[str] = None
provider_id: Optional[str] = None
async def handle(self, context): # type: ignore[override]
# 自适应模式:max_concurrent = NULL
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None))
if self.endpoint_id:
query = query.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
# 自适应模式:rpm_limit = NULL
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.rpm_limit.is_(None))
if self.provider_id:
query = query.filter(ProviderAPIKey.provider_id == self.provider_id)
keys = query.all()
return [
KeyListItem(
id=key.id,
name=key.name,
endpoint_id=key.endpoint_id,
is_adaptive=key.max_concurrent is None,
max_concurrent=key.max_concurrent,
provider_id=key.provider_id,
api_formats=key.api_formats or [],
is_adaptive=key.rpm_limit is None,
rpm_limit=key.rpm_limit,
effective_limit=(
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
if key.rpm_limit is None
else key.rpm_limit
),
learned_max_concurrent=key.learned_max_concurrent,
learned_rpm_limit=key.learned_rpm_limit,
concurrent_429_count=key.concurrent_429_count or 0,
rpm_429_count=key.rpm_429_count or 0,
)
@@ -252,28 +257,32 @@ class ToggleAdaptiveModeAdapter(AdminApiAdapter):
raise InvalidRequestException("请求数据验证失败")
if body.enabled:
# 启用自适应模式:将 max_concurrent 设为 NULL
key.max_concurrent = None
message = "已切换为自适应模式,系统将自动学习并调整并发限制"
# 启用自适应模式:将 rpm_limit 设为 NULL
key.rpm_limit = None
message = "已切换为自适应模式,系统将自动学习并调整 RPM 限制"
else:
# 禁用自适应模式:设置固定限制
if body.fixed_limit is None:
raise HTTPException(
status_code=400, detail="禁用自适应模式时必须提供 fixed_limit 参数"
)
key.max_concurrent = body.fixed_limit
message = f"已切换为固定限制模式,并发限制设为 {body.fixed_limit}"
key.rpm_limit = body.fixed_limit
message = f"已切换为固定限制模式,RPM 限制设为 {body.fixed_limit}"
context.db.commit()
context.db.refresh(key)
is_adaptive = key.max_concurrent is None
is_adaptive = key.rpm_limit is None
return {
"message": message,
"key_id": key.id,
"is_adaptive": is_adaptive,
"max_concurrent": key.max_concurrent,
"effective_limit": key.learned_max_concurrent if is_adaptive else key.max_concurrent,
"rpm_limit": key.rpm_limit,
"effective_limit": (
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
if is_adaptive
else key.rpm_limit
),
}
@@ -286,13 +295,13 @@ class GetAdaptiveStatsAdapter(AdminApiAdapter):
if not key:
raise HTTPException(status_code=404, detail="Key not found")
adaptive_manager = get_adaptive_manager()
adaptive_manager = get_adaptive_rpm_manager()
stats = adaptive_manager.get_adjustment_stats(key)
# 转换字段名以匹配响应模型
return AdaptiveStatsResponse(
adaptive_mode=stats["adaptive_mode"],
max_concurrent=stats["max_concurrent"],
rpm_limit=stats["rpm_limit"],
effective_limit=stats["effective_limit"],
learned_limit=stats["learned_limit"],
concurrent_429_count=stats["concurrent_429_count"],
@@ -313,13 +322,13 @@ class ResetAdaptiveLearningAdapter(AdminApiAdapter):
if not key:
raise HTTPException(status_code=404, detail="Key not found")
adaptive_manager = get_adaptive_manager()
adaptive_manager = get_adaptive_rpm_manager()
adaptive_manager.reset_learning(context.db, key)
return {"message": "学习状态已重置", "key_id": key.id}
@dataclass
class SetConcurrentLimitAdapter(AdminApiAdapter):
class SetRPMLimitAdapter(AdminApiAdapter):
key_id: str
limit: int
@@ -328,25 +337,25 @@ class SetConcurrentLimitAdapter(AdminApiAdapter):
if not key:
raise HTTPException(status_code=404, detail="Key not found")
was_adaptive = key.max_concurrent is None
key.max_concurrent = self.limit
was_adaptive = key.rpm_limit is None
key.rpm_limit = self.limit
context.db.commit()
context.db.refresh(key)
return {
"message": f"已设置为固定限制模式,并发限制为 {self.limit}",
"message": f"已设置为固定限制模式,RPM 限制为 {self.limit}",
"key_id": key.id,
"is_adaptive": False,
"max_concurrent": key.max_concurrent,
"rpm_limit": key.rpm_limit,
"previous_mode": "adaptive" if was_adaptive else "fixed",
}
class AdaptiveSummaryAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
# 自适应模式:max_concurrent = NULL
# 自适应模式:rpm_limit = NULL
adaptive_keys = (
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None)).all()
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.rpm_limit.is_(None)).all()
)
total_keys = len(adaptive_keys)

View File

@@ -1,9 +1,8 @@
"""
Endpoint 并发控制管理 API
Key RPM 限制管理 API
"""
from dataclasses import dataclass
from typing import Optional
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
@@ -12,83 +11,56 @@ from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.database import get_db
from src.models.database import ProviderAPIKey, ProviderEndpoint
from src.models.endpoint_models import (
ConcurrencyStatusResponse,
ResetConcurrencyRequest,
)
from src.models.database import ProviderAPIKey
from src.models.endpoint_models import KeyRpmStatusResponse
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
router = APIRouter(tags=["Concurrency Control"])
router = APIRouter(tags=["RPM Control"])
pipeline = ApiRequestPipeline()
@router.get("/concurrency/endpoint/{endpoint_id}", response_model=ConcurrencyStatusResponse)
async def get_endpoint_concurrency(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse:
"""
获取 Endpoint 当前并发状态
查询指定 Endpoint 的实时并发使用情况,包括当前并发数和最大并发限制。
**路径参数**:
- `endpoint_id`: Endpoint ID
**返回字段**:
- `endpoint_id`: Endpoint ID
- `endpoint_current_concurrency`: 当前并发数
- `endpoint_max_concurrent`: 最大并发限制
"""
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/concurrency/key/{key_id}", response_model=ConcurrencyStatusResponse)
async def get_key_concurrency(
@router.get("/rpm/key/{key_id}", response_model=KeyRpmStatusResponse)
async def get_key_rpm(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse:
) -> KeyRpmStatusResponse:
"""
获取 Key 当前并发状态
获取 Key 当前 RPM 状态
查询指定 API Key 的实时并发使用情况,包括当前并发数和最大并发限制。
查询指定 API Key 的实时 RPM 使用情况,包括当前 RPM 计数和最大 RPM 限制。
**路径参数**:
- `key_id`: API Key ID
**返回字段**:
- `key_id`: API Key ID
- `key_current_concurrency`: 当前并发
- `key_max_concurrent`: 最大并发限制
- `current_rpm`: 当前 RPM 计
- `rpm_limit`: RPM 限制
"""
adapter = AdminKeyConcurrencyAdapter(key_id=key_id)
adapter = AdminKeyRpmAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/concurrency")
async def reset_concurrency(
request: ResetConcurrencyRequest,
@router.delete("/rpm/key/{key_id}")
async def reset_key_rpm(
key_id: str,
http_request: Request,
db: Session = Depends(get_db),
) -> dict:
"""
重置并发计数器
重置 Key RPM 计数器
重置指定 Endpoint 或 Key 的并发计数器,用于解决计数不准确的问题。
重置指定 API Key 的 RPM 计数器,用于解决计数不准确的问题。
管理员功能,请谨慎使用。
**请求体字段**:
- `endpoint_id`: Endpoint ID可选
- `key_id`: API Key ID可选
**路径参数**:
- `key_id`: API Key ID
**返回字段**:
- `message`: 操作结果消息
"""
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id)
adapter = AdminResetKeyRpmAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
@@ -96,31 +68,7 @@ async def reset_concurrency(
@dataclass
class AdminEndpointConcurrencyAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
concurrency_manager = await get_concurrency_manager()
endpoint_count, _ = await concurrency_manager.get_current_concurrency(
endpoint_id=self.endpoint_id
)
return ConcurrencyStatusResponse(
endpoint_id=self.endpoint_id,
endpoint_current_concurrency=endpoint_count,
endpoint_max_concurrent=endpoint.max_concurrent,
)
@dataclass
class AdminKeyConcurrencyAdapter(AdminApiAdapter):
class AdminKeyRpmAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
@@ -130,23 +78,20 @@ class AdminKeyConcurrencyAdapter(AdminApiAdapter):
raise NotFoundException(f"Key {self.key_id} 不存在")
concurrency_manager = await get_concurrency_manager()
_, key_count = await concurrency_manager.get_current_concurrency(key_id=self.key_id)
key_count = await concurrency_manager.get_key_rpm_count(key_id=self.key_id)
return ConcurrencyStatusResponse(
return KeyRpmStatusResponse(
key_id=self.key_id,
key_current_concurrency=key_count,
key_max_concurrent=key.max_concurrent,
current_rpm=key_count,
rpm_limit=key.rpm_limit,
)
@dataclass
class AdminResetConcurrencyAdapter(AdminApiAdapter):
endpoint_id: Optional[str]
key_id: Optional[str]
class AdminResetKeyRpmAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
concurrency_manager = await get_concurrency_manager()
await concurrency_manager.reset_concurrency(
endpoint_id=self.endpoint_id, key_id=self.key_id
)
return {"message": "并发计数已重置"}
await concurrency_manager.reset_key_rpm(key_id=self.key_id)
return {"message": "RPM 计数已重置"}

View File

@@ -5,7 +5,7 @@ Endpoint 健康监控 API
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Dict, List
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import func
@@ -128,29 +128,32 @@ async def get_api_format_health_monitor(
async def get_key_health(
key_id: str,
request: Request,
api_format: Optional[str] = Query(None, description="API 格式(可选,如 CLAUDE、OPENAI"),
db: Session = Depends(get_db),
) -> HealthStatusResponse:
"""
获取 Key 健康状态
获取指定 API Key 的健康状态详情,包括健康分数、连续失败次数、
熔断器状态等信息。
熔断器状态等信息。支持按 API 格式查询。
**路径参数**:
- `key_id`: API Key ID
**查询参数**:
- `api_format`: 可选,指定 API 格式(如 CLAUDE、OPENAI
- 指定时返回该格式的健康度详情
- 不指定时返回所有格式的健康度摘要
**返回字段**:
- `key_id`: API Key ID
- `key_health_score`: 健康分数0.0-1.0
- `key_consecutive_failures`: 连续失败次数
- `key_last_failure_at`: 最后失败时间
- `key_is_active`: 是否活跃
- `key_statistics`: 统计信息
- `circuit_breaker_open`: 熔断器是否打开
- `circuit_breaker_open_at`: 熔断器打开时间
- `next_probe_at`: 下次探测时间
- `health_by_format`: 按格式的健康度数据(无 api_format 参数时)
- `circuit_breaker_open`: 熔断器是否打开(有 api_format 参数时)
"""
adapter = AdminKeyHealthAdapter(key_id=key_id)
adapter = AdminKeyHealthAdapter(key_id=key_id, api_format=api_format)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -158,17 +161,23 @@ async def get_key_health(
async def recover_key_health(
key_id: str,
request: Request,
api_format: Optional[str] = Query(None, description="API 格式(可选,不指定则恢复所有格式)"),
db: Session = Depends(get_db),
) -> dict:
"""
恢复 Key 健康状态
手动恢复指定 Key 的健康状态,将健康分数重置为 1.0,关闭熔断器,
取消自动禁用,并重置所有失败计数。
取消自动禁用,并重置所有失败计数。支持按 API 格式恢复。
**路径参数**:
- `key_id`: API Key ID
**查询参数**:
- `api_format`: 可选,指定 API 格式(如 CLAUDE、OPENAI
- 指定时仅恢复该格式的健康度
- 不指定时恢复所有格式
**返回字段**:
- `message`: 操作结果消息
- `details`: 详细信息
@@ -176,7 +185,7 @@ async def recover_key_health(
- `circuit_breaker_open`: 熔断器状态
- `is_active`: 是否活跃
"""
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id)
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id, api_format=api_format)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -276,34 +285,9 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
)
all_formats[api_format] = provider_count
# 1.1 获取所有活跃的 API 格式及其 API Key 数量
active_keys = (
db.query(
ProviderEndpoint.api_format,
func.count(ProviderAPIKey.id).label("key_count"),
)
.join(ProviderAPIKey, ProviderEndpoint.id == ProviderAPIKey.endpoint_id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
ProviderAPIKey.is_active.is_(True),
)
.group_by(ProviderEndpoint.api_format)
.all()
)
# 构建所有格式的 key_count 映射
key_counts: Dict[str, int] = {}
for api_format_enum, key_count in active_keys:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
key_counts[api_format] = key_count
# 1.2 建立每个 API 格式对应的 Endpoint ID 列表,供 Usage 时间线生成使用
# 1.1 建立每个 API 格式对应的 Endpoint ID 列表(用于时间线生成),并收集活跃的 provider+format 组合
endpoint_rows = (
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id, ProviderEndpoint.provider_id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
@@ -312,11 +296,32 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
.all()
)
endpoint_map: Dict[str, List[str]] = defaultdict(list)
for api_format_enum, endpoint_id in endpoint_rows:
active_provider_formats: set[tuple[str, str]] = set()
for api_format_enum, endpoint_id, provider_id in endpoint_rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
endpoint_map[api_format].append(endpoint_id)
active_provider_formats.add((str(provider_id), api_format))
# 1.2 统计每个 API 格式可用的活跃 Key 数量Key 属于 Provider通过 api_formats 关联格式)
key_counts: Dict[str, int] = {}
if active_provider_formats:
active_provider_keys = (
db.query(ProviderAPIKey.provider_id, ProviderAPIKey.api_formats)
.join(Provider, ProviderAPIKey.provider_id == Provider.id)
.filter(
Provider.is_active.is_(True),
ProviderAPIKey.is_active.is_(True),
)
.all()
)
for provider_id, api_formats in active_provider_keys:
pid = str(provider_id)
for fmt in (api_formats or []):
if (pid, fmt) not in active_provider_formats:
continue
key_counts[fmt] = key_counts.get(fmt, 0) + 1
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
# 只统计最终状态success, failed, skipped
@@ -457,28 +462,45 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
@dataclass
class AdminKeyHealthAdapter(AdminApiAdapter):
key_id: str
api_format: Optional[str] = None
async def handle(self, context): # type: ignore[override]
health_data = health_monitor.get_key_health(context.db, self.key_id)
health_data = health_monitor.get_key_health(context.db, self.key_id, self.api_format)
if not health_data:
raise NotFoundException(f"Key {self.key_id} 不存在")
return HealthStatusResponse(
key_id=health_data["key_id"],
key_health_score=health_data["health_score"],
key_consecutive_failures=health_data["consecutive_failures"],
key_last_failure_at=health_data["last_failure_at"],
key_is_active=health_data["is_active"],
key_statistics=health_data["statistics"],
circuit_breaker_open=health_data["circuit_breaker_open"],
circuit_breaker_open_at=health_data["circuit_breaker_open_at"],
next_probe_at=health_data["next_probe_at"],
)
# 构建响应
response_data = {
"key_id": health_data["key_id"],
"key_is_active": health_data["is_active"],
"key_statistics": health_data.get("statistics"),
"key_health_score": health_data.get("health_score", 1.0),
}
if self.api_format:
# 单格式查询
response_data["api_format"] = self.api_format
response_data["key_consecutive_failures"] = health_data.get("consecutive_failures")
response_data["key_last_failure_at"] = health_data.get("last_failure_at")
circuit = health_data.get("circuit_breaker", {})
response_data["circuit_breaker_open"] = circuit.get("open", False)
response_data["circuit_breaker_open_at"] = circuit.get("open_at")
response_data["next_probe_at"] = circuit.get("next_probe_at")
response_data["half_open_until"] = circuit.get("half_open_until")
response_data["half_open_successes"] = circuit.get("half_open_successes", 0)
response_data["half_open_failures"] = circuit.get("half_open_failures", 0)
else:
# 全格式查询
response_data["any_circuit_open"] = health_data.get("any_circuit_open", False)
response_data["health_by_format"] = health_data.get("health_by_format")
return HealthStatusResponse(**response_data)
@dataclass
class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
key_id: str
api_format: Optional[str] = None
async def handle(self, context): # type: ignore[override]
db = context.db
@@ -486,28 +508,38 @@ class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
key.health_score = 1.0
key.consecutive_failures = 0
key.last_failure_at = None
key.circuit_breaker_open = False
key.circuit_breaker_open_at = None
key.next_probe_at = None
# 使用 health_monitor.reset_health 重置健康度
success = health_monitor.reset_health(db, key_id=self.key_id, api_format=self.api_format)
if not success:
raise Exception("重置健康度失败")
# 如果 Key 被禁用,重新启用
if not key.is_active:
key.is_active = True
key.is_active = True # type: ignore[assignment]
db.commit()
admin_name = context.user.username if context.user else "admin"
logger.info(f"管理员恢复Key健康状态: {self.key_id} (health_score: 1.0, circuit_breaker: closed)")
return {
"message": "Key已完全恢复",
"details": {
"health_score": 1.0,
"circuit_breaker_open": False,
"is_active": True,
},
}
if self.api_format:
logger.info(f"管理员恢复Key健康状态: {self.key_id}/{self.api_format}")
return {
"message": f"Key 的 {self.api_format} 格式已恢复",
"details": {
"api_format": self.api_format,
"health_score": 1.0,
"circuit_breaker_open": False,
"is_active": True,
},
}
else:
logger.info(f"管理员恢复Key健康状态: {self.key_id} (所有格式)")
return {
"message": "Key 所有格式已恢复",
"details": {
"health_score": 1.0,
"circuit_breaker_open": False,
"is_active": True,
},
}
class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
@@ -516,10 +548,17 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
# 查找所有熔断的 Key
circuit_open_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.circuit_breaker_open == True).all()
)
# 查找所有熔断格式的 Key(检查 circuit_breaker_by_format JSON 字段)
all_keys = db.query(ProviderAPIKey).all()
# 筛选出有任何格式熔断的 Key
circuit_open_keys = []
for key in all_keys:
circuit_by_format = key.circuit_breaker_by_format or {}
for fmt, circuit_data in circuit_by_format.items():
if circuit_data.get("open"):
circuit_open_keys.append(key)
break
if not circuit_open_keys:
return {
@@ -530,17 +569,15 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
recovered_keys = []
for key in circuit_open_keys:
key.health_score = 1.0
key.consecutive_failures = 0
key.last_failure_at = None
key.circuit_breaker_open = False
key.circuit_breaker_open_at = None
key.next_probe_at = None
# 重置所有格式的健康度
key.health_by_format = {} # type: ignore[assignment]
key.circuit_breaker_by_format = {} # type: ignore[assignment]
recovered_keys.append(
{
"key_id": key.id,
"key_name": key.name,
"endpoint_id": key.endpoint_id,
"provider_id": key.provider_id,
"api_formats": key.api_formats,
}
)
@@ -552,7 +589,6 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
HealthMonitor._open_circuit_keys = 0
health_open_circuits.set(0)
admin_name = context.user.username if context.user else "admin"
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
return {

View File

@@ -1,5 +1,5 @@
"""
Endpoint API Keys 管理
Provider API Keys 管理
"""
import uuid
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.config.constants import RPMDefaults
from src.core.crypto import crypto_service
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.key_capabilities import get_capability
@@ -20,96 +21,14 @@ from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.services.cache.provider_cache import ProviderCacheService
from src.models.endpoint_models import (
BatchUpdateKeyPriorityRequest,
EndpointAPIKeyCreate,
EndpointAPIKeyResponse,
EndpointAPIKeyUpdate,
)
router = APIRouter(tags=["Endpoint Keys"])
router = APIRouter(tags=["Provider Keys"])
pipeline = ApiRequestPipeline()
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
async def list_endpoint_keys(
endpoint_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[EndpointAPIKeyResponse]:
"""
获取 Endpoint 的所有 Keys
获取指定 Endpoint 下的所有 API Key 列表,包括 Key 的配置、统计信息等。
结果按优先级和创建时间排序。
**路径参数**:
- `endpoint_id`: Endpoint ID
**查询参数**:
- `skip`: 跳过的记录数,用于分页(默认 0
- `limit`: 返回的最大记录数1-1000默认 100
**返回字段**:
- `id`: Key ID
- `name`: Key 名称
- `api_key_masked`: 脱敏后的 API Key
- `internal_priority`: 内部优先级
- `global_priority`: 全局优先级
- `rate_multiplier`: 速率倍数
- `max_concurrent`: 最大并发数null 表示自适应模式)
- `is_adaptive`: 是否为自适应并发模式
- `effective_limit`: 有效并发限制
- `success_rate`: 成功率
- `avg_response_time_ms`: 平均响应时间(毫秒)
- 其他配置和统计字段
"""
adapter = AdminListEndpointKeysAdapter(
endpoint_id=endpoint_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
async def add_endpoint_key(
endpoint_id: str,
key_data: EndpointAPIKeyCreate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""
为 Endpoint 添加 Key
为指定 Endpoint 添加新的 API Key支持配置并发限制、速率倍数、
优先级、配额限制、能力限制等。
**路径参数**:
- `endpoint_id`: Endpoint ID
**请求体字段**:
- `endpoint_id`: Endpoint ID必须与路径参数一致
- `api_key`: API Key 原文(将被加密存储)
- `name`: Key 名称
- `note`: 备注(可选)
- `rate_multiplier`: 速率倍数(默认 1.0
- `internal_priority`: 内部优先级(默认 100
- `max_concurrent`: 最大并发数null 表示自适应模式)
- `rate_limit`: 每分钟请求限制(可选)
- `daily_limit`: 每日请求限制(可选)
- `monthly_limit`: 每月请求限制(可选)
- `allowed_models`: 允许的模型列表(可选)
- `capabilities`: 能力配置(可选)
**返回字段**:
- 包含完整的 Key 信息,其中 `api_key_plain` 为原文(仅在创建时返回)
"""
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
async def update_endpoint_key(
key_id: str,
@@ -118,7 +37,7 @@ async def update_endpoint_key(
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""
更新 Endpoint Key
更新 Provider Key
更新指定 Key 的配置,支持修改并发限制、速率倍数、优先级、
配额限制、能力限制等。支持部分更新。
@@ -132,10 +51,7 @@ async def update_endpoint_key(
- `note`: 备注
- `rate_multiplier`: 速率倍数
- `internal_priority`: 内部优先级
- `max_concurrent`: 最大并发数(设置为 null 可切换到自适应模式)
- `rate_limit`: 每分钟请求限制
- `daily_limit`: 每日请求限制
- `monthly_limit`: 每月请求限制
- `rpm_limit`: RPM 限制(设置为 null 可切换到自适应模式)
- `allowed_models`: 允许的模型列表
- `capabilities`: 能力配置
- `is_active`: 是否活跃
@@ -210,7 +126,7 @@ async def delete_endpoint_key(
db: Session = Depends(get_db),
) -> dict:
"""
删除 Endpoint Key
删除 Provider Key
删除指定的 API Key。此操作不可逆请谨慎使用。
@@ -224,163 +140,66 @@ async def delete_endpoint_key(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{endpoint_id}/keys/batch-priority")
async def batch_update_key_priority(
endpoint_id: str,
request: Request,
priority_data: BatchUpdateKeyPriorityRequest,
db: Session = Depends(get_db),
) -> dict:
"""
批量更新 Endpoint 下 Keys 的优先级
# ========== Provider Keys API ==========
批量更新指定 Endpoint 下多个 Key 的内部优先级,用于拖动排序。
所有 Key 必须属于指定的 Endpoint。
@router.get("/providers/{provider_id}/keys", response_model=List[EndpointAPIKeyResponse])
async def list_provider_keys(
provider_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[EndpointAPIKeyResponse]:
"""
获取 Provider 的所有 Keys
获取指定 Provider 下的所有 API Key 列表,支持多 API 格式。
结果按优先级和创建时间排序。
**路径参数**:
- `endpoint_id`: Endpoint ID
- `provider_id`: Provider ID
**查询参数**:
- `skip`: 跳过的记录数,用于分页(默认 0
- `limit`: 返回的最大记录数1-1000默认 100
"""
adapter = AdminListProviderKeysAdapter(
provider_id=provider_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/providers/{provider_id}/keys", response_model=EndpointAPIKeyResponse)
async def add_provider_key(
provider_id: str,
key_data: EndpointAPIKeyCreate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""
为 Provider 添加 Key
为指定 Provider 添加新的 API Key支持配置多个 API 格式。
**路径参数**:
- `provider_id`: Provider ID
**请求体字段**:
- `priorities`: 优先级列表
- `key_id`: Key ID
- `internal_priority`: 新的内部优先级
**返回字段**:
- `message`: 操作结果消息
- `updated_count`: 实际更新的 Key 数量
- `api_formats`: 支持的 API 格式列表(必填)
- `api_key`: API Key 原文(将被加密存储)
- `name`: Key 名称
- 其他配置字段同 Key
"""
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
adapter = AdminCreateProviderKeyAdapter(provider_id=provider_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminListEndpointKeysAdapter(AdminApiAdapter):
endpoint_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
.offset(self.skip)
.limit(self.limit)
.all()
)
result: List[EndpointAPIKeyResponse] = []
for key in keys:
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
key_dict = key.__dict__.copy()
key_dict.pop("_sa_instance_state", None)
key_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
result.append(EndpointAPIKeyResponse(**key_dict))
return result
@dataclass
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
endpoint_id: str
key_data: EndpointAPIKeyCreate
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
if self.key_data.endpoint_id != self.endpoint_id:
raise InvalidRequestException("endpoint_id 不匹配")
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
now = datetime.now(timezone.utc)
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=self.endpoint_id,
api_key=encrypted_key,
name=self.key_data.name,
note=self.key_data.note,
rate_multiplier=self.key_data.rate_multiplier,
internal_priority=self.key_data.internal_priority,
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
rate_limit=self.key_data.rate_limit,
daily_limit=self.key_data.daily_limit,
monthly_limit=self.key_data.monthly_limit,
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
request_count=0,
success_count=0,
error_count=0,
total_response_time_ms=0,
is_active=True,
last_used_at=None,
created_at=now,
updated_at=now,
)
db.add(new_key)
db.commit()
db.refresh(new_key)
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
is_adaptive = new_key.max_concurrent is None
response_dict = new_key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": self.key_data.api_key,
"success_rate": 0.0,
"avg_response_time_ms": 0.0,
"is_adaptive": is_adaptive,
"effective_limit": (
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
@dataclass
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
key_id: str
@@ -396,14 +215,21 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
if "api_key" in update_data:
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
# 特殊处理 max_concurrent需要区分"未提供"和"显式设置为 null"
# 当 max_concurrent 被显式设置时(在 model_fields_set 中),即使值为 None 也应该更新
if "max_concurrent" in self.key_data.model_fields_set:
update_data["max_concurrent"] = self.key_data.max_concurrent
# 切换到自适应模式时,清空学习到的并发限制,让系统重新学习
if self.key_data.max_concurrent is None:
update_data["learned_max_concurrent"] = None
logger.info("Key %s 切换为自适应并发模式", self.key_id)
# 特殊处理 rpm_limit需要区分"未提供"和"显式设置为 null"
if "rpm_limit" in self.key_data.model_fields_set:
update_data["rpm_limit"] = self.key_data.rpm_limit
if self.key_data.rpm_limit is None:
update_data["learned_rpm_limit"] = None
logger.info("Key %s 切换为自适应 RPM 模式", self.key_id)
# 统一处理 allowed_models空列表/空字典 -> None表示不限制
if "allowed_models" in update_data:
am = update_data["allowed_models"]
if am is not None and (
(isinstance(am, list) and len(am) == 0)
or (isinstance(am, dict) and len(am) == 0)
):
update_data["allowed_models"] = None
for field, value in update_data.items():
setattr(key, field, value)
@@ -412,39 +238,13 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
db.commit()
db.refresh(key)
# 如果更新了 rate_multiplier清除缓存
if "rate_multiplier" in update_data:
await ProviderCacheService.invalidate_provider_api_key_cache(self.key_id)
# 任何字段更新都清除缓存,确保缓存一致性
# 包括 is_active、allowed_models、capabilities 等影响权限和行为的字段
await ProviderCacheService.invalidate_provider_api_key_cache(self.key_id)
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
response_dict = key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
return _build_key_response(key)
@dataclass
@@ -481,7 +281,7 @@ class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
endpoint_id = key.endpoint_id
provider_id = key.provider_id
try:
db.delete(key)
db.commit()
@@ -490,7 +290,7 @@ class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
raise
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}")
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Provider={provider_id}")
return {"message": f"Key {self.key_id} 已删除"}
@@ -498,31 +298,51 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
# Key 属于 Provider按 key.api_formats 分组展示
keys = (
db.query(ProviderAPIKey, ProviderEndpoint, Provider)
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
db.query(ProviderAPIKey, Provider)
.join(Provider, ProviderAPIKey.provider_id == Provider.id)
.filter(
ProviderAPIKey.is_active.is_(True),
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.order_by(
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc()
ProviderAPIKey.global_priority.asc().nullslast(),
ProviderAPIKey.internal_priority.asc(),
)
.all()
)
provider_ids = {str(provider.id) for _key, provider in keys}
endpoints = (
db.query(
ProviderEndpoint.provider_id,
ProviderEndpoint.api_format,
ProviderEndpoint.base_url,
)
.filter(
ProviderEndpoint.provider_id.in_(provider_ids),
ProviderEndpoint.is_active.is_(True),
)
.all()
)
endpoint_base_url_map: Dict[tuple[str, str], str] = {}
for provider_id, api_format, base_url in endpoints:
fmt = api_format.value if hasattr(api_format, "value") else str(api_format)
endpoint_base_url_map[(str(provider_id), fmt)] = base_url
grouped: Dict[str, List[dict]] = {}
for key, endpoint, provider in keys:
api_format = endpoint.api_format
if api_format not in grouped:
grouped[api_format] = []
for key, provider in keys:
api_formats = key.api_formats or []
if not api_formats:
continue # 跳过没有 API 格式的 Key
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
except Exception as e:
logger.error(f"解密 Key 失败: key_id={key.id}, error={e}")
masked_key = "***ERROR***"
# 计算健康度指标
@@ -541,73 +361,209 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
cap_def = get_capability(cap_name)
caps_list.append(cap_def.short_name if cap_def else cap_name)
grouped[api_format].append(
{
"id": key.id,
"name": key.name,
"api_key_masked": masked_key,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"rate_multiplier": key.rate_multiplier,
"is_active": key.is_active,
"circuit_breaker_open": key.circuit_breaker_open,
"provider_name": provider.display_name or provider.name,
"endpoint_base_url": endpoint.base_url,
"api_format": api_format,
"capabilities": caps_list,
"health_score": key.health_score,
"success_rate": success_rate,
"avg_response_time_ms": avg_response_time_ms,
"request_count": key.request_count,
}
)
# 构建 Key 信息(基础数据)
key_info = {
"id": key.id,
"name": key.name,
"api_key_masked": masked_key,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"rate_multiplier": key.rate_multiplier,
"is_active": key.is_active,
"provider_name": provider.name,
"api_formats": api_formats,
"capabilities": caps_list,
"success_rate": success_rate,
"avg_response_time_ms": avg_response_time_ms,
"request_count": key.request_count,
}
# 将 Key 添加到每个支持的格式分组中,并附加格式特定的健康度数据
health_by_format = key.health_by_format or {}
circuit_by_format = key.circuit_breaker_by_format or {}
provider_id = str(provider.id)
for api_format in api_formats:
if api_format not in grouped:
grouped[api_format] = []
# 为每个格式创建副本,设置当前格式
format_key_info = key_info.copy()
format_key_info["api_format"] = api_format
format_key_info["endpoint_base_url"] = endpoint_base_url_map.get(
(provider_id, api_format)
)
# 添加格式特定的健康度数据
format_health = health_by_format.get(api_format, {})
format_circuit = circuit_by_format.get(api_format, {})
format_key_info["health_score"] = float(format_health.get("health_score") or 1.0)
format_key_info["circuit_breaker_open"] = bool(format_circuit.get("open", False))
grouped[api_format].append(format_key_info)
# 直接返回分组对象,供前端使用
return grouped
# ========== Adapters ==========
def _build_key_response(
key: ProviderAPIKey, api_key_plain: str | None = None
) -> EndpointAPIKeyResponse:
"""构建 Key 响应对象的辅助函数"""
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.rpm_limit is None
key_dict = key.__dict__.copy()
key_dict.pop("_sa_instance_state", None)
# 从 health_by_format 计算汇总字段(便于列表展示)
health_by_format = key.health_by_format or {}
circuit_by_format = key.circuit_breaker_by_format or {}
# 计算整体健康度(取所有格式中的最低值)
if health_by_format:
health_scores = [
float(h.get("health_score") or 1.0) for h in health_by_format.values()
]
min_health_score = min(health_scores) if health_scores else 1.0
# 取最大的连续失败次数
max_consecutive = max(
(int(h.get("consecutive_failures") or 0) for h in health_by_format.values()),
default=0,
)
# 取最近的失败时间
failure_times = [
h.get("last_failure_at")
for h in health_by_format.values()
if h.get("last_failure_at")
]
last_failure = max(failure_times) if failure_times else None
else:
min_health_score = 1.0
max_consecutive = 0
last_failure = None
# 检查是否有任何格式的熔断器打开
any_circuit_open = any(c.get("open", False) for c in circuit_by_format.values())
key_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": api_key_plain,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
if is_adaptive
else key.rpm_limit
),
# 汇总字段
"health_score": min_health_score,
"consecutive_failures": max_consecutive,
"last_failure_at": last_failure,
"circuit_breaker_open": any_circuit_open,
}
)
# 防御性:确保 api_formats 存在(历史数据可能为空/缺失)
if "api_formats" not in key_dict or key_dict["api_formats"] is None:
key_dict["api_formats"] = []
return EndpointAPIKeyResponse(**key_dict)
@dataclass
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter):
endpoint_id: str
priority_data: BatchUpdateKeyPriorityRequest
class AdminListProviderKeysAdapter(AdminApiAdapter):
"""获取 Provider 的所有 Keys"""
provider_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
# 获取所有需要更新的 Key ID
key_ids = [item.key_id for item in self.priority_data.priorities]
# 验证所有 Key 都属于该 Endpoint
keys = (
db.query(ProviderAPIKey)
.filter(
ProviderAPIKey.id.in_(key_ids),
ProviderAPIKey.endpoint_id == self.endpoint_id,
)
.filter(ProviderAPIKey.provider_id == self.provider_id)
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
.offset(self.skip)
.limit(self.limit)
.all()
)
if len(keys) != len(key_ids):
found_ids = {k.id for k in keys}
missing_ids = set(key_ids) - found_ids
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
return [_build_key_response(key) for key in keys]
# 批量更新优先级
key_map = {k.id: k for k in keys}
updated_count = 0
for item in self.priority_data.priorities:
key = key_map.get(item.key_id)
if key and key.internal_priority != item.internal_priority:
key.internal_priority = item.internal_priority
key.updated_at = datetime.now(timezone.utc)
updated_count += 1
@dataclass
class AdminCreateProviderKeyAdapter(AdminApiAdapter):
"""为 Provider 添加 Key"""
provider_id: str
key_data: EndpointAPIKeyCreate
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
# 验证 api_formats 必填
if not self.key_data.api_formats:
raise InvalidRequestException("api_formats 为必填字段")
# 允许同一个 API Key 在同一 Provider 下添加多次
# 用户可以为不同的 API 格式创建独立的配置记录,便于分开管理
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
now = datetime.now(timezone.utc)
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
provider_id=self.provider_id,
api_formats=self.key_data.api_formats,
api_key=encrypted_key,
name=self.key_data.name,
note=self.key_data.note,
rate_multiplier=self.key_data.rate_multiplier,
rate_multipliers=self.key_data.rate_multipliers, # 按 API 格式的成本倍率
internal_priority=self.key_data.internal_priority,
rpm_limit=self.key_data.rpm_limit,
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
cache_ttl_minutes=self.key_data.cache_ttl_minutes,
max_probe_interval_minutes=self.key_data.max_probe_interval_minutes,
request_count=0,
success_count=0,
error_count=0,
total_response_time_ms=0,
health_by_format={}, # 按格式存储健康度
circuit_breaker_by_format={}, # 按格式存储熔断器状态
is_active=True,
last_used_at=None,
created_at=now,
updated_at=now,
)
db.add(new_key)
db.commit()
db.refresh(new_key)
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}")
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count}
logger.info(
f"[OK] 添加 Key: Provider={self.provider_id}, "
f"Formats={self.key_data.api_formats}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}"
)
return _build_key_response(new_key, api_key_plain=self.key_data.api_key)

View File

@@ -67,8 +67,6 @@ async def list_provider_endpoints(
- `custom_path`: 自定义路径
- `timeout`: 超时时间(秒)
- `max_retries`: 最大重试次数
- `max_concurrent`: 最大并发数
- `rate_limit`: 速率限制
- `is_active`: 是否活跃
- `total_keys`: Key 总数
- `active_keys`: 活跃 Key 数量
@@ -107,8 +105,6 @@ async def create_provider_endpoint(
- `headers`: 自定义请求头(可选)
- `timeout`: 超时时间(秒,默认 300
- `max_retries`: 最大重试次数(默认 2
- `max_concurrent`: 最大并发数(可选)
- `rate_limit`: 速率限制(可选)
- `config`: 额外配置(可选)
- `proxy`: 代理配置(可选)
@@ -145,8 +141,6 @@ async def get_endpoint(
- `custom_path`: 自定义路径
- `timeout`: 超时时间(秒)
- `max_retries`: 最大重试次数
- `max_concurrent`: 最大并发数
- `rate_limit`: 速率限制
- `is_active`: 是否活跃
- `total_keys`: Key 总数
- `active_keys`: 活跃 Key 数量
@@ -178,8 +172,6 @@ async def update_endpoint(
- `headers`: 自定义请求头
- `timeout`: 超时时间(秒)
- `max_retries`: 最大重试次数
- `max_concurrent`: 最大并发数
- `rate_limit`: 速率限制
- `is_active`: 是否活跃
- `config`: 额外配置
- `proxy`: 代理配置(设置为 null 可清除代理)
@@ -203,15 +195,15 @@ async def delete_endpoint(
"""
删除 Endpoint
删除指定的 Endpoint同时级联删除所有关联的 API Keys
此操作不可逆,请谨慎使用
删除指定的 Endpoint会影响该 Provider 在该 API 格式下的路由能力
Key 不会被删除,但包含该 API 格式的 Key 将无法被调度使用(直到重新创建该格式的 Endpoint
**路径参数**:
- `endpoint_id`: Endpoint ID
**返回字段**:
- `message`: 操作结果消息
- `deleted_keys_count`: 同时删除的 Key 数量
- `affected_keys_count`: 受影响的 Key 数量(包含该 API 格式)
"""
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -241,39 +233,33 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter):
.all()
)
endpoint_ids = [ep.id for ep in endpoints]
total_keys_map = {}
active_keys_map = {}
if endpoint_ids:
total_rows = (
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
.group_by(ProviderAPIKey.endpoint_id)
.all()
)
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
active_rows = (
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
.filter(
and_(
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
ProviderAPIKey.is_active.is_(True),
)
)
.group_by(ProviderAPIKey.endpoint_id)
.all()
)
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
# Key 是 Provider 级别资源:按 key.api_formats 归类到各 Endpoint.api_format 下
keys = (
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
.filter(ProviderAPIKey.provider_id == self.provider_id)
.all()
)
total_keys_map: dict[str, int] = {}
active_keys_map: dict[str, int] = {}
for api_formats, is_active in keys:
for fmt in (api_formats or []):
total_keys_map[fmt] = total_keys_map.get(fmt, 0) + 1
if is_active:
active_keys_map[fmt] = active_keys_map.get(fmt, 0) + 1
result: List[ProviderEndpointResponse] = []
for endpoint in endpoints:
endpoint_format = (
endpoint.api_format
if isinstance(endpoint.api_format, str)
else endpoint.api_format.value
)
endpoint_dict = {
**endpoint.__dict__,
"provider_name": provider.name,
"api_format": endpoint.api_format,
"total_keys": total_keys_map.get(endpoint.id, 0),
"active_keys": active_keys_map.get(endpoint.id, 0),
"total_keys": total_keys_map.get(endpoint_format, 0),
"active_keys": active_keys_map.get(endpoint_format, 0),
"proxy": mask_proxy_password(endpoint.proxy),
}
endpoint_dict.pop("_sa_instance_state", None)
@@ -321,8 +307,6 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
headers=self.endpoint_data.headers,
timeout=self.endpoint_data.timeout,
max_retries=self.endpoint_data.max_retries,
max_concurrent=self.endpoint_data.max_concurrent,
rate_limit=self.endpoint_data.rate_limit,
is_active=True,
config=self.endpoint_data.config,
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
@@ -367,19 +351,23 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter):
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
endpoint_obj, provider = endpoint
total_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
endpoint_format = (
endpoint_obj.api_format
if isinstance(endpoint_obj.api_format, str)
else endpoint_obj.api_format.value
)
active_keys = (
db.query(ProviderAPIKey)
.filter(
and_(
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
keys = (
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
.filter(ProviderAPIKey.provider_id == endpoint_obj.provider_id)
.all()
)
total_keys = 0
active_keys = 0
for api_formats, is_active in keys:
if endpoint_format in (api_formats or []):
total_keys += 1
if is_active:
active_keys += 1
endpoint_dict = {
k: v
@@ -431,19 +419,21 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
total_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
endpoint_format = (
endpoint.api_format if isinstance(endpoint.api_format, str) else endpoint.api_format.value
)
active_keys = (
db.query(ProviderAPIKey)
.filter(
and_(
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
keys = (
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
.filter(ProviderAPIKey.provider_id == endpoint.provider_id)
.all()
)
total_keys = 0
active_keys = 0
for api_formats, is_active in keys:
if endpoint_format in (api_formats or []):
total_keys += 1
if is_active:
active_keys += 1
endpoint_dict = {
k: v
@@ -472,12 +462,26 @@ class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys_count = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
endpoint_format = (
endpoint.api_format if isinstance(endpoint.api_format, str) else endpoint.api_format.value
)
keys = (
db.query(ProviderAPIKey.api_formats)
.filter(ProviderAPIKey.provider_id == endpoint.provider_id)
.all()
)
affected_keys_count = sum(
1 for (api_formats,) in keys if endpoint_format in (api_formats or [])
)
db.delete(endpoint)
db.commit()
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
logger.warning(
f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, Format={endpoint_format}, "
f"AffectedKeys={affected_keys_count}"
)
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}
return {
"message": f"Endpoint {self.endpoint_id} 已删除",
"affected_keys_count": affected_keys_count,
}

View File

@@ -125,7 +125,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
ModelCatalogProviderDetail(
provider_id=provider.id,
provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id,
target_model=model.provider_model_name,
# 显示有效价格

View File

@@ -452,7 +452,6 @@ class AdminGetGlobalModelProvidersAdapter(AdminApiAdapter):
ModelCatalogProviderDetail(
provider_id=provider.id,
provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id,
target_model=model.provider_model_name,
input_price_per_1m=model.get_effective_input_price(),

View File

@@ -819,7 +819,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
"username": user.username if user else None,
"email": user.email if user else None,
"provider_id": provider_id,
"provider_name": provider.display_name if provider else None,
"provider_name": provider.name if provider else None,
"endpoint_id": endpoint_id,
"endpoint_api_format": (
endpoint.api_format if endpoint and endpoint.api_format else None
@@ -1369,9 +1369,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
for model, provider in models:
# 检查是否是主模型名称
if model.provider_model_name == mapping_name:
provider_names.append(
provider.display_name or provider.name
)
provider_names.append(provider.name)
continue
# 检查是否在映射列表中
if model.provider_model_mappings:
@@ -1381,9 +1379,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
if isinstance(a, dict)
]
if mapping_name in mapping_list:
provider_names.append(
provider.display_name or provider.name
)
provider_names.append(provider.name)
provider_names = sorted(list(set(provider_names)))
mappings.append({
@@ -1473,7 +1469,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
provider_model_mappings.append({
"provider_id": provider_id,
"provider_name": provider.display_name or provider.name,
"provider_name": provider.name,
"global_model_id": global_model_id,
"global_model_name": global_model.name,
"global_model_display_name": global_model.display_name,

View File

@@ -13,10 +13,11 @@ from sqlalchemy.orm import Session, joinedload
from src.api.handlers.base.chat_adapter_base import get_adapter_class
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_class
from src.config.constants import TimeoutDefaults
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.database.database import get_db
from src.models.database import Provider, ProviderEndpoint, User
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
from src.utils.auth_utils import get_current_user
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
@@ -81,10 +82,13 @@ async def query_available_models(
Returns:
所有端点的模型列表(合并)
"""
# 获取提供商及其端点
# 获取提供商及其端点和 API Keys
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.options(
joinedload(Provider.endpoints),
joinedload(Provider.api_keys),
)
.filter(Provider.id == request.provider_id)
.first()
)
@@ -95,49 +99,70 @@ async def query_available_models(
# 收集所有活跃端点的配置
endpoint_configs: list[dict] = []
# 构建 api_format -> endpoint 映射
format_to_endpoint: dict[str, ProviderEndpoint] = {}
for endpoint in provider.endpoints:
if endpoint.is_active:
format_to_endpoint[endpoint.api_format] = endpoint
if request.api_key_id:
# 指定了特定的 API Key,只使用该 Key 对应的端点
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break
if endpoint_configs:
break
# 指定了特定的 API Key(从 provider.api_keys 查找)
api_key = next(
(key for key in provider.api_keys if key.id == request.api_key_id),
None
)
if not api_key:
raise HTTPException(status_code=404, detail="API Key not found")
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
# 根据 Key 的 api_formats 找对应的 Endpoint
key_formats = api_key.api_formats or []
for fmt in key_formats:
endpoint = format_to_endpoint.get(fmt)
if endpoint:
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": fmt,
"extra_headers": endpoint.headers,
})
if not endpoint_configs:
raise HTTPException(status_code=404, detail="API Key not found")
raise HTTPException(
status_code=400,
detail="No matching endpoint found for this API Key's formats"
)
else:
# 遍历所有活跃端点,每个端点取第一个可用的 Key
# 遍历所有活跃端点,每个端点找一个支持该格式的 Key
for endpoint in provider.endpoints:
if not endpoint.is_active or not endpoint.api_keys:
if not endpoint.is_active:
continue
# 找第一个可用 Key
for api_key in endpoint.api_keys:
if api_key.is_active:
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
continue # 尝试下一个 Key
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break # 只取第一个可用的 Key
# 找第一个支持该格式的可用 Key
for api_key in provider.api_keys:
if not api_key.is_active:
continue
key_formats = api_key.api_formats or []
if endpoint.api_format not in key_formats:
continue
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
continue
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break # 只取第一个可用的 Key
if not endpoint_configs:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
@@ -214,7 +239,6 @@ async def query_available_models(
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@@ -229,17 +253,14 @@ async def test_model(
测试模型连接性
向指定提供商的指定模型发送测试请求,验证模型是否可用
Args:
request: 测试请求
Returns:
测试结果
"""
# 获取提供商及其端点
# 获取提供商及其端点和 Keys
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.options(
joinedload(Provider.endpoints),
joinedload(Provider.api_keys),
)
.filter(Provider.id == request.provider_id)
.first()
)
@@ -247,28 +268,38 @@ async def test_model(
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 找到合适的端点和API Key
endpoint_config = None
# 构建 api_format -> endpoint 映射
format_to_endpoint: dict[str, ProviderEndpoint] = {}
for ep in provider.endpoints:
if ep.is_active:
format_to_endpoint[ep.api_format] = ep
# 找到合适的端点和 API Key
endpoint = None
api_key = None
if request.api_key_id:
# 使用指定的API Key
for ep in provider.endpoints:
for key in ep.api_keys:
if key.id == request.api_key_id and key.is_active and ep.is_active:
endpoint = ep
api_key = key
# 使用指定的 API Key
api_key = next(
(key for key in provider.api_keys if key.id == request.api_key_id and key.is_active),
None
)
if api_key:
# 找到该 Key 支持的第一个活跃 Endpoint
for fmt in (api_key.api_formats or []):
if fmt in format_to_endpoint:
endpoint = format_to_endpoint[fmt]
break
if endpoint:
break
else:
# 使用第一个可用的端点和密钥
for ep in provider.endpoints:
if not ep.is_active or not ep.api_keys:
if not ep.is_active:
continue
for key in ep.api_keys:
if key.is_active:
# 找支持该格式的第一个可用 Key
for key in provider.api_keys:
if not key.is_active:
continue
if ep.api_format in (key.api_formats or []):
endpoint = ep
api_key = key
break
@@ -284,14 +315,14 @@ async def test_model(
logger.error(f"[test-model] Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
# 构建请求配置
# 构建请求配置timeout 从 Provider 读取)
endpoint_config = {
"api_key": api_key_value,
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
"timeout": endpoint.timeout or 30.0,
"timeout": provider.timeout or TimeoutDefaults.HTTP_REQUEST,
}
try:
@@ -304,7 +335,6 @@ async def test_model(
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
@@ -325,7 +355,6 @@ async def test_model(
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
@@ -415,7 +444,6 @@ async def test_model(
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {
@@ -433,7 +461,6 @@ async def test_model(
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {

View File

@@ -78,7 +78,7 @@ async def get_provider_stats(
"""
获取提供商统计数据
获取指定提供商的计费信息、RPM 使用情况和使用统计数据。
获取指定提供商的计费信息和使用统计数据。
**路径参数**:
- `provider_id`: 提供商 ID
@@ -96,10 +96,6 @@ async def get_provider_stats(
- `monthly_used_usd`: 月度已使用
- `quota_remaining_usd`: 剩余配额
- `quota_expires_at`: 配额过期时间
- `rpm_info`: RPM 信息
- `rpm_limit`: RPM 限制
- `rpm_used`: 已使用 RPM
- `rpm_reset_at`: RPM 重置时间
- `usage_stats`: 使用统计
- `total_requests`: 总请求数
- `successful_requests`: 成功请求数
@@ -165,7 +161,6 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
provider.billing_type = config.billing_type
provider.monthly_quota_usd = config.monthly_quota_usd
provider.quota_reset_day = config.quota_reset_day
provider.rpm_limit = config.rpm_limit
provider.provider_priority = config.provider_priority
from dateutil import parser
@@ -262,13 +257,6 @@ class AdminProviderStatsAdapter(AdminApiAdapter):
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
),
},
"rpm_info": {
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": (
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
),
},
"usage_stats": {
"total_requests": total_requests,
"successful_requests": total_success,
@@ -296,8 +284,6 @@ class AdminProviderResetQuotaAdapter(AdminApiAdapter):
old_used = provider.monthly_used_usd
provider.monthly_used_usd = 0.0
provider.rpm_used = 0
provider.rpm_reset_at = None
db.commit()
logger.info(f"Manually reset quota for provider {provider.name}")

View File

@@ -338,27 +338,29 @@ async def import_models_from_upstream(
"""
从上游提供商导入模型
从上游提供商导入模型列表。如果全局模型不存在,将自动创建。
从上游提供商导入模型列表。导入的模型作为独立的 ProviderModel 存储,
不会自动创建 GlobalModel。后续需要手动关联 GlobalModel 才能参与路由。
**流程说明**:
1. 根据 model_ids 检查全局模型是否存在(按 name 匹配)
2. 如不存在,自动创建新的 GlobalModel(使用默认免费配置
3. 创建 Model 关联到当前 Provider
4. 如模型已关联,则记录到成功列表中
1. 检查模型是否存在于当前 Provider按 provider_model_name 匹配)
2. 创建新的 ProviderModelglobal_model_id = NULL
3. 支持设置价格覆盖tiered_pricing, price_per_request
**路径参数**:
- `provider_id`: 提供商 ID
**请求体字段**:
- `model_ids`: 模型 ID 数组(必填,每个 ID 长度 1-100 字符)
- `tiered_pricing`: 可选的阶梯计费配置(应用于所有导入的模型)
- `price_per_request`: 可选的按次计费价格(应用于所有导入的模型)
**返回字段**:
- `success`: 成功导入的模型数组,每项包含:
- `model_id`: 模型 ID
- `global_model_id`: 全局模型 ID
- `global_model_name`: 全局模型名称
- `provider_model_id`: 提供商模型 ID
- `created_global_model`: 是否新创建了全局模型
- `global_model_id`: 全局模型 ID如果已关联
- `global_model_name`: 全局模型名称(如果已关联)
- `created_global_model`: 是否新创建了全局模型(始终为 false
- `errors`: 失败的模型数组,每项包含:
- `model_id`: 模型 ID
- `error`: 错误信息
@@ -638,7 +640,7 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
@dataclass
class AdminImportFromUpstreamAdapter(AdminApiAdapter):
"""从上游提供商导入模型"""
"""从上游提供商导入模型(不创建 GlobalModel作为独立 ProviderModel"""
provider_id: str
payload: ImportFromUpstreamRequest
@@ -652,16 +654,13 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
success: list[ImportFromUpstreamSuccessItem] = []
errors: list[ImportFromUpstreamErrorItem] = []
# 默认阶梯计费配置(免费)
default_tiered_pricing = {
"tiers": [
{
"up_to": None,
"input_price_per_1m": 0.0,
"output_price_per_1m": 0.0,
}
]
}
# 获取价格覆盖配置
tiered_pricing = None
price_per_request = None
if hasattr(self.payload, 'tiered_pricing') and self.payload.tiered_pricing:
tiered_pricing = self.payload.tiered_pricing
if hasattr(self.payload, 'price_per_request') and self.payload.price_per_request is not None:
price_per_request = self.payload.price_per_request
for model_id in self.payload.model_ids:
# 输入验证:检查 model_id 长度
@@ -678,56 +677,37 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
# 使用 savepoint 确保单个模型导入的原子性
savepoint = db.begin_nested()
try:
# 1. 检查是否已存在同名的 GlobalModel
global_model = (
db.query(GlobalModel).filter(GlobalModel.name == model_id).first()
)
created_global_model = False
if not global_model:
# 2. 创建新的 GlobalModel
global_model = GlobalModel(
name=model_id,
display_name=model_id,
default_tiered_pricing=default_tiered_pricing,
is_active=True,
)
db.add(global_model)
db.flush()
created_global_model = True
logger.info(
f"Created new GlobalModel: {model_id} during upstream import"
)
# 3. 检查是否已存在关联
# 1. 检查是否已存在同名的 ProviderModel
existing = (
db.query(Model)
.filter(
Model.provider_id == self.provider_id,
Model.global_model_id == global_model.id,
Model.provider_model_name == model_id,
)
.first()
)
if existing:
# 已存在关联,提交 savepoint 并记录成功
# 已存在,提交 savepoint 并记录成功
savepoint.commit()
success.append(
ImportFromUpstreamSuccessItem(
model_id=model_id,
global_model_id=global_model.id,
global_model_name=global_model.name,
global_model_id=existing.global_model_id or "",
global_model_name=existing.global_model.name if existing.global_model else "",
provider_model_id=existing.id,
created_global_model=created_global_model,
created_global_model=False,
)
)
continue
# 4. 创建新的 Model 记录
# 2. 创建新的 Model 记录(不关联 GlobalModel
new_model = Model(
provider_id=self.provider_id,
global_model_id=global_model.id,
provider_model_name=global_model.name,
global_model_id=None, # 独立模型,不关联 GlobalModel
provider_model_name=model_id,
is_active=True,
tiered_pricing=tiered_pricing,
price_per_request=price_per_request,
)
db.add(new_model)
db.flush()
@@ -737,12 +717,15 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
success.append(
ImportFromUpstreamSuccessItem(
model_id=model_id,
global_model_id=global_model.id,
global_model_name=global_model.name,
global_model_id="", # 未关联
global_model_name="", # 未关联
provider_model_id=new_model.id,
created_global_model=created_global_model,
created_global_model=False,
)
)
logger.info(
f"Created independent ProviderModel: {model_id} for provider {provider.name}"
)
except Exception as e:
# 回滚到 savepoint
savepoint.rollback()
@@ -753,11 +736,9 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
db.commit()
logger.info(
f"Imported {len(success)} models from upstream to provider {provider.name} by {context.user.username}"
f"Imported {len(success)} independent models to provider {provider.name} by {context.user.username}"
)
# 清除 /v1/models 列表缓存
if success:
await invalidate_models_list_cache()
# 不需要清除 /v1/models 缓存,因为独立模型不参与路由
return ImportFromUpstreamResponse(success=success, errors=errors)

View File

@@ -41,8 +41,7 @@ async def list_providers(
**返回字段**:
- `id`: 提供商 ID
- `name`: 提供商名称(唯一标识
- `display_name`: 显示名称
- `name`: 提供商名称(唯一)
- `api_format`: API 格式(如 claude、openai、gemini 等)
- `base_url`: API 基础 URL
- `api_key`: API 密钥(脱敏显示)
@@ -63,8 +62,7 @@ async def create_provider(request: Request, db: Session = Depends(get_db)):
创建一个新的 AI 模型提供商配置。
**请求体字段**:
- `name`: 提供商名称(必填,唯一,用于系统标识
- `display_name`: 显示名称(必填)
- `name`: 提供商名称(必填,唯一)
- `description`: 描述信息(可选)
- `website`: 官网地址(可选)
- `billing_type`: 计费类型可选pay_as_you_go/subscription/prepaid默认 pay_as_you_go
@@ -72,16 +70,17 @@ async def create_provider(request: Request, db: Session = Depends(get_db)):
- `quota_reset_day`: 配额重置日期1-31可选
- `quota_last_reset_at`: 上次配额重置时间(可选)
- `quota_expires_at`: 配额过期时间(可选)
- `rpm_limit`: 每分钟请求数限制(可选)
- `provider_priority`: 提供商优先级(数字越小优先级越高,默认 100
- `is_active`: 是否启用(默认 true
- `concurrent_limit`: 并发限制(可选)
- `timeout`: 请求超时(秒,可选)
- `max_retries`: 最大重试次数(可选)
- `proxy`: 代理配置(可选)
- `config`: 额外配置信息JSON可选
**返回字段**:
- `id`: 新创建的提供商 ID
- `name`: 提供商名称
- `display_name`: 显示名称
- `message`: 成功提示信息
"""
adapter = AdminCreateProviderAdapter()
@@ -100,7 +99,6 @@ async def update_provider(provider_id: str, request: Request, db: Session = Depe
**请求体字段**(所有字段可选):
- `name`: 提供商名称
- `display_name`: 显示名称
- `description`: 描述信息
- `website`: 官网地址
- `billing_type`: 计费类型pay_as_you_go/subscription/prepaid
@@ -108,10 +106,12 @@ async def update_provider(provider_id: str, request: Request, db: Session = Depe
- `quota_reset_day`: 配额重置日期1-31
- `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间
- `rpm_limit`: 每分钟请求数限制
- `provider_priority`: 提供商优先级
- `is_active`: 是否启用
- `concurrent_limit`: 并发限制
- `timeout`: 请求超时(秒)
- `max_retries`: 最大重试次数
- `proxy`: 代理配置
- `config`: 额外配置信息JSON
**返回字段**:
@@ -165,7 +165,6 @@ class AdminListProvidersAdapter(AdminApiAdapter):
{
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"api_format": api_format.value if api_format else None,
"base_url": base_url,
"api_key": "***" if api_key else None,
@@ -217,7 +216,6 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
# 创建 Provider 对象
provider = Provider(
name=validated_data.name,
display_name=validated_data.display_name,
description=validated_data.description,
website=validated_data.website,
billing_type=billing_type,
@@ -225,10 +223,12 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
quota_reset_day=validated_data.quota_reset_day,
quota_last_reset_at=validated_data.quota_last_reset_at,
quota_expires_at=validated_data.quota_expires_at,
rpm_limit=validated_data.rpm_limit,
provider_priority=validated_data.provider_priority,
is_active=validated_data.is_active,
concurrent_limit=validated_data.concurrent_limit,
timeout=validated_data.timeout,
max_retries=validated_data.max_retries,
proxy=validated_data.proxy.model_dump() if validated_data.proxy else None,
config=validated_data.config,
)
@@ -248,7 +248,6 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
return {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"message": "提供商创建成功",
}
except InvalidRequestException:
@@ -291,6 +290,9 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
if field == "billing_type" and value is not None:
# billing_type 需要转换为枚举
setattr(provider, field, ProviderBillingType(value))
elif field == "proxy" and value is not None:
# proxy 需要转换为 dict如果是 Pydantic 模型)
setattr(provider, field, value if isinstance(value, dict) else value.model_dump())
else:
setattr(provider, field, value)

View File

@@ -48,7 +48,6 @@ async def get_providers_summary(
**返回字段**(数组,每项包含):
- `id`: 提供商 ID
- `name`: 提供商名称
- `display_name`: 显示名称
- `description`: 描述信息
- `website`: 官网地址
- `provider_priority`: 优先级
@@ -59,9 +58,9 @@ async def get_providers_summary(
- `quota_reset_day`: 配额重置日期
- `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间
- `rpm_limit`: RPM 限制
- `rpm_used`: 已使用 RPM
- `rpm_reset_at`: RPM 重置时间
- `timeout`: 默认请求超时(秒)
- `max_retries`: 默认最大重试次数
- `proxy`: 默认代理配置
- `total_endpoints`: 端点总数
- `active_endpoints`: 活跃端点数
- `total_keys`: 密钥总数
@@ -96,7 +95,6 @@ async def get_provider_summary(
**返回字段**:
- `id`: 提供商 ID
- `name`: 提供商名称
- `display_name`: 显示名称
- `description`: 描述信息
- `website`: 官网地址
- `provider_priority`: 优先级
@@ -107,9 +105,9 @@ async def get_provider_summary(
- `quota_reset_day`: 配额重置日期
- `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间
- `rpm_limit`: RPM 限制
- `rpm_used`: 已使用 RPM
- `rpm_reset_at`: RPM 重置时间
- `timeout`: 默认请求超时(秒)
- `max_retries`: 默认最大重试次数
- `proxy`: 默认代理配置
- `total_endpoints`: 端点总数
- `active_endpoints`: 活跃端点数
- `total_keys`: 密钥总数
@@ -185,13 +183,13 @@ async def update_provider_settings(
"""
更新提供商基础配置
更新提供商的基础配置信息,如显示名称、描述、优先级等。只需传入需要更新的字段。
更新提供商的基础配置信息,如名称、描述、优先级等。只需传入需要更新的字段。
**路径参数**:
- `provider_id`: 提供商 ID
**请求体字段**(所有字段可选):
- `display_name`: 显示名称
- `name`: 提供商名称
- `description`: 描述信息
- `website`: 官网地址
- `provider_priority`: 优先级
@@ -199,9 +197,10 @@ async def update_provider_settings(
- `billing_type`: 计费类型
- `monthly_quota_usd`: 月度配额(美元)
- `quota_reset_day`: 配额重置日期
- `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间
- `rpm_limit`: RPM 限制
- `timeout`: 默认请求超时(秒)
- `max_retries`: 默认最大重试次数
- `proxy`: 默认代理配置
**返回字段**: 返回更新后的提供商摘要信息(与 GET /summary 接口返回格式相同)
"""
@@ -215,18 +214,18 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
total_endpoints = len(endpoints)
active_endpoints = sum(1 for e in endpoints if e.is_active)
endpoint_ids = [e.id for e in endpoints]
# Key 统计(合并为单个查询)
total_keys = 0
active_keys = 0
if endpoint_ids:
key_stats = db.query(
key_stats = (
db.query(
func.count(ProviderAPIKey.id).label("total"),
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).first()
total_keys = key_stats.total or 0
active_keys = int(key_stats.active or 0)
)
.filter(ProviderAPIKey.provider_id == provider.id)
.first()
)
total_keys = key_stats.total or 0
active_keys = int(key_stats.active or 0)
# Model 统计(合并为单个查询)
model_stats = db.query(
@@ -238,25 +237,34 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
api_formats = [e.api_format for e in endpoints]
# 优化: 一次性加载所有 endpoint 的 keys避免 N+1 查询
all_keys = []
if endpoint_ids:
all_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).all()
)
# 优化: 一次性加载 Provider 的 keys避免 N+1 查询
all_keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.provider_id == provider.id).all()
# 按 endpoint_id 分组 keys
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {}
# 按 api_formats 分组 keys通过 api_formats 关联)
format_to_endpoint_id: dict[str, str] = {e.api_format: e.id for e in endpoints}
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {e.id: [] for e in endpoints}
for key in all_keys:
if key.endpoint_id not in keys_by_endpoint:
keys_by_endpoint[key.endpoint_id] = []
keys_by_endpoint[key.endpoint_id].append(key)
formats = key.api_formats or []
for fmt in formats:
endpoint_id = format_to_endpoint_id.get(fmt)
if endpoint_id:
keys_by_endpoint[endpoint_id].append(key)
endpoint_health_map: dict[str, float] = {}
for endpoint in endpoints:
keys = keys_by_endpoint.get(endpoint.id, [])
if keys:
health_scores = [k.health_score for k in keys if k.health_score is not None]
# 从 health_by_format 获取对应格式的健康度
api_fmt = endpoint.api_format
health_scores = []
for k in keys:
health_by_format = k.health_by_format or {}
if api_fmt in health_by_format:
score = health_by_format[api_fmt].get("health_score")
if score is not None:
health_scores.append(float(score))
else:
health_scores.append(1.0) # 默认健康度
avg_health = sum(health_scores) / len(health_scores) if health_scores else 1.0
endpoint_health_map[endpoint.id] = avg_health
else:
@@ -284,7 +292,6 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
return ProviderWithEndpointsSummary(
id=provider.id,
name=provider.name,
display_name=provider.display_name,
description=provider.description,
website=provider.website,
provider_priority=provider.provider_priority,
@@ -295,9 +302,9 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
quota_reset_day=provider.quota_reset_day,
quota_last_reset_at=provider.quota_last_reset_at,
quota_expires_at=provider.quota_expires_at,
rpm_limit=provider.rpm_limit,
rpm_used=provider.rpm_used,
rpm_reset_at=provider.rpm_reset_at,
timeout=provider.timeout,
max_retries=provider.max_retries,
proxy=provider.proxy,
total_endpoints=total_endpoints,
active_endpoints=active_endpoints,
total_keys=total_keys,
@@ -341,7 +348,7 @@ class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
if not endpoint_ids:
response = ProviderEndpointHealthMonitorResponse(
provider_id=provider.id,
provider_name=provider.display_name or provider.name,
provider_name=provider.name,
generated_at=now,
endpoints=[],
)
@@ -416,7 +423,7 @@ class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
response = ProviderEndpointHealthMonitorResponse(
provider_id=provider.id,
provider_name=provider.display_name or provider.name,
provider_name=provider.name,
generated_at=now,
endpoints=endpoint_monitors,
)

View File

@@ -730,36 +730,6 @@ class AdminExportConfigAdapter(AdminApiAdapter):
)
endpoints_data = []
for ep in endpoints:
# 导出 Endpoint Keys
keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
)
keys_data = []
for key in keys:
# 解密 API Key
try:
decrypted_key = crypto_service.decrypt(key.api_key)
except Exception:
decrypted_key = ""
keys_data.append(
{
"api_key": decrypted_key,
"name": key.name,
"note": key.note,
"rate_multiplier": key.rate_multiplier,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"max_concurrent": key.max_concurrent,
"rate_limit": key.rate_limit,
"daily_limit": key.daily_limit,
"monthly_limit": key.monthly_limit,
"allowed_models": key.allowed_models,
"capabilities": key.capabilities,
"is_active": key.is_active,
}
)
endpoints_data.append(
{
"api_format": ep.api_format,
@@ -767,12 +737,44 @@ class AdminExportConfigAdapter(AdminApiAdapter):
"headers": ep.headers,
"timeout": ep.timeout,
"max_retries": ep.max_retries,
"max_concurrent": ep.max_concurrent,
"rate_limit": ep.rate_limit,
"is_active": ep.is_active,
"custom_path": ep.custom_path,
"config": ep.config,
"keys": keys_data,
"proxy": ep.proxy,
}
)
# 导出 Provider Keys按 provider_id 归属,包含 api_formats
keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.provider_id == provider.id)
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
.all()
)
keys_data = []
for key in keys:
# 解密 API Key
try:
decrypted_key = crypto_service.decrypt(key.api_key)
except Exception:
decrypted_key = ""
keys_data.append(
{
"api_key": decrypted_key,
"name": key.name,
"note": key.note,
"api_formats": key.api_formats or [],
"rate_multiplier": key.rate_multiplier,
"rate_multipliers": key.rate_multipliers,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"rpm_limit": key.rpm_limit,
"allowed_models": key.allowed_models,
"capabilities": key.capabilities,
"cache_ttl_minutes": key.cache_ttl_minutes,
"max_probe_interval_minutes": key.max_probe_interval_minutes,
"is_active": key.is_active,
}
)
@@ -804,24 +806,26 @@ class AdminExportConfigAdapter(AdminApiAdapter):
providers_data.append(
{
"name": provider.name,
"display_name": provider.display_name,
"description": provider.description,
"website": provider.website,
"billing_type": provider.billing_type.value if provider.billing_type else None,
"monthly_quota_usd": provider.monthly_quota_usd,
"quota_reset_day": provider.quota_reset_day,
"rpm_limit": provider.rpm_limit,
"provider_priority": provider.provider_priority,
"is_active": provider.is_active,
"concurrent_limit": provider.concurrent_limit,
"timeout": provider.timeout,
"max_retries": provider.max_retries,
"proxy": provider.proxy,
"config": provider.config,
"endpoints": endpoints_data,
"api_keys": keys_data,
"models": models_data,
}
)
return {
"version": "1.0",
"version": "2.0",
"exported_at": datetime.now(timezone.utc).isoformat(),
"global_models": global_models_data,
"providers": providers_data,
@@ -850,7 +854,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
# 验证配置版本
version = payload.get("version")
if version != "1.0":
if version != "2.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
@@ -939,8 +943,8 @@ class AdminImportConfigAdapter(AdminApiAdapter):
)
elif merge_mode == "overwrite":
# 更新现有记录
existing_provider.display_name = prov_data.get(
"display_name", existing_provider.display_name
existing_provider.name = prov_data.get(
"name", existing_provider.name
)
existing_provider.description = prov_data.get("description")
existing_provider.website = prov_data.get("website")
@@ -954,7 +958,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
existing_provider.quota_reset_day = prov_data.get(
"quota_reset_day", 30
)
existing_provider.rpm_limit = prov_data.get("rpm_limit")
existing_provider.provider_priority = prov_data.get(
"provider_priority", 100
)
@@ -962,6 +965,11 @@ class AdminImportConfigAdapter(AdminApiAdapter):
existing_provider.concurrent_limit = prov_data.get(
"concurrent_limit"
)
existing_provider.timeout = prov_data.get("timeout", existing_provider.timeout)
existing_provider.max_retries = prov_data.get(
"max_retries", existing_provider.max_retries
)
existing_provider.proxy = prov_data.get("proxy", existing_provider.proxy)
existing_provider.config = prov_data.get("config")
existing_provider.updated_at = datetime.now(timezone.utc)
stats["providers"]["updated"] += 1
@@ -974,16 +982,17 @@ class AdminImportConfigAdapter(AdminApiAdapter):
new_provider = Provider(
id=str(uuid.uuid4()),
name=prov_data["name"],
display_name=prov_data.get("display_name", prov_data["name"]),
description=prov_data.get("description"),
website=prov_data.get("website"),
billing_type=billing_type,
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
quota_reset_day=prov_data.get("quota_reset_day", 30),
rpm_limit=prov_data.get("rpm_limit"),
provider_priority=prov_data.get("provider_priority", 100),
is_active=prov_data.get("is_active", True),
concurrent_limit=prov_data.get("concurrent_limit"),
timeout=prov_data.get("timeout"),
max_retries=prov_data.get("max_retries"),
proxy=prov_data.get("proxy"),
config=prov_data.get("config"),
)
db.add(new_provider)
@@ -1003,7 +1012,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
)
if existing_ep:
endpoint_id = existing_ep.id
if merge_mode == "skip":
stats["endpoints"]["skipped"] += 1
elif merge_mode == "error":
@@ -1017,11 +1025,10 @@ class AdminImportConfigAdapter(AdminApiAdapter):
existing_ep.headers = ep_data.get("headers")
existing_ep.timeout = ep_data.get("timeout", 300)
existing_ep.max_retries = ep_data.get("max_retries", 2)
existing_ep.max_concurrent = ep_data.get("max_concurrent")
existing_ep.rate_limit = ep_data.get("rate_limit")
existing_ep.is_active = ep_data.get("is_active", True)
existing_ep.custom_path = ep_data.get("custom_path")
existing_ep.config = ep_data.get("config")
existing_ep.proxy = ep_data.get("proxy")
existing_ep.updated_at = datetime.now(timezone.utc)
stats["endpoints"]["updated"] += 1
else:
@@ -1033,68 +1040,106 @@ class AdminImportConfigAdapter(AdminApiAdapter):
headers=ep_data.get("headers"),
timeout=ep_data.get("timeout", 300),
max_retries=ep_data.get("max_retries", 2),
max_concurrent=ep_data.get("max_concurrent"),
rate_limit=ep_data.get("rate_limit"),
is_active=ep_data.get("is_active", True),
custom_path=ep_data.get("custom_path"),
config=ep_data.get("config"),
proxy=ep_data.get("proxy"),
)
db.add(new_ep)
db.flush()
endpoint_id = new_ep.id
stats["endpoints"]["created"] += 1
# 导入 Keys
# 获取当前 endpoint 下所有已有的 keys用于去重
existing_keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
.all()
)
# 解密已有 keys 用于比对
existing_key_values = set()
for ek in existing_keys:
try:
decrypted = crypto_service.decrypt(ek.api_key)
existing_key_values.add(decrypted)
except Exception:
pass
# 导入 Provider Keys按 provider_id 归属)
endpoint_format_rows = (
db.query(ProviderEndpoint.api_format)
.filter(ProviderEndpoint.provider_id == provider_id)
.all()
)
endpoint_formats: set[str] = set()
for (api_format,) in endpoint_format_rows:
fmt = api_format.value if hasattr(api_format, "value") else str(api_format)
endpoint_formats.add(fmt.strip().upper())
existing_keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.provider_id == provider_id)
.all()
)
existing_key_values = set()
for ek in existing_keys:
try:
decrypted = crypto_service.decrypt(ek.api_key)
existing_key_values.add(decrypted)
except Exception:
pass
for key_data in ep_data.get("keys", []):
if not key_data.get("api_key"):
stats["errors"].append(
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
)
continue
# 检查是否已存在相同的 Key通过明文比对
if key_data["api_key"] in existing_key_values:
stats["keys"]["skipped"] += 1
continue
encrypted_key = crypto_service.encrypt(key_data["api_key"])
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=endpoint_id,
api_key=encrypted_key,
name=key_data.get("name"),
note=key_data.get("note"),
rate_multiplier=key_data.get("rate_multiplier", 1.0),
internal_priority=key_data.get("internal_priority", 100),
global_priority=key_data.get("global_priority"),
max_concurrent=key_data.get("max_concurrent"),
rate_limit=key_data.get("rate_limit"),
daily_limit=key_data.get("daily_limit"),
monthly_limit=key_data.get("monthly_limit"),
allowed_models=key_data.get("allowed_models"),
capabilities=key_data.get("capabilities"),
is_active=key_data.get("is_active", True),
for key_data in prov_data.get("api_keys", []):
if not key_data.get("api_key"):
stats["errors"].append(
f"跳过空 API Key (Provider: {prov_data['name']})"
)
db.add(new_key)
# 添加到已有集合,防止同一批导入中重复
existing_key_values.add(key_data["api_key"])
stats["keys"]["created"] += 1
continue
plaintext_key = key_data["api_key"]
if plaintext_key in existing_key_values:
stats["keys"]["skipped"] += 1
continue
raw_formats = key_data.get("api_formats") or []
if not isinstance(raw_formats, list) or len(raw_formats) == 0:
stats["errors"].append(
f"跳过无 api_formats 的 Key (Provider: {prov_data['name']})"
)
continue
normalized_formats: list[str] = []
seen: set[str] = set()
missing_formats: list[str] = []
for fmt in raw_formats:
if not isinstance(fmt, str):
continue
fmt_upper = fmt.strip().upper()
if not fmt_upper or fmt_upper in seen:
continue
seen.add(fmt_upper)
if endpoint_formats and fmt_upper not in endpoint_formats:
missing_formats.append(fmt_upper)
continue
normalized_formats.append(fmt_upper)
if missing_formats:
stats["errors"].append(
f"Key (Provider: {prov_data['name']}) 的 api_formats 未配置对应 Endpoint已跳过: {missing_formats}"
)
if len(normalized_formats) == 0:
stats["keys"]["skipped"] += 1
continue
encrypted_key = crypto_service.encrypt(plaintext_key)
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
provider_id=provider_id,
api_formats=normalized_formats,
api_key=encrypted_key,
name=key_data.get("name") or "Imported Key",
note=key_data.get("note"),
rate_multiplier=key_data.get("rate_multiplier", 1.0),
rate_multipliers=key_data.get("rate_multipliers"),
internal_priority=key_data.get("internal_priority", 50),
global_priority=key_data.get("global_priority"),
rpm_limit=key_data.get("rpm_limit"),
allowed_models=key_data.get("allowed_models"),
capabilities=key_data.get("capabilities"),
cache_ttl_minutes=key_data.get("cache_ttl_minutes", 5),
max_probe_interval_minutes=key_data.get("max_probe_interval_minutes", 32),
is_active=key_data.get("is_active", True),
health_by_format={},
circuit_breaker_by_format={},
)
db.add(new_key)
existing_key_values.add(plaintext_key)
stats["keys"]["created"] += 1
# 导入 Models
for model_data in prov_data.get("models", []):

View File

@@ -247,7 +247,8 @@ async def get_usage_detail(
- `request_headers`: 请求头
- `request_body`: 请求体
- `provider_request_headers`: 提供商请求头
- `response_headers`: 响应头
- `response_headers`: 提供商响应头
- `client_response_headers`: 返回给客户端的响应头
- `response_body`: 响应体
- `metadata`: 提供商响应元数据
- `tiered_pricing`: 阶梯计费信息(如适用)
@@ -916,6 +917,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
"request_body": usage_record.get_request_body(),
"provider_request_headers": usage_record.provider_request_headers,
"response_headers": usage_record.response_headers,
"client_response_headers": usage_record.client_response_headers,
"response_body": usage_record.get_response_body(),
"metadata": usage_record.request_metadata,
"tiered_pricing": tiered_pricing_info,