mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-13 13:07:22 +08:00
refactor: 重构限流系统和健康监控,支持按 API 格式区分
- 将 adaptive_concurrency 重命名为 adaptive_rpm,从并发控制改为 RPM 控制 - 健康监控器支持按 API 格式独立管理健康度和熔断器状态 - 新增 model_permissions 模块,支持按格式配置允许的模型 - 重构前端提供商相关表单组件,新增 Collapsible UI 组件 - 新增数据库迁移脚本支持新的数据结构
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
"""
|
||||
自适应并发管理 API 端点
|
||||
自适应 RPM 管理 API 端点
|
||||
|
||||
设计原则:
|
||||
- 自适应模式由 max_concurrent 字段决定:
|
||||
- max_concurrent = NULL:启用自适应模式,系统自动学习并调整并发限制
|
||||
- max_concurrent = 数字:固定限制模式,使用用户指定的并发限制
|
||||
- learned_max_concurrent:自适应模式下学习到的并发限制值
|
||||
- adaptive_mode 是计算字段,基于 max_concurrent 是否为 NULL
|
||||
- 自适应模式由 rpm_limit 字段决定:
|
||||
- rpm_limit = NULL:启用自适应模式,系统自动学习并调整 RPM 限制
|
||||
- rpm_limit = 数字:固定限制模式,使用用户指定的 RPM 限制
|
||||
- learned_rpm_limit:自适应模式下学习到的 RPM 限制值
|
||||
- adaptive_mode 是计算字段,基于 rpm_limit 是否为 NULL
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
@@ -18,12 +18,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.config.constants import RPMDefaults
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.database import get_db
|
||||
from src.models.database import ProviderAPIKey
|
||||
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager
|
||||
from src.services.rate_limit.adaptive_rpm import get_adaptive_rpm_manager
|
||||
|
||||
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive Concurrency"])
|
||||
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive RPM"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@@ -35,19 +36,19 @@ class EnableAdaptiveRequest(BaseModel):
|
||||
|
||||
enabled: bool = Field(..., description="是否启用自适应模式(true=自适应,false=固定限制)")
|
||||
fixed_limit: Optional[int] = Field(
|
||||
None, ge=1, le=100, description="固定并发限制(仅当 enabled=false 时生效)"
|
||||
None, ge=1, le=100, description="固定 RPM 限制(仅当 enabled=false 时生效,1-100)"
|
||||
)
|
||||
|
||||
|
||||
class AdaptiveStatsResponse(BaseModel):
|
||||
"""自适应统计响应"""
|
||||
|
||||
adaptive_mode: bool = Field(..., description="是否为自适应模式(max_concurrent=NULL)")
|
||||
max_concurrent: Optional[int] = Field(None, description="用户配置的固定限制(NULL=自适应)")
|
||||
adaptive_mode: bool = Field(..., description="是否为自适应模式(rpm_limit=NULL)")
|
||||
rpm_limit: Optional[int] = Field(None, description="用户配置的固定限制(NULL=自适应)")
|
||||
effective_limit: Optional[int] = Field(
|
||||
None, description="当前有效限制(自适应使用学习值,固定使用配置值)"
|
||||
)
|
||||
learned_limit: Optional[int] = Field(None, description="学习到的并发限制")
|
||||
learned_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
|
||||
concurrent_429_count: int
|
||||
rpm_429_count: int
|
||||
last_429_at: Optional[str]
|
||||
@@ -61,11 +62,12 @@ class KeyListItem(BaseModel):
|
||||
|
||||
id: str
|
||||
name: Optional[str]
|
||||
endpoint_id: str
|
||||
is_adaptive: bool = Field(..., description="是否为自适应模式(max_concurrent=NULL)")
|
||||
max_concurrent: Optional[int] = Field(None, description="固定并发限制(NULL=自适应)")
|
||||
provider_id: str
|
||||
api_formats: List[str] = Field(default_factory=list)
|
||||
is_adaptive: bool = Field(..., description="是否为自适应模式(rpm_limit=NULL)")
|
||||
rpm_limit: Optional[int] = Field(None, description="固定 RPM 限制(NULL=自适应)")
|
||||
effective_limit: Optional[int] = Field(None, description="当前有效限制")
|
||||
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
|
||||
learned_rpm_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
|
||||
concurrent_429_count: int
|
||||
rpm_429_count: int
|
||||
|
||||
@@ -80,22 +82,22 @@ class KeyListItem(BaseModel):
|
||||
)
|
||||
async def list_adaptive_keys(
|
||||
request: Request,
|
||||
endpoint_id: Optional[str] = Query(None, description="按 Endpoint 过滤"),
|
||||
provider_id: Optional[str] = Query(None, description="按 Provider 过滤"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取所有启用自适应模式的Key列表
|
||||
|
||||
可选参数:
|
||||
- endpoint_id: 按 Endpoint 过滤
|
||||
- provider_id: 按 Provider 过滤
|
||||
"""
|
||||
adapter = ListAdaptiveKeysAdapter(endpoint_id=endpoint_id)
|
||||
adapter = ListAdaptiveKeysAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/keys/{key_id}/mode",
|
||||
summary="Toggle key's concurrency control mode",
|
||||
summary="Toggle key's RPM control mode",
|
||||
)
|
||||
async def toggle_adaptive_mode(
|
||||
key_id: str,
|
||||
@@ -103,10 +105,10 @@ async def toggle_adaptive_mode(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Toggle the concurrency control mode for a specific key
|
||||
Toggle the RPM control mode for a specific key
|
||||
|
||||
Parameters:
|
||||
- enabled: true=adaptive mode (max_concurrent=NULL), false=fixed limit mode
|
||||
- enabled: true=adaptive mode (rpm_limit=NULL), false=fixed limit mode
|
||||
- fixed_limit: fixed limit value (required when enabled=false)
|
||||
"""
|
||||
adapter = ToggleAdaptiveModeAdapter(key_id=key_id)
|
||||
@@ -124,7 +126,7 @@ async def get_adaptive_stats(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定Key的自适应并发统计信息
|
||||
获取指定Key的自适应 RPM 统计信息
|
||||
|
||||
包括:
|
||||
- 当前配置
|
||||
@@ -149,12 +151,12 @@ async def reset_adaptive_learning(
|
||||
Reset the adaptive learning state for a specific key
|
||||
|
||||
Clears:
|
||||
- Learned concurrency limit (learned_max_concurrent)
|
||||
- Learned RPM limit (learned_rpm_limit)
|
||||
- 429 error counts
|
||||
- Adjustment history
|
||||
|
||||
Does not change:
|
||||
- max_concurrent config (determines adaptive mode)
|
||||
- rpm_limit config (determines adaptive mode)
|
||||
"""
|
||||
adapter = ResetAdaptiveLearningAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -162,40 +164,40 @@ async def reset_adaptive_learning(
|
||||
|
||||
@router.patch(
|
||||
"/keys/{key_id}/limit",
|
||||
summary="Set key to fixed concurrency limit mode",
|
||||
summary="Set key to fixed RPM limit mode",
|
||||
)
|
||||
async def set_concurrent_limit(
|
||||
async def set_rpm_limit(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(..., ge=1, le=100, description="Concurrency limit value"),
|
||||
limit: int = Query(..., ge=1, le=100, description="RPM limit value (1-100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Set key to fixed concurrency limit mode
|
||||
Set key to fixed RPM limit mode
|
||||
|
||||
Note:
|
||||
- After setting this value, key switches to fixed limit mode and won't auto-adjust
|
||||
- To restore adaptive mode, use PATCH /keys/{key_id}/mode
|
||||
"""
|
||||
adapter = SetConcurrentLimitAdapter(key_id=key_id, limit=limit)
|
||||
adapter = SetRPMLimitAdapter(key_id=key_id, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/summary",
|
||||
summary="获取自适应并发的全局统计",
|
||||
summary="获取自适应 RPM 的全局统计",
|
||||
)
|
||||
async def get_adaptive_summary(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取自适应并发的全局统计摘要
|
||||
获取自适应 RPM 的全局统计摘要
|
||||
|
||||
包括:
|
||||
- 启用自适应模式的Key数量
|
||||
- 总429错误数
|
||||
- 并发限制调整次数
|
||||
- RPM 限制调整次数
|
||||
"""
|
||||
adapter = AdaptiveSummaryAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -206,26 +208,29 @@ async def get_adaptive_summary(
|
||||
|
||||
@dataclass
|
||||
class ListAdaptiveKeysAdapter(AdminApiAdapter):
|
||||
endpoint_id: Optional[str] = None
|
||||
provider_id: Optional[str] = None
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
# 自适应模式:max_concurrent = NULL
|
||||
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None))
|
||||
if self.endpoint_id:
|
||||
query = query.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
|
||||
# 自适应模式:rpm_limit = NULL
|
||||
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.rpm_limit.is_(None))
|
||||
if self.provider_id:
|
||||
query = query.filter(ProviderAPIKey.provider_id == self.provider_id)
|
||||
|
||||
keys = query.all()
|
||||
return [
|
||||
KeyListItem(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
endpoint_id=key.endpoint_id,
|
||||
is_adaptive=key.max_concurrent is None,
|
||||
max_concurrent=key.max_concurrent,
|
||||
provider_id=key.provider_id,
|
||||
api_formats=key.api_formats or [],
|
||||
is_adaptive=key.rpm_limit is None,
|
||||
rpm_limit=key.rpm_limit,
|
||||
effective_limit=(
|
||||
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
|
||||
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
|
||||
if key.rpm_limit is None
|
||||
else key.rpm_limit
|
||||
),
|
||||
learned_max_concurrent=key.learned_max_concurrent,
|
||||
learned_rpm_limit=key.learned_rpm_limit,
|
||||
concurrent_429_count=key.concurrent_429_count or 0,
|
||||
rpm_429_count=key.rpm_429_count or 0,
|
||||
)
|
||||
@@ -252,28 +257,32 @@ class ToggleAdaptiveModeAdapter(AdminApiAdapter):
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
if body.enabled:
|
||||
# 启用自适应模式:将 max_concurrent 设为 NULL
|
||||
key.max_concurrent = None
|
||||
message = "已切换为自适应模式,系统将自动学习并调整并发限制"
|
||||
# 启用自适应模式:将 rpm_limit 设为 NULL
|
||||
key.rpm_limit = None
|
||||
message = "已切换为自适应模式,系统将自动学习并调整 RPM 限制"
|
||||
else:
|
||||
# 禁用自适应模式:设置固定限制
|
||||
if body.fixed_limit is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="禁用自适应模式时必须提供 fixed_limit 参数"
|
||||
)
|
||||
key.max_concurrent = body.fixed_limit
|
||||
message = f"已切换为固定限制模式,并发限制设为 {body.fixed_limit}"
|
||||
key.rpm_limit = body.fixed_limit
|
||||
message = f"已切换为固定限制模式,RPM 限制设为 {body.fixed_limit}"
|
||||
|
||||
context.db.commit()
|
||||
context.db.refresh(key)
|
||||
|
||||
is_adaptive = key.max_concurrent is None
|
||||
is_adaptive = key.rpm_limit is None
|
||||
return {
|
||||
"message": message,
|
||||
"key_id": key.id,
|
||||
"is_adaptive": is_adaptive,
|
||||
"max_concurrent": key.max_concurrent,
|
||||
"effective_limit": key.learned_max_concurrent if is_adaptive else key.max_concurrent,
|
||||
"rpm_limit": key.rpm_limit,
|
||||
"effective_limit": (
|
||||
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
|
||||
if is_adaptive
|
||||
else key.rpm_limit
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -286,13 +295,13 @@ class GetAdaptiveStatsAdapter(AdminApiAdapter):
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
adaptive_manager = get_adaptive_manager()
|
||||
adaptive_manager = get_adaptive_rpm_manager()
|
||||
stats = adaptive_manager.get_adjustment_stats(key)
|
||||
|
||||
# 转换字段名以匹配响应模型
|
||||
return AdaptiveStatsResponse(
|
||||
adaptive_mode=stats["adaptive_mode"],
|
||||
max_concurrent=stats["max_concurrent"],
|
||||
rpm_limit=stats["rpm_limit"],
|
||||
effective_limit=stats["effective_limit"],
|
||||
learned_limit=stats["learned_limit"],
|
||||
concurrent_429_count=stats["concurrent_429_count"],
|
||||
@@ -313,13 +322,13 @@ class ResetAdaptiveLearningAdapter(AdminApiAdapter):
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
adaptive_manager = get_adaptive_manager()
|
||||
adaptive_manager = get_adaptive_rpm_manager()
|
||||
adaptive_manager.reset_learning(context.db, key)
|
||||
return {"message": "学习状态已重置", "key_id": key.id}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetConcurrentLimitAdapter(AdminApiAdapter):
|
||||
class SetRPMLimitAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
limit: int
|
||||
|
||||
@@ -328,25 +337,25 @@ class SetConcurrentLimitAdapter(AdminApiAdapter):
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
was_adaptive = key.max_concurrent is None
|
||||
key.max_concurrent = self.limit
|
||||
was_adaptive = key.rpm_limit is None
|
||||
key.rpm_limit = self.limit
|
||||
context.db.commit()
|
||||
context.db.refresh(key)
|
||||
|
||||
return {
|
||||
"message": f"已设置为固定限制模式,并发限制为 {self.limit}",
|
||||
"message": f"已设置为固定限制模式,RPM 限制为 {self.limit}",
|
||||
"key_id": key.id,
|
||||
"is_adaptive": False,
|
||||
"max_concurrent": key.max_concurrent,
|
||||
"rpm_limit": key.rpm_limit,
|
||||
"previous_mode": "adaptive" if was_adaptive else "fixed",
|
||||
}
|
||||
|
||||
|
||||
class AdaptiveSummaryAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
# 自适应模式:max_concurrent = NULL
|
||||
# 自适应模式:rpm_limit = NULL
|
||||
adaptive_keys = (
|
||||
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None)).all()
|
||||
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.rpm_limit.is_(None)).all()
|
||||
)
|
||||
|
||||
total_keys = len(adaptive_keys)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""
|
||||
Endpoint 并发控制管理 API
|
||||
Key RPM 限制管理 API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -12,83 +11,56 @@ from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.database import get_db
|
||||
from src.models.database import ProviderAPIKey, ProviderEndpoint
|
||||
from src.models.endpoint_models import (
|
||||
ConcurrencyStatusResponse,
|
||||
ResetConcurrencyRequest,
|
||||
)
|
||||
from src.models.database import ProviderAPIKey
|
||||
from src.models.endpoint_models import KeyRpmStatusResponse
|
||||
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
||||
|
||||
router = APIRouter(tags=["Concurrency Control"])
|
||||
router = APIRouter(tags=["RPM Control"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/concurrency/endpoint/{endpoint_id}", response_model=ConcurrencyStatusResponse)
|
||||
async def get_endpoint_concurrency(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ConcurrencyStatusResponse:
|
||||
"""
|
||||
获取 Endpoint 当前并发状态
|
||||
|
||||
查询指定 Endpoint 的实时并发使用情况,包括当前并发数和最大并发限制。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
|
||||
**返回字段**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
- `endpoint_current_concurrency`: 当前并发数
|
||||
- `endpoint_max_concurrent`: 最大并发限制
|
||||
"""
|
||||
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/concurrency/key/{key_id}", response_model=ConcurrencyStatusResponse)
|
||||
async def get_key_concurrency(
|
||||
@router.get("/rpm/key/{key_id}", response_model=KeyRpmStatusResponse)
|
||||
async def get_key_rpm(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ConcurrencyStatusResponse:
|
||||
) -> KeyRpmStatusResponse:
|
||||
"""
|
||||
获取 Key 当前并发状态
|
||||
获取 Key 当前 RPM 状态
|
||||
|
||||
查询指定 API Key 的实时并发使用情况,包括当前并发数和最大并发限制。
|
||||
查询指定 API Key 的实时 RPM 使用情况,包括当前 RPM 计数和最大 RPM 限制。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**返回字段**:
|
||||
- `key_id`: API Key ID
|
||||
- `key_current_concurrency`: 当前并发数
|
||||
- `key_max_concurrent`: 最大并发限制
|
||||
- `current_rpm`: 当前 RPM 计数
|
||||
- `rpm_limit`: RPM 限制
|
||||
"""
|
||||
adapter = AdminKeyConcurrencyAdapter(key_id=key_id)
|
||||
adapter = AdminKeyRpmAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/concurrency")
|
||||
async def reset_concurrency(
|
||||
request: ResetConcurrencyRequest,
|
||||
@router.delete("/rpm/key/{key_id}")
|
||||
async def reset_key_rpm(
|
||||
key_id: str,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
重置并发计数器
|
||||
重置 Key RPM 计数器
|
||||
|
||||
重置指定 Endpoint 或 Key 的并发计数器,用于解决计数不准确的问题。
|
||||
重置指定 API Key 的 RPM 计数器,用于解决计数不准确的问题。
|
||||
管理员功能,请谨慎使用。
|
||||
|
||||
**请求体字段**:
|
||||
- `endpoint_id`: Endpoint ID(可选)
|
||||
- `key_id`: API Key ID(可选)
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果消息
|
||||
"""
|
||||
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id)
|
||||
adapter = AdminResetKeyRpmAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@@ -96,31 +68,7 @@ async def reset_concurrency(
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminEndpointConcurrencyAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
endpoint_count, _ = await concurrency_manager.get_current_concurrency(
|
||||
endpoint_id=self.endpoint_id
|
||||
)
|
||||
|
||||
return ConcurrencyStatusResponse(
|
||||
endpoint_id=self.endpoint_id,
|
||||
endpoint_current_concurrency=endpoint_count,
|
||||
endpoint_max_concurrent=endpoint.max_concurrent,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminKeyConcurrencyAdapter(AdminApiAdapter):
|
||||
class AdminKeyRpmAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
@@ -130,23 +78,20 @@ class AdminKeyConcurrencyAdapter(AdminApiAdapter):
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
_, key_count = await concurrency_manager.get_current_concurrency(key_id=self.key_id)
|
||||
key_count = await concurrency_manager.get_key_rpm_count(key_id=self.key_id)
|
||||
|
||||
return ConcurrencyStatusResponse(
|
||||
return KeyRpmStatusResponse(
|
||||
key_id=self.key_id,
|
||||
key_current_concurrency=key_count,
|
||||
key_max_concurrent=key.max_concurrent,
|
||||
current_rpm=key_count,
|
||||
rpm_limit=key.rpm_limit,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminResetConcurrencyAdapter(AdminApiAdapter):
|
||||
endpoint_id: Optional[str]
|
||||
key_id: Optional[str]
|
||||
class AdminResetKeyRpmAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
await concurrency_manager.reset_concurrency(
|
||||
endpoint_id=self.endpoint_id, key_id=self.key_id
|
||||
)
|
||||
return {"message": "并发计数已重置"}
|
||||
await concurrency_manager.reset_key_rpm(key_id=self.key_id)
|
||||
return {"message": "RPM 计数已重置"}
|
||||
|
||||
@@ -5,7 +5,7 @@ Endpoint 健康监控 API
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import func
|
||||
@@ -128,29 +128,32 @@ async def get_api_format_health_monitor(
|
||||
async def get_key_health(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
api_format: Optional[str] = Query(None, description="API 格式(可选,如 CLAUDE、OPENAI)"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> HealthStatusResponse:
|
||||
"""
|
||||
获取 Key 健康状态
|
||||
|
||||
获取指定 API Key 的健康状态详情,包括健康分数、连续失败次数、
|
||||
熔断器状态等信息。
|
||||
熔断器状态等信息。支持按 API 格式查询。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**查询参数**:
|
||||
- `api_format`: 可选,指定 API 格式(如 CLAUDE、OPENAI)。
|
||||
- 指定时返回该格式的健康度详情
|
||||
- 不指定时返回所有格式的健康度摘要
|
||||
|
||||
**返回字段**:
|
||||
- `key_id`: API Key ID
|
||||
- `key_health_score`: 健康分数(0.0-1.0)
|
||||
- `key_consecutive_failures`: 连续失败次数
|
||||
- `key_last_failure_at`: 最后失败时间
|
||||
- `key_is_active`: 是否活跃
|
||||
- `key_statistics`: 统计信息
|
||||
- `circuit_breaker_open`: 熔断器是否打开
|
||||
- `circuit_breaker_open_at`: 熔断器打开时间
|
||||
- `next_probe_at`: 下次探测时间
|
||||
- `health_by_format`: 按格式的健康度数据(无 api_format 参数时)
|
||||
- `circuit_breaker_open`: 熔断器是否打开(有 api_format 参数时)
|
||||
"""
|
||||
adapter = AdminKeyHealthAdapter(key_id=key_id)
|
||||
adapter = AdminKeyHealthAdapter(key_id=key_id, api_format=api_format)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@@ -158,17 +161,23 @@ async def get_key_health(
|
||||
async def recover_key_health(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
api_format: Optional[str] = Query(None, description="API 格式(可选,不指定则恢复所有格式)"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
恢复 Key 健康状态
|
||||
|
||||
手动恢复指定 Key 的健康状态,将健康分数重置为 1.0,关闭熔断器,
|
||||
取消自动禁用,并重置所有失败计数。
|
||||
取消自动禁用,并重置所有失败计数。支持按 API 格式恢复。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**查询参数**:
|
||||
- `api_format`: 可选,指定 API 格式(如 CLAUDE、OPENAI)
|
||||
- 指定时仅恢复该格式的健康度
|
||||
- 不指定时恢复所有格式
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果消息
|
||||
- `details`: 详细信息
|
||||
@@ -176,7 +185,7 @@ async def recover_key_health(
|
||||
- `circuit_breaker_open`: 熔断器状态
|
||||
- `is_active`: 是否活跃
|
||||
"""
|
||||
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id)
|
||||
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id, api_format=api_format)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@@ -276,34 +285,9 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
|
||||
)
|
||||
all_formats[api_format] = provider_count
|
||||
|
||||
# 1.1 获取所有活跃的 API 格式及其 API Key 数量
|
||||
active_keys = (
|
||||
db.query(
|
||||
ProviderEndpoint.api_format,
|
||||
func.count(ProviderAPIKey.id).label("key_count"),
|
||||
)
|
||||
.join(ProviderAPIKey, ProviderEndpoint.id == ProviderAPIKey.endpoint_id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
.group_by(ProviderEndpoint.api_format)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构建所有格式的 key_count 映射
|
||||
key_counts: Dict[str, int] = {}
|
||||
for api_format_enum, key_count in active_keys:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
key_counts[api_format] = key_count
|
||||
|
||||
# 1.2 建立每个 API 格式对应的 Endpoint ID 列表,供 Usage 时间线生成使用
|
||||
# 1.1 建立每个 API 格式对应的 Endpoint ID 列表(用于时间线生成),并收集活跃的 provider+format 组合
|
||||
endpoint_rows = (
|
||||
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
|
||||
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id, ProviderEndpoint.provider_id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
@@ -312,11 +296,32 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
|
||||
.all()
|
||||
)
|
||||
endpoint_map: Dict[str, List[str]] = defaultdict(list)
|
||||
for api_format_enum, endpoint_id in endpoint_rows:
|
||||
active_provider_formats: set[tuple[str, str]] = set()
|
||||
for api_format_enum, endpoint_id, provider_id in endpoint_rows:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
endpoint_map[api_format].append(endpoint_id)
|
||||
active_provider_formats.add((str(provider_id), api_format))
|
||||
|
||||
# 1.2 统计每个 API 格式可用的活跃 Key 数量(Key 属于 Provider,通过 api_formats 关联格式)
|
||||
key_counts: Dict[str, int] = {}
|
||||
if active_provider_formats:
|
||||
active_provider_keys = (
|
||||
db.query(ProviderAPIKey.provider_id, ProviderAPIKey.api_formats)
|
||||
.join(Provider, ProviderAPIKey.provider_id == Provider.id)
|
||||
.filter(
|
||||
Provider.is_active.is_(True),
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for provider_id, api_formats in active_provider_keys:
|
||||
pid = str(provider_id)
|
||||
for fmt in (api_formats or []):
|
||||
if (pid, fmt) not in active_provider_formats:
|
||||
continue
|
||||
key_counts[fmt] = key_counts.get(fmt, 0) + 1
|
||||
|
||||
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
|
||||
# 只统计最终状态:success, failed, skipped
|
||||
@@ -457,28 +462,45 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
|
||||
@dataclass
|
||||
class AdminKeyHealthAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
api_format: Optional[str] = None
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
health_data = health_monitor.get_key_health(context.db, self.key_id)
|
||||
health_data = health_monitor.get_key_health(context.db, self.key_id, self.api_format)
|
||||
if not health_data:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
return HealthStatusResponse(
|
||||
key_id=health_data["key_id"],
|
||||
key_health_score=health_data["health_score"],
|
||||
key_consecutive_failures=health_data["consecutive_failures"],
|
||||
key_last_failure_at=health_data["last_failure_at"],
|
||||
key_is_active=health_data["is_active"],
|
||||
key_statistics=health_data["statistics"],
|
||||
circuit_breaker_open=health_data["circuit_breaker_open"],
|
||||
circuit_breaker_open_at=health_data["circuit_breaker_open_at"],
|
||||
next_probe_at=health_data["next_probe_at"],
|
||||
)
|
||||
# 构建响应
|
||||
response_data = {
|
||||
"key_id": health_data["key_id"],
|
||||
"key_is_active": health_data["is_active"],
|
||||
"key_statistics": health_data.get("statistics"),
|
||||
"key_health_score": health_data.get("health_score", 1.0),
|
||||
}
|
||||
|
||||
if self.api_format:
|
||||
# 单格式查询
|
||||
response_data["api_format"] = self.api_format
|
||||
response_data["key_consecutive_failures"] = health_data.get("consecutive_failures")
|
||||
response_data["key_last_failure_at"] = health_data.get("last_failure_at")
|
||||
circuit = health_data.get("circuit_breaker", {})
|
||||
response_data["circuit_breaker_open"] = circuit.get("open", False)
|
||||
response_data["circuit_breaker_open_at"] = circuit.get("open_at")
|
||||
response_data["next_probe_at"] = circuit.get("next_probe_at")
|
||||
response_data["half_open_until"] = circuit.get("half_open_until")
|
||||
response_data["half_open_successes"] = circuit.get("half_open_successes", 0)
|
||||
response_data["half_open_failures"] = circuit.get("half_open_failures", 0)
|
||||
else:
|
||||
# 全格式查询
|
||||
response_data["any_circuit_open"] = health_data.get("any_circuit_open", False)
|
||||
response_data["health_by_format"] = health_data.get("health_by_format")
|
||||
|
||||
return HealthStatusResponse(**response_data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
api_format: Optional[str] = None
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
@@ -486,28 +508,38 @@ class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
key.health_score = 1.0
|
||||
key.consecutive_failures = 0
|
||||
key.last_failure_at = None
|
||||
key.circuit_breaker_open = False
|
||||
key.circuit_breaker_open_at = None
|
||||
key.next_probe_at = None
|
||||
# 使用 health_monitor.reset_health 重置健康度
|
||||
success = health_monitor.reset_health(db, key_id=self.key_id, api_format=self.api_format)
|
||||
if not success:
|
||||
raise Exception("重置健康度失败")
|
||||
|
||||
# 如果 Key 被禁用,重新启用
|
||||
if not key.is_active:
|
||||
key.is_active = True
|
||||
key.is_active = True # type: ignore[assignment]
|
||||
|
||||
db.commit()
|
||||
|
||||
admin_name = context.user.username if context.user else "admin"
|
||||
logger.info(f"管理员恢复Key健康状态: {self.key_id} (health_score: 1.0, circuit_breaker: closed)")
|
||||
|
||||
return {
|
||||
"message": "Key已完全恢复",
|
||||
"details": {
|
||||
"health_score": 1.0,
|
||||
"circuit_breaker_open": False,
|
||||
"is_active": True,
|
||||
},
|
||||
}
|
||||
if self.api_format:
|
||||
logger.info(f"管理员恢复Key健康状态: {self.key_id}/{self.api_format}")
|
||||
return {
|
||||
"message": f"Key 的 {self.api_format} 格式已恢复",
|
||||
"details": {
|
||||
"api_format": self.api_format,
|
||||
"health_score": 1.0,
|
||||
"circuit_breaker_open": False,
|
||||
"is_active": True,
|
||||
},
|
||||
}
|
||||
else:
|
||||
logger.info(f"管理员恢复Key健康状态: {self.key_id} (所有格式)")
|
||||
return {
|
||||
"message": "Key 所有格式已恢复",
|
||||
"details": {
|
||||
"health_score": 1.0,
|
||||
"circuit_breaker_open": False,
|
||||
"is_active": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
||||
@@ -516,10 +548,17 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 查找所有熔断的 Key
|
||||
circuit_open_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.circuit_breaker_open == True).all()
|
||||
)
|
||||
# 查找所有有熔断格式的 Key(检查 circuit_breaker_by_format JSON 字段)
|
||||
all_keys = db.query(ProviderAPIKey).all()
|
||||
|
||||
# 筛选出有任何格式熔断的 Key
|
||||
circuit_open_keys = []
|
||||
for key in all_keys:
|
||||
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||
for fmt, circuit_data in circuit_by_format.items():
|
||||
if circuit_data.get("open"):
|
||||
circuit_open_keys.append(key)
|
||||
break
|
||||
|
||||
if not circuit_open_keys:
|
||||
return {
|
||||
@@ -530,17 +569,15 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
||||
|
||||
recovered_keys = []
|
||||
for key in circuit_open_keys:
|
||||
key.health_score = 1.0
|
||||
key.consecutive_failures = 0
|
||||
key.last_failure_at = None
|
||||
key.circuit_breaker_open = False
|
||||
key.circuit_breaker_open_at = None
|
||||
key.next_probe_at = None
|
||||
# 重置所有格式的健康度
|
||||
key.health_by_format = {} # type: ignore[assignment]
|
||||
key.circuit_breaker_by_format = {} # type: ignore[assignment]
|
||||
recovered_keys.append(
|
||||
{
|
||||
"key_id": key.id,
|
||||
"key_name": key.name,
|
||||
"endpoint_id": key.endpoint_id,
|
||||
"provider_id": key.provider_id,
|
||||
"api_formats": key.api_formats,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -552,7 +589,6 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
||||
HealthMonitor._open_circuit_keys = 0
|
||||
health_open_circuits.set(0)
|
||||
|
||||
admin_name = context.user.username if context.user else "admin"
|
||||
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Endpoint API Keys 管理
|
||||
Provider API Keys 管理
|
||||
"""
|
||||
|
||||
import uuid
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.config.constants import RPMDefaults
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.key_capabilities import get_capability
|
||||
@@ -20,96 +21,14 @@ from src.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
from src.services.cache.provider_cache import ProviderCacheService
|
||||
from src.models.endpoint_models import (
|
||||
BatchUpdateKeyPriorityRequest,
|
||||
EndpointAPIKeyCreate,
|
||||
EndpointAPIKeyResponse,
|
||||
EndpointAPIKeyUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Endpoint Keys"])
|
||||
router = APIRouter(tags=["Provider Keys"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
|
||||
async def list_endpoint_keys(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[EndpointAPIKeyResponse]:
|
||||
"""
|
||||
获取 Endpoint 的所有 Keys
|
||||
|
||||
获取指定 Endpoint 下的所有 API Key 列表,包括 Key 的配置、统计信息等。
|
||||
结果按优先级和创建时间排序。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
|
||||
**查询参数**:
|
||||
- `skip`: 跳过的记录数,用于分页(默认 0)
|
||||
- `limit`: 返回的最大记录数(1-1000,默认 100)
|
||||
|
||||
**返回字段**:
|
||||
- `id`: Key ID
|
||||
- `name`: Key 名称
|
||||
- `api_key_masked`: 脱敏后的 API Key
|
||||
- `internal_priority`: 内部优先级
|
||||
- `global_priority`: 全局优先级
|
||||
- `rate_multiplier`: 速率倍数
|
||||
- `max_concurrent`: 最大并发数(null 表示自适应模式)
|
||||
- `is_adaptive`: 是否为自适应并发模式
|
||||
- `effective_limit`: 有效并发限制
|
||||
- `success_rate`: 成功率
|
||||
- `avg_response_time_ms`: 平均响应时间(毫秒)
|
||||
- 其他配置和统计字段
|
||||
"""
|
||||
adapter = AdminListEndpointKeysAdapter(
|
||||
endpoint_id=endpoint_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
|
||||
async def add_endpoint_key(
|
||||
endpoint_id: str,
|
||||
key_data: EndpointAPIKeyCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""
|
||||
为 Endpoint 添加 Key
|
||||
|
||||
为指定 Endpoint 添加新的 API Key,支持配置并发限制、速率倍数、
|
||||
优先级、配额限制、能力限制等。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
|
||||
**请求体字段**:
|
||||
- `endpoint_id`: Endpoint ID(必须与路径参数一致)
|
||||
- `api_key`: API Key 原文(将被加密存储)
|
||||
- `name`: Key 名称
|
||||
- `note`: 备注(可选)
|
||||
- `rate_multiplier`: 速率倍数(默认 1.0)
|
||||
- `internal_priority`: 内部优先级(默认 100)
|
||||
- `max_concurrent`: 最大并发数(null 表示自适应模式)
|
||||
- `rate_limit`: 每分钟请求限制(可选)
|
||||
- `daily_limit`: 每日请求限制(可选)
|
||||
- `monthly_limit`: 每月请求限制(可选)
|
||||
- `allowed_models`: 允许的模型列表(可选)
|
||||
- `capabilities`: 能力配置(可选)
|
||||
|
||||
**返回字段**:
|
||||
- 包含完整的 Key 信息,其中 `api_key_plain` 为原文(仅在创建时返回)
|
||||
"""
|
||||
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
|
||||
async def update_endpoint_key(
|
||||
key_id: str,
|
||||
@@ -118,7 +37,7 @@ async def update_endpoint_key(
|
||||
db: Session = Depends(get_db),
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""
|
||||
更新 Endpoint Key
|
||||
更新 Provider Key
|
||||
|
||||
更新指定 Key 的配置,支持修改并发限制、速率倍数、优先级、
|
||||
配额限制、能力限制等。支持部分更新。
|
||||
@@ -132,10 +51,7 @@ async def update_endpoint_key(
|
||||
- `note`: 备注
|
||||
- `rate_multiplier`: 速率倍数
|
||||
- `internal_priority`: 内部优先级
|
||||
- `max_concurrent`: 最大并发数(设置为 null 可切换到自适应模式)
|
||||
- `rate_limit`: 每分钟请求限制
|
||||
- `daily_limit`: 每日请求限制
|
||||
- `monthly_limit`: 每月请求限制
|
||||
- `rpm_limit`: RPM 限制(设置为 null 可切换到自适应模式)
|
||||
- `allowed_models`: 允许的模型列表
|
||||
- `capabilities`: 能力配置
|
||||
- `is_active`: 是否活跃
|
||||
@@ -210,7 +126,7 @@ async def delete_endpoint_key(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
删除 Endpoint Key
|
||||
删除 Provider Key
|
||||
|
||||
删除指定的 API Key。此操作不可逆,请谨慎使用。
|
||||
|
||||
@@ -224,163 +140,66 @@ async def delete_endpoint_key(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{endpoint_id}/keys/batch-priority")
|
||||
async def batch_update_key_priority(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
priority_data: BatchUpdateKeyPriorityRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
批量更新 Endpoint 下 Keys 的优先级
|
||||
# ========== Provider Keys API ==========
|
||||
|
||||
批量更新指定 Endpoint 下多个 Key 的内部优先级,用于拖动排序。
|
||||
所有 Key 必须属于指定的 Endpoint。
|
||||
|
||||
@router.get("/providers/{provider_id}/keys", response_model=List[EndpointAPIKeyResponse])
|
||||
async def list_provider_keys(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[EndpointAPIKeyResponse]:
|
||||
"""
|
||||
获取 Provider 的所有 Keys
|
||||
|
||||
获取指定 Provider 下的所有 API Key 列表,支持多 API 格式。
|
||||
结果按优先级和创建时间排序。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
- `provider_id`: Provider ID
|
||||
|
||||
**查询参数**:
|
||||
- `skip`: 跳过的记录数,用于分页(默认 0)
|
||||
- `limit`: 返回的最大记录数(1-1000,默认 100)
|
||||
"""
|
||||
adapter = AdminListProviderKeysAdapter(
|
||||
provider_id=provider_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/providers/{provider_id}/keys", response_model=EndpointAPIKeyResponse)
|
||||
async def add_provider_key(
|
||||
provider_id: str,
|
||||
key_data: EndpointAPIKeyCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""
|
||||
为 Provider 添加 Key
|
||||
|
||||
为指定 Provider 添加新的 API Key,支持配置多个 API 格式。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: Provider ID
|
||||
|
||||
**请求体字段**:
|
||||
- `priorities`: 优先级列表
|
||||
- `key_id`: Key ID
|
||||
- `internal_priority`: 新的内部优先级
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果消息
|
||||
- `updated_count`: 实际更新的 Key 数量
|
||||
- `api_formats`: 支持的 API 格式列表(必填)
|
||||
- `api_key`: API Key 原文(将被加密存储)
|
||||
- `name`: Key 名称
|
||||
- 其他配置字段同 Key
|
||||
"""
|
||||
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
|
||||
adapter = AdminCreateProviderKeyAdapter(provider_id=provider_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListEndpointKeysAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
|
||||
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
|
||||
.offset(self.skip)
|
||||
.limit(self.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
result: List[EndpointAPIKeyResponse] = []
|
||||
for key in keys:
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
||||
avg_response_time_ms = (
|
||||
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
||||
)
|
||||
|
||||
is_adaptive = key.max_concurrent is None
|
||||
key_dict = key.__dict__.copy()
|
||||
key_dict.pop("_sa_instance_state", None)
|
||||
key_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": None,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
result.append(EndpointAPIKeyResponse(**key_dict))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
key_data: EndpointAPIKeyCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
if self.key_data.endpoint_id != self.endpoint_id:
|
||||
raise InvalidRequestException("endpoint_id 不匹配")
|
||||
|
||||
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
|
||||
now = datetime.now(timezone.utc)
|
||||
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
|
||||
new_key = ProviderAPIKey(
|
||||
id=str(uuid.uuid4()),
|
||||
endpoint_id=self.endpoint_id,
|
||||
api_key=encrypted_key,
|
||||
name=self.key_data.name,
|
||||
note=self.key_data.note,
|
||||
rate_multiplier=self.key_data.rate_multiplier,
|
||||
internal_priority=self.key_data.internal_priority,
|
||||
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
|
||||
rate_limit=self.key_data.rate_limit,
|
||||
daily_limit=self.key_data.daily_limit,
|
||||
monthly_limit=self.key_data.monthly_limit,
|
||||
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
|
||||
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
|
||||
request_count=0,
|
||||
success_count=0,
|
||||
error_count=0,
|
||||
total_response_time_ms=0,
|
||||
is_active=True,
|
||||
last_used_at=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
db.add(new_key)
|
||||
db.commit()
|
||||
db.refresh(new_key)
|
||||
|
||||
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
|
||||
|
||||
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
|
||||
is_adaptive = new_key.max_concurrent is None
|
||||
response_dict = new_key.__dict__.copy()
|
||||
response_dict.pop("_sa_instance_state", None)
|
||||
response_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": self.key_data.api_key,
|
||||
"success_rate": 0.0,
|
||||
"avg_response_time_ms": 0.0,
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return EndpointAPIKeyResponse(**response_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
@@ -396,14 +215,21 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
||||
if "api_key" in update_data:
|
||||
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
|
||||
|
||||
# 特殊处理 max_concurrent:需要区分"未提供"和"显式设置为 null"
|
||||
# 当 max_concurrent 被显式设置时(在 model_fields_set 中),即使值为 None 也应该更新
|
||||
if "max_concurrent" in self.key_data.model_fields_set:
|
||||
update_data["max_concurrent"] = self.key_data.max_concurrent
|
||||
# 切换到自适应模式时,清空学习到的并发限制,让系统重新学习
|
||||
if self.key_data.max_concurrent is None:
|
||||
update_data["learned_max_concurrent"] = None
|
||||
logger.info("Key %s 切换为自适应并发模式", self.key_id)
|
||||
# 特殊处理 rpm_limit:需要区分"未提供"和"显式设置为 null"
|
||||
if "rpm_limit" in self.key_data.model_fields_set:
|
||||
update_data["rpm_limit"] = self.key_data.rpm_limit
|
||||
if self.key_data.rpm_limit is None:
|
||||
update_data["learned_rpm_limit"] = None
|
||||
logger.info("Key %s 切换为自适应 RPM 模式", self.key_id)
|
||||
|
||||
# 统一处理 allowed_models:空列表/空字典 -> None(表示不限制)
|
||||
if "allowed_models" in update_data:
|
||||
am = update_data["allowed_models"]
|
||||
if am is not None and (
|
||||
(isinstance(am, list) and len(am) == 0)
|
||||
or (isinstance(am, dict) and len(am) == 0)
|
||||
):
|
||||
update_data["allowed_models"] = None
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(key, field, value)
|
||||
@@ -412,39 +238,13 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
|
||||
# 如果更新了 rate_multiplier,清除缓存
|
||||
if "rate_multiplier" in update_data:
|
||||
await ProviderCacheService.invalidate_provider_api_key_cache(self.key_id)
|
||||
# 任何字段更新都清除缓存,确保缓存一致性
|
||||
# 包括 is_active、allowed_models、capabilities 等影响权限和行为的字段
|
||||
await ProviderCacheService.invalidate_provider_api_key_cache(self.key_id)
|
||||
|
||||
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
|
||||
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
||||
avg_response_time_ms = (
|
||||
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
||||
)
|
||||
|
||||
is_adaptive = key.max_concurrent is None
|
||||
response_dict = key.__dict__.copy()
|
||||
response_dict.pop("_sa_instance_state", None)
|
||||
response_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": None,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
return EndpointAPIKeyResponse(**response_dict)
|
||||
return _build_key_response(key)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -481,7 +281,7 @@ class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
endpoint_id = key.endpoint_id
|
||||
provider_id = key.provider_id
|
||||
try:
|
||||
db.delete(key)
|
||||
db.commit()
|
||||
@@ -490,7 +290,7 @@ class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
|
||||
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
|
||||
raise
|
||||
|
||||
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}")
|
||||
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Provider={provider_id}")
|
||||
return {"message": f"Key {self.key_id} 已删除"}
|
||||
|
||||
|
||||
@@ -498,31 +298,51 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# Key 属于 Provider:按 key.api_formats 分组展示
|
||||
keys = (
|
||||
db.query(ProviderAPIKey, ProviderEndpoint, Provider)
|
||||
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
db.query(ProviderAPIKey, Provider)
|
||||
.join(Provider, ProviderAPIKey.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.order_by(
|
||||
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc()
|
||||
ProviderAPIKey.global_priority.asc().nullslast(),
|
||||
ProviderAPIKey.internal_priority.asc(),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_ids = {str(provider.id) for _key, provider in keys}
|
||||
endpoints = (
|
||||
db.query(
|
||||
ProviderEndpoint.provider_id,
|
||||
ProviderEndpoint.api_format,
|
||||
ProviderEndpoint.base_url,
|
||||
)
|
||||
.filter(
|
||||
ProviderEndpoint.provider_id.in_(provider_ids),
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
endpoint_base_url_map: Dict[tuple[str, str], str] = {}
|
||||
for provider_id, api_format, base_url in endpoints:
|
||||
fmt = api_format.value if hasattr(api_format, "value") else str(api_format)
|
||||
endpoint_base_url_map[(str(provider_id), fmt)] = base_url
|
||||
|
||||
grouped: Dict[str, List[dict]] = {}
|
||||
for key, endpoint, provider in keys:
|
||||
api_format = endpoint.api_format
|
||||
if api_format not in grouped:
|
||||
grouped[api_format] = []
|
||||
for key, provider in keys:
|
||||
api_formats = key.api_formats or []
|
||||
|
||||
if not api_formats:
|
||||
continue # 跳过没有 API 格式的 Key
|
||||
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(f"解密 Key 失败: key_id={key.id}, error={e}")
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
# 计算健康度指标
|
||||
@@ -541,73 +361,209 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
|
||||
cap_def = get_capability(cap_name)
|
||||
caps_list.append(cap_def.short_name if cap_def else cap_name)
|
||||
|
||||
grouped[api_format].append(
|
||||
{
|
||||
"id": key.id,
|
||||
"name": key.name,
|
||||
"api_key_masked": masked_key,
|
||||
"internal_priority": key.internal_priority,
|
||||
"global_priority": key.global_priority,
|
||||
"rate_multiplier": key.rate_multiplier,
|
||||
"is_active": key.is_active,
|
||||
"circuit_breaker_open": key.circuit_breaker_open,
|
||||
"provider_name": provider.display_name or provider.name,
|
||||
"endpoint_base_url": endpoint.base_url,
|
||||
"api_format": api_format,
|
||||
"capabilities": caps_list,
|
||||
"health_score": key.health_score,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": avg_response_time_ms,
|
||||
"request_count": key.request_count,
|
||||
}
|
||||
)
|
||||
# 构建 Key 信息(基础数据)
|
||||
key_info = {
|
||||
"id": key.id,
|
||||
"name": key.name,
|
||||
"api_key_masked": masked_key,
|
||||
"internal_priority": key.internal_priority,
|
||||
"global_priority": key.global_priority,
|
||||
"rate_multiplier": key.rate_multiplier,
|
||||
"is_active": key.is_active,
|
||||
"provider_name": provider.name,
|
||||
"api_formats": api_formats,
|
||||
"capabilities": caps_list,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": avg_response_time_ms,
|
||||
"request_count": key.request_count,
|
||||
}
|
||||
|
||||
# 将 Key 添加到每个支持的格式分组中,并附加格式特定的健康度数据
|
||||
health_by_format = key.health_by_format or {}
|
||||
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||
provider_id = str(provider.id)
|
||||
for api_format in api_formats:
|
||||
if api_format not in grouped:
|
||||
grouped[api_format] = []
|
||||
# 为每个格式创建副本,设置当前格式
|
||||
format_key_info = key_info.copy()
|
||||
format_key_info["api_format"] = api_format
|
||||
format_key_info["endpoint_base_url"] = endpoint_base_url_map.get(
|
||||
(provider_id, api_format)
|
||||
)
|
||||
# 添加格式特定的健康度数据
|
||||
format_health = health_by_format.get(api_format, {})
|
||||
format_circuit = circuit_by_format.get(api_format, {})
|
||||
format_key_info["health_score"] = float(format_health.get("health_score") or 1.0)
|
||||
format_key_info["circuit_breaker_open"] = bool(format_circuit.get("open", False))
|
||||
grouped[api_format].append(format_key_info)
|
||||
|
||||
# 直接返回分组对象,供前端使用
|
||||
return grouped
|
||||
|
||||
|
||||
# ========== Adapters ==========
|
||||
|
||||
|
||||
def _build_key_response(
|
||||
key: ProviderAPIKey, api_key_plain: str | None = None
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""构建 Key 响应对象的辅助函数"""
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
||||
avg_response_time_ms = (
|
||||
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
||||
)
|
||||
|
||||
is_adaptive = key.rpm_limit is None
|
||||
key_dict = key.__dict__.copy()
|
||||
key_dict.pop("_sa_instance_state", None)
|
||||
|
||||
# 从 health_by_format 计算汇总字段(便于列表展示)
|
||||
health_by_format = key.health_by_format or {}
|
||||
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||
|
||||
# 计算整体健康度(取所有格式中的最低值)
|
||||
if health_by_format:
|
||||
health_scores = [
|
||||
float(h.get("health_score") or 1.0) for h in health_by_format.values()
|
||||
]
|
||||
min_health_score = min(health_scores) if health_scores else 1.0
|
||||
# 取最大的连续失败次数
|
||||
max_consecutive = max(
|
||||
(int(h.get("consecutive_failures") or 0) for h in health_by_format.values()),
|
||||
default=0,
|
||||
)
|
||||
# 取最近的失败时间
|
||||
failure_times = [
|
||||
h.get("last_failure_at")
|
||||
for h in health_by_format.values()
|
||||
if h.get("last_failure_at")
|
||||
]
|
||||
last_failure = max(failure_times) if failure_times else None
|
||||
else:
|
||||
min_health_score = 1.0
|
||||
max_consecutive = 0
|
||||
last_failure = None
|
||||
|
||||
# 检查是否有任何格式的熔断器打开
|
||||
any_circuit_open = any(c.get("open", False) for c in circuit_by_format.values())
|
||||
|
||||
key_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": api_key_plain,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
|
||||
if is_adaptive
|
||||
else key.rpm_limit
|
||||
),
|
||||
# 汇总字段
|
||||
"health_score": min_health_score,
|
||||
"consecutive_failures": max_consecutive,
|
||||
"last_failure_at": last_failure,
|
||||
"circuit_breaker_open": any_circuit_open,
|
||||
}
|
||||
)
|
||||
|
||||
# 防御性:确保 api_formats 存在(历史数据可能为空/缺失)
|
||||
if "api_formats" not in key_dict or key_dict["api_formats"] is None:
|
||||
key_dict["api_formats"] = []
|
||||
|
||||
return EndpointAPIKeyResponse(**key_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
priority_data: BatchUpdateKeyPriorityRequest
|
||||
class AdminListProviderKeysAdapter(AdminApiAdapter):
|
||||
"""获取 Provider 的所有 Keys"""
|
||||
|
||||
provider_id: str
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||
|
||||
# 获取所有需要更新的 Key ID
|
||||
key_ids = [item.key_id for item in self.priority_data.priorities]
|
||||
|
||||
# 验证所有 Key 都属于该 Endpoint
|
||||
keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
ProviderAPIKey.id.in_(key_ids),
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
)
|
||||
.filter(ProviderAPIKey.provider_id == self.provider_id)
|
||||
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
|
||||
.offset(self.skip)
|
||||
.limit(self.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(keys) != len(key_ids):
|
||||
found_ids = {k.id for k in keys}
|
||||
missing_ids = set(key_ids) - found_ids
|
||||
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
|
||||
return [_build_key_response(key) for key in keys]
|
||||
|
||||
# 批量更新优先级
|
||||
key_map = {k.id: k for k in keys}
|
||||
updated_count = 0
|
||||
for item in self.priority_data.priorities:
|
||||
key = key_map.get(item.key_id)
|
||||
if key and key.internal_priority != item.internal_priority:
|
||||
key.internal_priority = item.internal_priority
|
||||
key.updated_at = datetime.now(timezone.utc)
|
||||
updated_count += 1
|
||||
|
||||
@dataclass
|
||||
class AdminCreateProviderKeyAdapter(AdminApiAdapter):
|
||||
"""为 Provider 添加 Key"""
|
||||
|
||||
provider_id: str
|
||||
key_data: EndpointAPIKeyCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||
|
||||
# 验证 api_formats 必填
|
||||
if not self.key_data.api_formats:
|
||||
raise InvalidRequestException("api_formats 为必填字段")
|
||||
|
||||
# 允许同一个 API Key 在同一 Provider 下添加多次
|
||||
# 用户可以为不同的 API 格式创建独立的配置记录,便于分开管理
|
||||
|
||||
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
new_key = ProviderAPIKey(
|
||||
id=str(uuid.uuid4()),
|
||||
provider_id=self.provider_id,
|
||||
api_formats=self.key_data.api_formats,
|
||||
api_key=encrypted_key,
|
||||
name=self.key_data.name,
|
||||
note=self.key_data.note,
|
||||
rate_multiplier=self.key_data.rate_multiplier,
|
||||
rate_multipliers=self.key_data.rate_multipliers, # 按 API 格式的成本倍率
|
||||
internal_priority=self.key_data.internal_priority,
|
||||
rpm_limit=self.key_data.rpm_limit,
|
||||
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
|
||||
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
|
||||
cache_ttl_minutes=self.key_data.cache_ttl_minutes,
|
||||
max_probe_interval_minutes=self.key_data.max_probe_interval_minutes,
|
||||
request_count=0,
|
||||
success_count=0,
|
||||
error_count=0,
|
||||
total_response_time_ms=0,
|
||||
health_by_format={}, # 按格式存储健康度
|
||||
circuit_breaker_by_format={}, # 按格式存储熔断器状态
|
||||
is_active=True,
|
||||
last_used_at=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
db.add(new_key)
|
||||
db.commit()
|
||||
db.refresh(new_key)
|
||||
|
||||
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}")
|
||||
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count}
|
||||
logger.info(
|
||||
f"[OK] 添加 Key: Provider={self.provider_id}, "
|
||||
f"Formats={self.key_data.api_formats}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}"
|
||||
)
|
||||
|
||||
return _build_key_response(new_key, api_key_plain=self.key_data.api_key)
|
||||
|
||||
@@ -67,8 +67,6 @@ async def list_provider_endpoints(
|
||||
- `custom_path`: 自定义路径
|
||||
- `timeout`: 超时时间(秒)
|
||||
- `max_retries`: 最大重试次数
|
||||
- `max_concurrent`: 最大并发数
|
||||
- `rate_limit`: 速率限制
|
||||
- `is_active`: 是否活跃
|
||||
- `total_keys`: Key 总数
|
||||
- `active_keys`: 活跃 Key 数量
|
||||
@@ -107,8 +105,6 @@ async def create_provider_endpoint(
|
||||
- `headers`: 自定义请求头(可选)
|
||||
- `timeout`: 超时时间(秒,默认 300)
|
||||
- `max_retries`: 最大重试次数(默认 2)
|
||||
- `max_concurrent`: 最大并发数(可选)
|
||||
- `rate_limit`: 速率限制(可选)
|
||||
- `config`: 额外配置(可选)
|
||||
- `proxy`: 代理配置(可选)
|
||||
|
||||
@@ -145,8 +141,6 @@ async def get_endpoint(
|
||||
- `custom_path`: 自定义路径
|
||||
- `timeout`: 超时时间(秒)
|
||||
- `max_retries`: 最大重试次数
|
||||
- `max_concurrent`: 最大并发数
|
||||
- `rate_limit`: 速率限制
|
||||
- `is_active`: 是否活跃
|
||||
- `total_keys`: Key 总数
|
||||
- `active_keys`: 活跃 Key 数量
|
||||
@@ -178,8 +172,6 @@ async def update_endpoint(
|
||||
- `headers`: 自定义请求头
|
||||
- `timeout`: 超时时间(秒)
|
||||
- `max_retries`: 最大重试次数
|
||||
- `max_concurrent`: 最大并发数
|
||||
- `rate_limit`: 速率限制
|
||||
- `is_active`: 是否活跃
|
||||
- `config`: 额外配置
|
||||
- `proxy`: 代理配置(设置为 null 可清除代理)
|
||||
@@ -203,15 +195,15 @@ async def delete_endpoint(
|
||||
"""
|
||||
删除 Endpoint
|
||||
|
||||
删除指定的 Endpoint,同时级联删除所有关联的 API Keys。
|
||||
此操作不可逆,请谨慎使用。
|
||||
删除指定的 Endpoint,会影响该 Provider 在该 API 格式下的路由能力。
|
||||
Key 不会被删除,但包含该 API 格式的 Key 将无法被调度使用(直到重新创建该格式的 Endpoint)。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果消息
|
||||
- `deleted_keys_count`: 同时删除的 Key 数量
|
||||
- `affected_keys_count`: 受影响的 Key 数量(包含该 API 格式)
|
||||
"""
|
||||
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -241,39 +233,33 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter):
|
||||
.all()
|
||||
)
|
||||
|
||||
endpoint_ids = [ep.id for ep in endpoints]
|
||||
total_keys_map = {}
|
||||
active_keys_map = {}
|
||||
if endpoint_ids:
|
||||
total_rows = (
|
||||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
|
||||
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
|
||||
.group_by(ProviderAPIKey.endpoint_id)
|
||||
.all()
|
||||
)
|
||||
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
|
||||
|
||||
active_rows = (
|
||||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.group_by(ProviderAPIKey.endpoint_id)
|
||||
.all()
|
||||
)
|
||||
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
|
||||
# Key 是 Provider 级别资源:按 key.api_formats 归类到各 Endpoint.api_format 下
|
||||
keys = (
|
||||
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
|
||||
.filter(ProviderAPIKey.provider_id == self.provider_id)
|
||||
.all()
|
||||
)
|
||||
total_keys_map: dict[str, int] = {}
|
||||
active_keys_map: dict[str, int] = {}
|
||||
for api_formats, is_active in keys:
|
||||
for fmt in (api_formats or []):
|
||||
total_keys_map[fmt] = total_keys_map.get(fmt, 0) + 1
|
||||
if is_active:
|
||||
active_keys_map[fmt] = active_keys_map.get(fmt, 0) + 1
|
||||
|
||||
result: List[ProviderEndpointResponse] = []
|
||||
for endpoint in endpoints:
|
||||
endpoint_format = (
|
||||
endpoint.api_format
|
||||
if isinstance(endpoint.api_format, str)
|
||||
else endpoint.api_format.value
|
||||
)
|
||||
endpoint_dict = {
|
||||
**endpoint.__dict__,
|
||||
"provider_name": provider.name,
|
||||
"api_format": endpoint.api_format,
|
||||
"total_keys": total_keys_map.get(endpoint.id, 0),
|
||||
"active_keys": active_keys_map.get(endpoint.id, 0),
|
||||
"total_keys": total_keys_map.get(endpoint_format, 0),
|
||||
"active_keys": active_keys_map.get(endpoint_format, 0),
|
||||
"proxy": mask_proxy_password(endpoint.proxy),
|
||||
}
|
||||
endpoint_dict.pop("_sa_instance_state", None)
|
||||
@@ -321,8 +307,6 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
||||
headers=self.endpoint_data.headers,
|
||||
timeout=self.endpoint_data.timeout,
|
||||
max_retries=self.endpoint_data.max_retries,
|
||||
max_concurrent=self.endpoint_data.max_concurrent,
|
||||
rate_limit=self.endpoint_data.rate_limit,
|
||||
is_active=True,
|
||||
config=self.endpoint_data.config,
|
||||
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
|
||||
@@ -367,19 +351,23 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter):
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
endpoint_obj, provider = endpoint
|
||||
total_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
endpoint_format = (
|
||||
endpoint_obj.api_format
|
||||
if isinstance(endpoint_obj.api_format, str)
|
||||
else endpoint_obj.api_format.value
|
||||
)
|
||||
active_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.count()
|
||||
keys = (
|
||||
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
|
||||
.filter(ProviderAPIKey.provider_id == endpoint_obj.provider_id)
|
||||
.all()
|
||||
)
|
||||
total_keys = 0
|
||||
active_keys = 0
|
||||
for api_formats, is_active in keys:
|
||||
if endpoint_format in (api_formats or []):
|
||||
total_keys += 1
|
||||
if is_active:
|
||||
active_keys += 1
|
||||
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
@@ -431,19 +419,21 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
||||
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
|
||||
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
|
||||
|
||||
total_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
endpoint_format = (
|
||||
endpoint.api_format if isinstance(endpoint.api_format, str) else endpoint.api_format.value
|
||||
)
|
||||
active_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.count()
|
||||
keys = (
|
||||
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
|
||||
.filter(ProviderAPIKey.provider_id == endpoint.provider_id)
|
||||
.all()
|
||||
)
|
||||
total_keys = 0
|
||||
active_keys = 0
|
||||
for api_formats, is_active in keys:
|
||||
if endpoint_format in (api_formats or []):
|
||||
total_keys += 1
|
||||
if is_active:
|
||||
active_keys += 1
|
||||
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
@@ -472,12 +462,26 @@ class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
keys_count = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
endpoint_format = (
|
||||
endpoint.api_format if isinstance(endpoint.api_format, str) else endpoint.api_format.value
|
||||
)
|
||||
keys = (
|
||||
db.query(ProviderAPIKey.api_formats)
|
||||
.filter(ProviderAPIKey.provider_id == endpoint.provider_id)
|
||||
.all()
|
||||
)
|
||||
affected_keys_count = sum(
|
||||
1 for (api_formats,) in keys if endpoint_format in (api_formats or [])
|
||||
)
|
||||
db.delete(endpoint)
|
||||
db.commit()
|
||||
|
||||
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
|
||||
logger.warning(
|
||||
f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, Format={endpoint_format}, "
|
||||
f"AffectedKeys={affected_keys_count}"
|
||||
)
|
||||
|
||||
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}
|
||||
return {
|
||||
"message": f"Endpoint {self.endpoint_id} 已删除",
|
||||
"affected_keys_count": affected_keys_count,
|
||||
}
|
||||
|
||||
@@ -125,7 +125,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
ModelCatalogProviderDetail(
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.name,
|
||||
provider_display_name=provider.display_name,
|
||||
model_id=model.id,
|
||||
target_model=model.provider_model_name,
|
||||
# 显示有效价格
|
||||
|
||||
@@ -452,7 +452,6 @@ class AdminGetGlobalModelProvidersAdapter(AdminApiAdapter):
|
||||
ModelCatalogProviderDetail(
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.name,
|
||||
provider_display_name=provider.display_name,
|
||||
model_id=model.id,
|
||||
target_model=model.provider_model_name,
|
||||
input_price_per_1m=model.get_effective_input_price(),
|
||||
|
||||
@@ -819,7 +819,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
|
||||
"username": user.username if user else None,
|
||||
"email": user.email if user else None,
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.display_name if provider else None,
|
||||
"provider_name": provider.name if provider else None,
|
||||
"endpoint_id": endpoint_id,
|
||||
"endpoint_api_format": (
|
||||
endpoint.api_format if endpoint and endpoint.api_format else None
|
||||
@@ -1369,9 +1369,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
||||
for model, provider in models:
|
||||
# 检查是否是主模型名称
|
||||
if model.provider_model_name == mapping_name:
|
||||
provider_names.append(
|
||||
provider.display_name or provider.name
|
||||
)
|
||||
provider_names.append(provider.name)
|
||||
continue
|
||||
# 检查是否在映射列表中
|
||||
if model.provider_model_mappings:
|
||||
@@ -1381,9 +1379,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
||||
if isinstance(a, dict)
|
||||
]
|
||||
if mapping_name in mapping_list:
|
||||
provider_names.append(
|
||||
provider.display_name or provider.name
|
||||
)
|
||||
provider_names.append(provider.name)
|
||||
provider_names = sorted(list(set(provider_names)))
|
||||
|
||||
mappings.append({
|
||||
@@ -1473,7 +1469,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
||||
|
||||
provider_model_mappings.append({
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.display_name or provider.name,
|
||||
"provider_name": provider.name,
|
||||
"global_model_id": global_model_id,
|
||||
"global_model_name": global_model.name,
|
||||
"global_model_display_name": global_model.display_name,
|
||||
|
||||
@@ -13,10 +13,11 @@ from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.handlers.base.chat_adapter_base import get_adapter_class
|
||||
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_class
|
||||
from src.config.constants import TimeoutDefaults
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.database.database import get_db
|
||||
from src.models.database import Provider, ProviderEndpoint, User
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
|
||||
from src.utils.auth_utils import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
|
||||
@@ -81,10 +82,13 @@ async def query_available_models(
|
||||
Returns:
|
||||
所有端点的模型列表(合并)
|
||||
"""
|
||||
# 获取提供商及其端点
|
||||
# 获取提供商及其端点和 API Keys
|
||||
provider = (
|
||||
db.query(Provider)
|
||||
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
||||
.options(
|
||||
joinedload(Provider.endpoints),
|
||||
joinedload(Provider.api_keys),
|
||||
)
|
||||
.filter(Provider.id == request.provider_id)
|
||||
.first()
|
||||
)
|
||||
@@ -95,49 +99,70 @@ async def query_available_models(
|
||||
# 收集所有活跃端点的配置
|
||||
endpoint_configs: list[dict] = []
|
||||
|
||||
# 构建 api_format -> endpoint 映射
|
||||
format_to_endpoint: dict[str, ProviderEndpoint] = {}
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active:
|
||||
format_to_endpoint[endpoint.api_format] = endpoint
|
||||
|
||||
if request.api_key_id:
|
||||
# 指定了特定的 API Key,只使用该 Key 对应的端点
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
try:
|
||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt API key: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||
endpoint_configs.append({
|
||||
"api_key": api_key_value,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format,
|
||||
"extra_headers": endpoint.headers,
|
||||
})
|
||||
break
|
||||
if endpoint_configs:
|
||||
break
|
||||
# 指定了特定的 API Key(从 provider.api_keys 查找)
|
||||
api_key = next(
|
||||
(key for key in provider.api_keys if key.id == request.api_key_id),
|
||||
None
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
|
||||
try:
|
||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt API key: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||
|
||||
# 根据 Key 的 api_formats 找对应的 Endpoint
|
||||
key_formats = api_key.api_formats or []
|
||||
for fmt in key_formats:
|
||||
endpoint = format_to_endpoint.get(fmt)
|
||||
if endpoint:
|
||||
endpoint_configs.append({
|
||||
"api_key": api_key_value,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": fmt,
|
||||
"extra_headers": endpoint.headers,
|
||||
})
|
||||
|
||||
if not endpoint_configs:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No matching endpoint found for this API Key's formats"
|
||||
)
|
||||
else:
|
||||
# 遍历所有活跃端点,每个端点取第一个可用的 Key
|
||||
# 遍历所有活跃端点,为每个端点找一个支持该格式的 Key
|
||||
for endpoint in provider.endpoints:
|
||||
if not endpoint.is_active or not endpoint.api_keys:
|
||||
if not endpoint.is_active:
|
||||
continue
|
||||
|
||||
# 找第一个可用的 Key
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
try:
|
||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt API key: {e}")
|
||||
continue # 尝试下一个 Key
|
||||
endpoint_configs.append({
|
||||
"api_key": api_key_value,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format,
|
||||
"extra_headers": endpoint.headers,
|
||||
})
|
||||
break # 只取第一个可用的 Key
|
||||
# 找第一个支持该格式的可用 Key
|
||||
for api_key in provider.api_keys:
|
||||
if not api_key.is_active:
|
||||
continue
|
||||
key_formats = api_key.api_formats or []
|
||||
if endpoint.api_format not in key_formats:
|
||||
continue
|
||||
try:
|
||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt API key: {e}")
|
||||
continue
|
||||
endpoint_configs.append({
|
||||
"api_key": api_key_value,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format,
|
||||
"extra_headers": endpoint.headers,
|
||||
})
|
||||
break # 只取第一个可用的 Key
|
||||
|
||||
if not endpoint_configs:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
@@ -214,7 +239,6 @@ async def query_available_models(
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -229,17 +253,14 @@ async def test_model(
|
||||
测试模型连接性
|
||||
|
||||
向指定提供商的指定模型发送测试请求,验证模型是否可用
|
||||
|
||||
Args:
|
||||
request: 测试请求
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
# 获取提供商及其端点
|
||||
# 获取提供商及其端点和 Keys
|
||||
provider = (
|
||||
db.query(Provider)
|
||||
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
||||
.options(
|
||||
joinedload(Provider.endpoints),
|
||||
joinedload(Provider.api_keys),
|
||||
)
|
||||
.filter(Provider.id == request.provider_id)
|
||||
.first()
|
||||
)
|
||||
@@ -247,28 +268,38 @@ async def test_model(
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 找到合适的端点和API Key
|
||||
endpoint_config = None
|
||||
# 构建 api_format -> endpoint 映射
|
||||
format_to_endpoint: dict[str, ProviderEndpoint] = {}
|
||||
for ep in provider.endpoints:
|
||||
if ep.is_active:
|
||||
format_to_endpoint[ep.api_format] = ep
|
||||
|
||||
# 找到合适的端点和 API Key
|
||||
endpoint = None
|
||||
api_key = None
|
||||
|
||||
if request.api_key_id:
|
||||
# 使用指定的API Key
|
||||
for ep in provider.endpoints:
|
||||
for key in ep.api_keys:
|
||||
if key.id == request.api_key_id and key.is_active and ep.is_active:
|
||||
endpoint = ep
|
||||
api_key = key
|
||||
# 使用指定的 API Key
|
||||
api_key = next(
|
||||
(key for key in provider.api_keys if key.id == request.api_key_id and key.is_active),
|
||||
None
|
||||
)
|
||||
if api_key:
|
||||
# 找到该 Key 支持的第一个活跃 Endpoint
|
||||
for fmt in (api_key.api_formats or []):
|
||||
if fmt in format_to_endpoint:
|
||||
endpoint = format_to_endpoint[fmt]
|
||||
break
|
||||
if endpoint:
|
||||
break
|
||||
else:
|
||||
# 使用第一个可用的端点和密钥
|
||||
for ep in provider.endpoints:
|
||||
if not ep.is_active or not ep.api_keys:
|
||||
if not ep.is_active:
|
||||
continue
|
||||
for key in ep.api_keys:
|
||||
if key.is_active:
|
||||
# 找支持该格式的第一个可用 Key
|
||||
for key in provider.api_keys:
|
||||
if not key.is_active:
|
||||
continue
|
||||
if ep.api_format in (key.api_formats or []):
|
||||
endpoint = ep
|
||||
api_key = key
|
||||
break
|
||||
@@ -284,14 +315,14 @@ async def test_model(
|
||||
logger.error(f"[test-model] Failed to decrypt API key: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||
|
||||
# 构建请求配置
|
||||
# 构建请求配置(timeout 从 Provider 读取)
|
||||
endpoint_config = {
|
||||
"api_key": api_key_value,
|
||||
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format,
|
||||
"extra_headers": endpoint.headers,
|
||||
"timeout": endpoint.timeout or 30.0,
|
||||
"timeout": provider.timeout or TimeoutDefaults.HTTP_REQUEST,
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -304,7 +335,6 @@ async def test_model(
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
"model": request.model_name,
|
||||
}
|
||||
@@ -325,7 +355,6 @@ async def test_model(
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
"model": request.model_name,
|
||||
}
|
||||
@@ -415,7 +444,6 @@ async def test_model(
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
"model": request.model_name,
|
||||
"endpoint": {
|
||||
@@ -433,7 +461,6 @@ async def test_model(
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
"model": request.model_name,
|
||||
"endpoint": {
|
||||
|
||||
@@ -78,7 +78,7 @@ async def get_provider_stats(
|
||||
"""
|
||||
获取提供商统计数据
|
||||
|
||||
获取指定提供商的计费信息、RPM 使用情况和使用统计数据。
|
||||
获取指定提供商的计费信息和使用统计数据。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
@@ -96,10 +96,6 @@ async def get_provider_stats(
|
||||
- `monthly_used_usd`: 月度已使用
|
||||
- `quota_remaining_usd`: 剩余配额
|
||||
- `quota_expires_at`: 配额过期时间
|
||||
- `rpm_info`: RPM 信息
|
||||
- `rpm_limit`: RPM 限制
|
||||
- `rpm_used`: 已使用 RPM
|
||||
- `rpm_reset_at`: RPM 重置时间
|
||||
- `usage_stats`: 使用统计
|
||||
- `total_requests`: 总请求数
|
||||
- `successful_requests`: 成功请求数
|
||||
@@ -165,7 +161,6 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
|
||||
provider.billing_type = config.billing_type
|
||||
provider.monthly_quota_usd = config.monthly_quota_usd
|
||||
provider.quota_reset_day = config.quota_reset_day
|
||||
provider.rpm_limit = config.rpm_limit
|
||||
provider.provider_priority = config.provider_priority
|
||||
|
||||
from dateutil import parser
|
||||
@@ -262,13 +257,6 @@ class AdminProviderStatsAdapter(AdminApiAdapter):
|
||||
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
|
||||
),
|
||||
},
|
||||
"rpm_info": {
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"rpm_used": provider.rpm_used,
|
||||
"rpm_reset_at": (
|
||||
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
|
||||
),
|
||||
},
|
||||
"usage_stats": {
|
||||
"total_requests": total_requests,
|
||||
"successful_requests": total_success,
|
||||
@@ -296,8 +284,6 @@ class AdminProviderResetQuotaAdapter(AdminApiAdapter):
|
||||
|
||||
old_used = provider.monthly_used_usd
|
||||
provider.monthly_used_usd = 0.0
|
||||
provider.rpm_used = 0
|
||||
provider.rpm_reset_at = None
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Manually reset quota for provider {provider.name}")
|
||||
|
||||
@@ -338,27 +338,29 @@ async def import_models_from_upstream(
|
||||
"""
|
||||
从上游提供商导入模型
|
||||
|
||||
从上游提供商导入模型列表。如果全局模型不存在,将自动创建。
|
||||
从上游提供商导入模型列表。导入的模型作为独立的 ProviderModel 存储,
|
||||
不会自动创建 GlobalModel。后续需要手动关联 GlobalModel 才能参与路由。
|
||||
|
||||
**流程说明**:
|
||||
1. 根据 model_ids 检查全局模型是否存在(按 name 匹配)
|
||||
2. 如不存在,自动创建新的 GlobalModel(使用默认免费配置)
|
||||
3. 创建 Model 关联到当前 Provider
|
||||
4. 如模型已关联,则记录到成功列表中
|
||||
1. 检查模型是否已存在于当前 Provider(按 provider_model_name 匹配)
|
||||
2. 创建新的 ProviderModel(global_model_id = NULL)
|
||||
3. 支持设置价格覆盖(tiered_pricing, price_per_request)
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**请求体字段**:
|
||||
- `model_ids`: 模型 ID 数组(必填,每个 ID 长度 1-100 字符)
|
||||
- `tiered_pricing`: 可选的阶梯计费配置(应用于所有导入的模型)
|
||||
- `price_per_request`: 可选的按次计费价格(应用于所有导入的模型)
|
||||
|
||||
**返回字段**:
|
||||
- `success`: 成功导入的模型数组,每项包含:
|
||||
- `model_id`: 模型 ID
|
||||
- `global_model_id`: 全局模型 ID
|
||||
- `global_model_name`: 全局模型名称
|
||||
- `provider_model_id`: 提供商模型 ID
|
||||
- `created_global_model`: 是否新创建了全局模型
|
||||
- `global_model_id`: 全局模型 ID(如果已关联)
|
||||
- `global_model_name`: 全局模型名称(如果已关联)
|
||||
- `created_global_model`: 是否新创建了全局模型(始终为 false)
|
||||
- `errors`: 失败的模型数组,每项包含:
|
||||
- `model_id`: 模型 ID
|
||||
- `error`: 错误信息
|
||||
@@ -638,7 +640,7 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
|
||||
|
||||
@dataclass
|
||||
class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
||||
"""从上游提供商导入模型"""
|
||||
"""从上游提供商导入模型(不创建 GlobalModel,作为独立 ProviderModel)"""
|
||||
|
||||
provider_id: str
|
||||
payload: ImportFromUpstreamRequest
|
||||
@@ -652,16 +654,13 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
||||
success: list[ImportFromUpstreamSuccessItem] = []
|
||||
errors: list[ImportFromUpstreamErrorItem] = []
|
||||
|
||||
# 默认阶梯计费配置(免费)
|
||||
default_tiered_pricing = {
|
||||
"tiers": [
|
||||
{
|
||||
"up_to": None,
|
||||
"input_price_per_1m": 0.0,
|
||||
"output_price_per_1m": 0.0,
|
||||
}
|
||||
]
|
||||
}
|
||||
# 获取价格覆盖配置
|
||||
tiered_pricing = None
|
||||
price_per_request = None
|
||||
if hasattr(self.payload, 'tiered_pricing') and self.payload.tiered_pricing:
|
||||
tiered_pricing = self.payload.tiered_pricing
|
||||
if hasattr(self.payload, 'price_per_request') and self.payload.price_per_request is not None:
|
||||
price_per_request = self.payload.price_per_request
|
||||
|
||||
for model_id in self.payload.model_ids:
|
||||
# 输入验证:检查 model_id 长度
|
||||
@@ -678,56 +677,37 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
||||
# 使用 savepoint 确保单个模型导入的原子性
|
||||
savepoint = db.begin_nested()
|
||||
try:
|
||||
# 1. 检查是否已存在同名的 GlobalModel
|
||||
global_model = (
|
||||
db.query(GlobalModel).filter(GlobalModel.name == model_id).first()
|
||||
)
|
||||
created_global_model = False
|
||||
|
||||
if not global_model:
|
||||
# 2. 创建新的 GlobalModel
|
||||
global_model = GlobalModel(
|
||||
name=model_id,
|
||||
display_name=model_id,
|
||||
default_tiered_pricing=default_tiered_pricing,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(global_model)
|
||||
db.flush()
|
||||
created_global_model = True
|
||||
logger.info(
|
||||
f"Created new GlobalModel: {model_id} during upstream import"
|
||||
)
|
||||
|
||||
# 3. 检查是否已存在关联
|
||||
# 1. 检查是否已存在同名的 ProviderModel
|
||||
existing = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == self.provider_id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.provider_model_name == model_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
# 已存在关联,提交 savepoint 并记录成功
|
||||
# 已存在,提交 savepoint 并记录成功
|
||||
savepoint.commit()
|
||||
success.append(
|
||||
ImportFromUpstreamSuccessItem(
|
||||
model_id=model_id,
|
||||
global_model_id=global_model.id,
|
||||
global_model_name=global_model.name,
|
||||
global_model_id=existing.global_model_id or "",
|
||||
global_model_name=existing.global_model.name if existing.global_model else "",
|
||||
provider_model_id=existing.id,
|
||||
created_global_model=created_global_model,
|
||||
created_global_model=False,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# 4. 创建新的 Model 记录
|
||||
# 2. 创建新的 Model 记录(不关联 GlobalModel)
|
||||
new_model = Model(
|
||||
provider_id=self.provider_id,
|
||||
global_model_id=global_model.id,
|
||||
provider_model_name=global_model.name,
|
||||
global_model_id=None, # 独立模型,不关联 GlobalModel
|
||||
provider_model_name=model_id,
|
||||
is_active=True,
|
||||
tiered_pricing=tiered_pricing,
|
||||
price_per_request=price_per_request,
|
||||
)
|
||||
db.add(new_model)
|
||||
db.flush()
|
||||
@@ -737,12 +717,15 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
||||
success.append(
|
||||
ImportFromUpstreamSuccessItem(
|
||||
model_id=model_id,
|
||||
global_model_id=global_model.id,
|
||||
global_model_name=global_model.name,
|
||||
global_model_id="", # 未关联
|
||||
global_model_name="", # 未关联
|
||||
provider_model_id=new_model.id,
|
||||
created_global_model=created_global_model,
|
||||
created_global_model=False,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Created independent ProviderModel: {model_id} for provider {provider.name}"
|
||||
)
|
||||
except Exception as e:
|
||||
# 回滚到 savepoint
|
||||
savepoint.rollback()
|
||||
@@ -753,11 +736,9 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
||||
|
||||
db.commit()
|
||||
logger.info(
|
||||
f"Imported {len(success)} models from upstream to provider {provider.name} by {context.user.username}"
|
||||
f"Imported {len(success)} independent models to provider {provider.name} by {context.user.username}"
|
||||
)
|
||||
|
||||
# 清除 /v1/models 列表缓存
|
||||
if success:
|
||||
await invalidate_models_list_cache()
|
||||
# 不需要清除 /v1/models 缓存,因为独立模型不参与路由
|
||||
|
||||
return ImportFromUpstreamResponse(success=success, errors=errors)
|
||||
|
||||
@@ -41,8 +41,7 @@ async def list_providers(
|
||||
|
||||
**返回字段**:
|
||||
- `id`: 提供商 ID
|
||||
- `name`: 提供商名称(唯一标识)
|
||||
- `display_name`: 显示名称
|
||||
- `name`: 提供商名称(唯一)
|
||||
- `api_format`: API 格式(如 claude、openai、gemini 等)
|
||||
- `base_url`: API 基础 URL
|
||||
- `api_key`: API 密钥(脱敏显示)
|
||||
@@ -63,8 +62,7 @@ async def create_provider(request: Request, db: Session = Depends(get_db)):
|
||||
创建一个新的 AI 模型提供商配置。
|
||||
|
||||
**请求体字段**:
|
||||
- `name`: 提供商名称(必填,唯一,用于系统标识)
|
||||
- `display_name`: 显示名称(必填)
|
||||
- `name`: 提供商名称(必填,唯一)
|
||||
- `description`: 描述信息(可选)
|
||||
- `website`: 官网地址(可选)
|
||||
- `billing_type`: 计费类型(可选,pay_as_you_go/subscription/prepaid,默认 pay_as_you_go)
|
||||
@@ -72,16 +70,17 @@ async def create_provider(request: Request, db: Session = Depends(get_db)):
|
||||
- `quota_reset_day`: 配额重置日期(1-31)(可选)
|
||||
- `quota_last_reset_at`: 上次配额重置时间(可选)
|
||||
- `quota_expires_at`: 配额过期时间(可选)
|
||||
- `rpm_limit`: 每分钟请求数限制(可选)
|
||||
- `provider_priority`: 提供商优先级(数字越小优先级越高,默认 100)
|
||||
- `is_active`: 是否启用(默认 true)
|
||||
- `concurrent_limit`: 并发限制(可选)
|
||||
- `timeout`: 请求超时(秒,可选)
|
||||
- `max_retries`: 最大重试次数(可选)
|
||||
- `proxy`: 代理配置(可选)
|
||||
- `config`: 额外配置信息(JSON,可选)
|
||||
|
||||
**返回字段**:
|
||||
- `id`: 新创建的提供商 ID
|
||||
- `name`: 提供商名称
|
||||
- `display_name`: 显示名称
|
||||
- `message`: 成功提示信息
|
||||
"""
|
||||
adapter = AdminCreateProviderAdapter()
|
||||
@@ -100,7 +99,6 @@ async def update_provider(provider_id: str, request: Request, db: Session = Depe
|
||||
|
||||
**请求体字段**(所有字段可选):
|
||||
- `name`: 提供商名称
|
||||
- `display_name`: 显示名称
|
||||
- `description`: 描述信息
|
||||
- `website`: 官网地址
|
||||
- `billing_type`: 计费类型(pay_as_you_go/subscription/prepaid)
|
||||
@@ -108,10 +106,12 @@ async def update_provider(provider_id: str, request: Request, db: Session = Depe
|
||||
- `quota_reset_day`: 配额重置日期(1-31)
|
||||
- `quota_last_reset_at`: 上次配额重置时间
|
||||
- `quota_expires_at`: 配额过期时间
|
||||
- `rpm_limit`: 每分钟请求数限制
|
||||
- `provider_priority`: 提供商优先级
|
||||
- `is_active`: 是否启用
|
||||
- `concurrent_limit`: 并发限制
|
||||
- `timeout`: 请求超时(秒)
|
||||
- `max_retries`: 最大重试次数
|
||||
- `proxy`: 代理配置
|
||||
- `config`: 额外配置信息(JSON)
|
||||
|
||||
**返回字段**:
|
||||
@@ -165,7 +165,6 @@ class AdminListProvidersAdapter(AdminApiAdapter):
|
||||
{
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"api_format": api_format.value if api_format else None,
|
||||
"base_url": base_url,
|
||||
"api_key": "***" if api_key else None,
|
||||
@@ -217,7 +216,6 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
|
||||
# 创建 Provider 对象
|
||||
provider = Provider(
|
||||
name=validated_data.name,
|
||||
display_name=validated_data.display_name,
|
||||
description=validated_data.description,
|
||||
website=validated_data.website,
|
||||
billing_type=billing_type,
|
||||
@@ -225,10 +223,12 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
|
||||
quota_reset_day=validated_data.quota_reset_day,
|
||||
quota_last_reset_at=validated_data.quota_last_reset_at,
|
||||
quota_expires_at=validated_data.quota_expires_at,
|
||||
rpm_limit=validated_data.rpm_limit,
|
||||
provider_priority=validated_data.provider_priority,
|
||||
is_active=validated_data.is_active,
|
||||
concurrent_limit=validated_data.concurrent_limit,
|
||||
timeout=validated_data.timeout,
|
||||
max_retries=validated_data.max_retries,
|
||||
proxy=validated_data.proxy.model_dump() if validated_data.proxy else None,
|
||||
config=validated_data.config,
|
||||
)
|
||||
|
||||
@@ -248,7 +248,6 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"message": "提供商创建成功",
|
||||
}
|
||||
except InvalidRequestException:
|
||||
@@ -291,6 +290,9 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
|
||||
if field == "billing_type" and value is not None:
|
||||
# billing_type 需要转换为枚举
|
||||
setattr(provider, field, ProviderBillingType(value))
|
||||
elif field == "proxy" and value is not None:
|
||||
# proxy 需要转换为 dict(如果是 Pydantic 模型)
|
||||
setattr(provider, field, value if isinstance(value, dict) else value.model_dump())
|
||||
else:
|
||||
setattr(provider, field, value)
|
||||
|
||||
|
||||
@@ -48,7 +48,6 @@ async def get_providers_summary(
|
||||
**返回字段**(数组,每项包含):
|
||||
- `id`: 提供商 ID
|
||||
- `name`: 提供商名称
|
||||
- `display_name`: 显示名称
|
||||
- `description`: 描述信息
|
||||
- `website`: 官网地址
|
||||
- `provider_priority`: 优先级
|
||||
@@ -59,9 +58,9 @@ async def get_providers_summary(
|
||||
- `quota_reset_day`: 配额重置日期
|
||||
- `quota_last_reset_at`: 上次配额重置时间
|
||||
- `quota_expires_at`: 配额过期时间
|
||||
- `rpm_limit`: RPM 限制
|
||||
- `rpm_used`: 已使用 RPM
|
||||
- `rpm_reset_at`: RPM 重置时间
|
||||
- `timeout`: 默认请求超时(秒)
|
||||
- `max_retries`: 默认最大重试次数
|
||||
- `proxy`: 默认代理配置
|
||||
- `total_endpoints`: 端点总数
|
||||
- `active_endpoints`: 活跃端点数
|
||||
- `total_keys`: 密钥总数
|
||||
@@ -96,7 +95,6 @@ async def get_provider_summary(
|
||||
**返回字段**:
|
||||
- `id`: 提供商 ID
|
||||
- `name`: 提供商名称
|
||||
- `display_name`: 显示名称
|
||||
- `description`: 描述信息
|
||||
- `website`: 官网地址
|
||||
- `provider_priority`: 优先级
|
||||
@@ -107,9 +105,9 @@ async def get_provider_summary(
|
||||
- `quota_reset_day`: 配额重置日期
|
||||
- `quota_last_reset_at`: 上次配额重置时间
|
||||
- `quota_expires_at`: 配额过期时间
|
||||
- `rpm_limit`: RPM 限制
|
||||
- `rpm_used`: 已使用 RPM
|
||||
- `rpm_reset_at`: RPM 重置时间
|
||||
- `timeout`: 默认请求超时(秒)
|
||||
- `max_retries`: 默认最大重试次数
|
||||
- `proxy`: 默认代理配置
|
||||
- `total_endpoints`: 端点总数
|
||||
- `active_endpoints`: 活跃端点数
|
||||
- `total_keys`: 密钥总数
|
||||
@@ -185,13 +183,13 @@ async def update_provider_settings(
|
||||
"""
|
||||
更新提供商基础配置
|
||||
|
||||
更新提供商的基础配置信息,如显示名称、描述、优先级等。只需传入需要更新的字段。
|
||||
更新提供商的基础配置信息,如名称、描述、优先级等。只需传入需要更新的字段。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**请求体字段**(所有字段可选):
|
||||
- `display_name`: 显示名称
|
||||
- `name`: 提供商名称
|
||||
- `description`: 描述信息
|
||||
- `website`: 官网地址
|
||||
- `provider_priority`: 优先级
|
||||
@@ -199,9 +197,10 @@ async def update_provider_settings(
|
||||
- `billing_type`: 计费类型
|
||||
- `monthly_quota_usd`: 月度配额(美元)
|
||||
- `quota_reset_day`: 配额重置日期
|
||||
- `quota_last_reset_at`: 上次配额重置时间
|
||||
- `quota_expires_at`: 配额过期时间
|
||||
- `rpm_limit`: RPM 限制
|
||||
- `timeout`: 默认请求超时(秒)
|
||||
- `max_retries`: 默认最大重试次数
|
||||
- `proxy`: 默认代理配置
|
||||
|
||||
**返回字段**: 返回更新后的提供商摘要信息(与 GET /summary 接口返回格式相同)
|
||||
"""
|
||||
@@ -215,18 +214,18 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
|
||||
|
||||
total_endpoints = len(endpoints)
|
||||
active_endpoints = sum(1 for e in endpoints if e.is_active)
|
||||
endpoint_ids = [e.id for e in endpoints]
|
||||
|
||||
# Key 统计(合并为单个查询)
|
||||
total_keys = 0
|
||||
active_keys = 0
|
||||
if endpoint_ids:
|
||||
key_stats = db.query(
|
||||
key_stats = (
|
||||
db.query(
|
||||
func.count(ProviderAPIKey.id).label("total"),
|
||||
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
|
||||
).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).first()
|
||||
total_keys = key_stats.total or 0
|
||||
active_keys = int(key_stats.active or 0)
|
||||
)
|
||||
.filter(ProviderAPIKey.provider_id == provider.id)
|
||||
.first()
|
||||
)
|
||||
total_keys = key_stats.total or 0
|
||||
active_keys = int(key_stats.active or 0)
|
||||
|
||||
# Model 统计(合并为单个查询)
|
||||
model_stats = db.query(
|
||||
@@ -238,25 +237,34 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
|
||||
|
||||
api_formats = [e.api_format for e in endpoints]
|
||||
|
||||
# 优化: 一次性加载所有 endpoint 的 keys,避免 N+1 查询
|
||||
all_keys = []
|
||||
if endpoint_ids:
|
||||
all_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).all()
|
||||
)
|
||||
# 优化: 一次性加载 Provider 的 keys,避免 N+1 查询
|
||||
all_keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.provider_id == provider.id).all()
|
||||
|
||||
# 按 endpoint_id 分组 keys
|
||||
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {}
|
||||
# 按 api_formats 分组 keys(通过 api_formats 关联)
|
||||
format_to_endpoint_id: dict[str, str] = {e.api_format: e.id for e in endpoints}
|
||||
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {e.id: [] for e in endpoints}
|
||||
for key in all_keys:
|
||||
if key.endpoint_id not in keys_by_endpoint:
|
||||
keys_by_endpoint[key.endpoint_id] = []
|
||||
keys_by_endpoint[key.endpoint_id].append(key)
|
||||
formats = key.api_formats or []
|
||||
for fmt in formats:
|
||||
endpoint_id = format_to_endpoint_id.get(fmt)
|
||||
if endpoint_id:
|
||||
keys_by_endpoint[endpoint_id].append(key)
|
||||
|
||||
endpoint_health_map: dict[str, float] = {}
|
||||
for endpoint in endpoints:
|
||||
keys = keys_by_endpoint.get(endpoint.id, [])
|
||||
if keys:
|
||||
health_scores = [k.health_score for k in keys if k.health_score is not None]
|
||||
# 从 health_by_format 获取对应格式的健康度
|
||||
api_fmt = endpoint.api_format
|
||||
health_scores = []
|
||||
for k in keys:
|
||||
health_by_format = k.health_by_format or {}
|
||||
if api_fmt in health_by_format:
|
||||
score = health_by_format[api_fmt].get("health_score")
|
||||
if score is not None:
|
||||
health_scores.append(float(score))
|
||||
else:
|
||||
health_scores.append(1.0) # 默认健康度
|
||||
avg_health = sum(health_scores) / len(health_scores) if health_scores else 1.0
|
||||
endpoint_health_map[endpoint.id] = avg_health
|
||||
else:
|
||||
@@ -284,7 +292,6 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
|
||||
return ProviderWithEndpointsSummary(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
display_name=provider.display_name,
|
||||
description=provider.description,
|
||||
website=provider.website,
|
||||
provider_priority=provider.provider_priority,
|
||||
@@ -295,9 +302,9 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
|
||||
quota_reset_day=provider.quota_reset_day,
|
||||
quota_last_reset_at=provider.quota_last_reset_at,
|
||||
quota_expires_at=provider.quota_expires_at,
|
||||
rpm_limit=provider.rpm_limit,
|
||||
rpm_used=provider.rpm_used,
|
||||
rpm_reset_at=provider.rpm_reset_at,
|
||||
timeout=provider.timeout,
|
||||
max_retries=provider.max_retries,
|
||||
proxy=provider.proxy,
|
||||
total_endpoints=total_endpoints,
|
||||
active_endpoints=active_endpoints,
|
||||
total_keys=total_keys,
|
||||
@@ -341,7 +348,7 @@ class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
|
||||
if not endpoint_ids:
|
||||
response = ProviderEndpointHealthMonitorResponse(
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.display_name or provider.name,
|
||||
provider_name=provider.name,
|
||||
generated_at=now,
|
||||
endpoints=[],
|
||||
)
|
||||
@@ -416,7 +423,7 @@ class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
|
||||
|
||||
response = ProviderEndpointHealthMonitorResponse(
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.display_name or provider.name,
|
||||
provider_name=provider.name,
|
||||
generated_at=now,
|
||||
endpoints=endpoint_monitors,
|
||||
)
|
||||
|
||||
@@ -730,36 +730,6 @@ class AdminExportConfigAdapter(AdminApiAdapter):
|
||||
)
|
||||
endpoints_data = []
|
||||
for ep in endpoints:
|
||||
# 导出 Endpoint Keys
|
||||
keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
|
||||
)
|
||||
keys_data = []
|
||||
for key in keys:
|
||||
# 解密 API Key
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
except Exception:
|
||||
decrypted_key = ""
|
||||
|
||||
keys_data.append(
|
||||
{
|
||||
"api_key": decrypted_key,
|
||||
"name": key.name,
|
||||
"note": key.note,
|
||||
"rate_multiplier": key.rate_multiplier,
|
||||
"internal_priority": key.internal_priority,
|
||||
"global_priority": key.global_priority,
|
||||
"max_concurrent": key.max_concurrent,
|
||||
"rate_limit": key.rate_limit,
|
||||
"daily_limit": key.daily_limit,
|
||||
"monthly_limit": key.monthly_limit,
|
||||
"allowed_models": key.allowed_models,
|
||||
"capabilities": key.capabilities,
|
||||
"is_active": key.is_active,
|
||||
}
|
||||
)
|
||||
|
||||
endpoints_data.append(
|
||||
{
|
||||
"api_format": ep.api_format,
|
||||
@@ -767,12 +737,44 @@ class AdminExportConfigAdapter(AdminApiAdapter):
|
||||
"headers": ep.headers,
|
||||
"timeout": ep.timeout,
|
||||
"max_retries": ep.max_retries,
|
||||
"max_concurrent": ep.max_concurrent,
|
||||
"rate_limit": ep.rate_limit,
|
||||
"is_active": ep.is_active,
|
||||
"custom_path": ep.custom_path,
|
||||
"config": ep.config,
|
||||
"keys": keys_data,
|
||||
"proxy": ep.proxy,
|
||||
}
|
||||
)
|
||||
|
||||
# 导出 Provider Keys(按 provider_id 归属,包含 api_formats)
|
||||
keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(ProviderAPIKey.provider_id == provider.id)
|
||||
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
keys_data = []
|
||||
for key in keys:
|
||||
# 解密 API Key
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
except Exception:
|
||||
decrypted_key = ""
|
||||
|
||||
keys_data.append(
|
||||
{
|
||||
"api_key": decrypted_key,
|
||||
"name": key.name,
|
||||
"note": key.note,
|
||||
"api_formats": key.api_formats or [],
|
||||
"rate_multiplier": key.rate_multiplier,
|
||||
"rate_multipliers": key.rate_multipliers,
|
||||
"internal_priority": key.internal_priority,
|
||||
"global_priority": key.global_priority,
|
||||
"rpm_limit": key.rpm_limit,
|
||||
"allowed_models": key.allowed_models,
|
||||
"capabilities": key.capabilities,
|
||||
"cache_ttl_minutes": key.cache_ttl_minutes,
|
||||
"max_probe_interval_minutes": key.max_probe_interval_minutes,
|
||||
"is_active": key.is_active,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -804,24 +806,26 @@ class AdminExportConfigAdapter(AdminApiAdapter):
|
||||
providers_data.append(
|
||||
{
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"description": provider.description,
|
||||
"website": provider.website,
|
||||
"billing_type": provider.billing_type.value if provider.billing_type else None,
|
||||
"monthly_quota_usd": provider.monthly_quota_usd,
|
||||
"quota_reset_day": provider.quota_reset_day,
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"provider_priority": provider.provider_priority,
|
||||
"is_active": provider.is_active,
|
||||
"concurrent_limit": provider.concurrent_limit,
|
||||
"timeout": provider.timeout,
|
||||
"max_retries": provider.max_retries,
|
||||
"proxy": provider.proxy,
|
||||
"config": provider.config,
|
||||
"endpoints": endpoints_data,
|
||||
"api_keys": keys_data,
|
||||
"models": models_data,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"version": "1.0",
|
||||
"version": "2.0",
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"global_models": global_models_data,
|
||||
"providers": providers_data,
|
||||
@@ -850,7 +854,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
|
||||
# 验证配置版本
|
||||
version = payload.get("version")
|
||||
if version != "1.0":
|
||||
if version != "2.0":
|
||||
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||||
|
||||
# 获取导入选项
|
||||
@@ -939,8 +943,8 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
# 更新现有记录
|
||||
existing_provider.display_name = prov_data.get(
|
||||
"display_name", existing_provider.display_name
|
||||
existing_provider.name = prov_data.get(
|
||||
"name", existing_provider.name
|
||||
)
|
||||
existing_provider.description = prov_data.get("description")
|
||||
existing_provider.website = prov_data.get("website")
|
||||
@@ -954,7 +958,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
existing_provider.quota_reset_day = prov_data.get(
|
||||
"quota_reset_day", 30
|
||||
)
|
||||
existing_provider.rpm_limit = prov_data.get("rpm_limit")
|
||||
existing_provider.provider_priority = prov_data.get(
|
||||
"provider_priority", 100
|
||||
)
|
||||
@@ -962,6 +965,11 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
existing_provider.concurrent_limit = prov_data.get(
|
||||
"concurrent_limit"
|
||||
)
|
||||
existing_provider.timeout = prov_data.get("timeout", existing_provider.timeout)
|
||||
existing_provider.max_retries = prov_data.get(
|
||||
"max_retries", existing_provider.max_retries
|
||||
)
|
||||
existing_provider.proxy = prov_data.get("proxy", existing_provider.proxy)
|
||||
existing_provider.config = prov_data.get("config")
|
||||
existing_provider.updated_at = datetime.now(timezone.utc)
|
||||
stats["providers"]["updated"] += 1
|
||||
@@ -974,16 +982,17 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
new_provider = Provider(
|
||||
id=str(uuid.uuid4()),
|
||||
name=prov_data["name"],
|
||||
display_name=prov_data.get("display_name", prov_data["name"]),
|
||||
description=prov_data.get("description"),
|
||||
website=prov_data.get("website"),
|
||||
billing_type=billing_type,
|
||||
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
|
||||
quota_reset_day=prov_data.get("quota_reset_day", 30),
|
||||
rpm_limit=prov_data.get("rpm_limit"),
|
||||
provider_priority=prov_data.get("provider_priority", 100),
|
||||
is_active=prov_data.get("is_active", True),
|
||||
concurrent_limit=prov_data.get("concurrent_limit"),
|
||||
timeout=prov_data.get("timeout"),
|
||||
max_retries=prov_data.get("max_retries"),
|
||||
proxy=prov_data.get("proxy"),
|
||||
config=prov_data.get("config"),
|
||||
)
|
||||
db.add(new_provider)
|
||||
@@ -1003,7 +1012,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
)
|
||||
|
||||
if existing_ep:
|
||||
endpoint_id = existing_ep.id
|
||||
if merge_mode == "skip":
|
||||
stats["endpoints"]["skipped"] += 1
|
||||
elif merge_mode == "error":
|
||||
@@ -1017,11 +1025,10 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
existing_ep.headers = ep_data.get("headers")
|
||||
existing_ep.timeout = ep_data.get("timeout", 300)
|
||||
existing_ep.max_retries = ep_data.get("max_retries", 2)
|
||||
existing_ep.max_concurrent = ep_data.get("max_concurrent")
|
||||
existing_ep.rate_limit = ep_data.get("rate_limit")
|
||||
existing_ep.is_active = ep_data.get("is_active", True)
|
||||
existing_ep.custom_path = ep_data.get("custom_path")
|
||||
existing_ep.config = ep_data.get("config")
|
||||
existing_ep.proxy = ep_data.get("proxy")
|
||||
existing_ep.updated_at = datetime.now(timezone.utc)
|
||||
stats["endpoints"]["updated"] += 1
|
||||
else:
|
||||
@@ -1033,68 +1040,106 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
headers=ep_data.get("headers"),
|
||||
timeout=ep_data.get("timeout", 300),
|
||||
max_retries=ep_data.get("max_retries", 2),
|
||||
max_concurrent=ep_data.get("max_concurrent"),
|
||||
rate_limit=ep_data.get("rate_limit"),
|
||||
is_active=ep_data.get("is_active", True),
|
||||
custom_path=ep_data.get("custom_path"),
|
||||
config=ep_data.get("config"),
|
||||
proxy=ep_data.get("proxy"),
|
||||
)
|
||||
db.add(new_ep)
|
||||
db.flush()
|
||||
endpoint_id = new_ep.id
|
||||
stats["endpoints"]["created"] += 1
|
||||
|
||||
# 导入 Keys
|
||||
# 获取当前 endpoint 下所有已有的 keys,用于去重
|
||||
existing_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
|
||||
.all()
|
||||
)
|
||||
# 解密已有 keys 用于比对
|
||||
existing_key_values = set()
|
||||
for ek in existing_keys:
|
||||
try:
|
||||
decrypted = crypto_service.decrypt(ek.api_key)
|
||||
existing_key_values.add(decrypted)
|
||||
except Exception:
|
||||
pass
|
||||
# 导入 Provider Keys(按 provider_id 归属)
|
||||
endpoint_format_rows = (
|
||||
db.query(ProviderEndpoint.api_format)
|
||||
.filter(ProviderEndpoint.provider_id == provider_id)
|
||||
.all()
|
||||
)
|
||||
endpoint_formats: set[str] = set()
|
||||
for (api_format,) in endpoint_format_rows:
|
||||
fmt = api_format.value if hasattr(api_format, "value") else str(api_format)
|
||||
endpoint_formats.add(fmt.strip().upper())
|
||||
existing_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(ProviderAPIKey.provider_id == provider_id)
|
||||
.all()
|
||||
)
|
||||
existing_key_values = set()
|
||||
for ek in existing_keys:
|
||||
try:
|
||||
decrypted = crypto_service.decrypt(ek.api_key)
|
||||
existing_key_values.add(decrypted)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for key_data in ep_data.get("keys", []):
|
||||
if not key_data.get("api_key"):
|
||||
stats["errors"].append(
|
||||
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
|
||||
)
|
||||
continue
|
||||
|
||||
# 检查是否已存在相同的 Key(通过明文比对)
|
||||
if key_data["api_key"] in existing_key_values:
|
||||
stats["keys"]["skipped"] += 1
|
||||
continue
|
||||
|
||||
encrypted_key = crypto_service.encrypt(key_data["api_key"])
|
||||
|
||||
new_key = ProviderAPIKey(
|
||||
id=str(uuid.uuid4()),
|
||||
endpoint_id=endpoint_id,
|
||||
api_key=encrypted_key,
|
||||
name=key_data.get("name"),
|
||||
note=key_data.get("note"),
|
||||
rate_multiplier=key_data.get("rate_multiplier", 1.0),
|
||||
internal_priority=key_data.get("internal_priority", 100),
|
||||
global_priority=key_data.get("global_priority"),
|
||||
max_concurrent=key_data.get("max_concurrent"),
|
||||
rate_limit=key_data.get("rate_limit"),
|
||||
daily_limit=key_data.get("daily_limit"),
|
||||
monthly_limit=key_data.get("monthly_limit"),
|
||||
allowed_models=key_data.get("allowed_models"),
|
||||
capabilities=key_data.get("capabilities"),
|
||||
is_active=key_data.get("is_active", True),
|
||||
for key_data in prov_data.get("api_keys", []):
|
||||
if not key_data.get("api_key"):
|
||||
stats["errors"].append(
|
||||
f"跳过空 API Key (Provider: {prov_data['name']})"
|
||||
)
|
||||
db.add(new_key)
|
||||
# 添加到已有集合,防止同一批导入中重复
|
||||
existing_key_values.add(key_data["api_key"])
|
||||
stats["keys"]["created"] += 1
|
||||
continue
|
||||
|
||||
plaintext_key = key_data["api_key"]
|
||||
if plaintext_key in existing_key_values:
|
||||
stats["keys"]["skipped"] += 1
|
||||
continue
|
||||
|
||||
raw_formats = key_data.get("api_formats") or []
|
||||
if not isinstance(raw_formats, list) or len(raw_formats) == 0:
|
||||
stats["errors"].append(
|
||||
f"跳过无 api_formats 的 Key (Provider: {prov_data['name']})"
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_formats: list[str] = []
|
||||
seen: set[str] = set()
|
||||
missing_formats: list[str] = []
|
||||
for fmt in raw_formats:
|
||||
if not isinstance(fmt, str):
|
||||
continue
|
||||
fmt_upper = fmt.strip().upper()
|
||||
if not fmt_upper or fmt_upper in seen:
|
||||
continue
|
||||
seen.add(fmt_upper)
|
||||
if endpoint_formats and fmt_upper not in endpoint_formats:
|
||||
missing_formats.append(fmt_upper)
|
||||
continue
|
||||
normalized_formats.append(fmt_upper)
|
||||
|
||||
if missing_formats:
|
||||
stats["errors"].append(
|
||||
f"Key (Provider: {prov_data['name']}) 的 api_formats 未配置对应 Endpoint,已跳过: {missing_formats}"
|
||||
)
|
||||
|
||||
if len(normalized_formats) == 0:
|
||||
stats["keys"]["skipped"] += 1
|
||||
continue
|
||||
|
||||
encrypted_key = crypto_service.encrypt(plaintext_key)
|
||||
|
||||
new_key = ProviderAPIKey(
|
||||
id=str(uuid.uuid4()),
|
||||
provider_id=provider_id,
|
||||
api_formats=normalized_formats,
|
||||
api_key=encrypted_key,
|
||||
name=key_data.get("name") or "Imported Key",
|
||||
note=key_data.get("note"),
|
||||
rate_multiplier=key_data.get("rate_multiplier", 1.0),
|
||||
rate_multipliers=key_data.get("rate_multipliers"),
|
||||
internal_priority=key_data.get("internal_priority", 50),
|
||||
global_priority=key_data.get("global_priority"),
|
||||
rpm_limit=key_data.get("rpm_limit"),
|
||||
allowed_models=key_data.get("allowed_models"),
|
||||
capabilities=key_data.get("capabilities"),
|
||||
cache_ttl_minutes=key_data.get("cache_ttl_minutes", 5),
|
||||
max_probe_interval_minutes=key_data.get("max_probe_interval_minutes", 32),
|
||||
is_active=key_data.get("is_active", True),
|
||||
health_by_format={},
|
||||
circuit_breaker_by_format={},
|
||||
)
|
||||
db.add(new_key)
|
||||
existing_key_values.add(plaintext_key)
|
||||
stats["keys"]["created"] += 1
|
||||
|
||||
# 导入 Models
|
||||
for model_data in prov_data.get("models", []):
|
||||
|
||||
@@ -247,7 +247,8 @@ async def get_usage_detail(
|
||||
- `request_headers`: 请求头
|
||||
- `request_body`: 请求体
|
||||
- `provider_request_headers`: 提供商请求头
|
||||
- `response_headers`: 响应头
|
||||
- `response_headers`: 提供商响应头
|
||||
- `client_response_headers`: 返回给客户端的响应头
|
||||
- `response_body`: 响应体
|
||||
- `metadata`: 提供商响应元数据
|
||||
- `tiered_pricing`: 阶梯计费信息(如适用)
|
||||
@@ -916,6 +917,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
|
||||
"request_body": usage_record.get_request_body(),
|
||||
"provider_request_headers": usage_record.provider_request_headers,
|
||||
"response_headers": usage_record.response_headers,
|
||||
"client_response_headers": usage_record.client_response_headers,
|
||||
"response_body": usage_record.get_response_body(),
|
||||
"metadata": usage_record.request_metadata,
|
||||
"tiered_pricing": tiered_pricing_info,
|
||||
|
||||
Reference in New Issue
Block a user