Files
Aether/src/api/admin/provider_query.py

209 lines
7.1 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
Provider Query API 端点
用于查询提供商的模型列表等信息
2025-12-10 20:52:44 +08:00
"""
import asyncio
2025-12-10 20:52:44 +08:00
from typing import Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException
2025-12-10 20:52:44 +08:00
from pydantic import BaseModel
from sqlalchemy.orm import Session, joinedload
2025-12-10 20:52:44 +08:00
from src.api.handlers.base.chat_adapter_base import get_adapter_class
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_class
from src.core.crypto import crypto_service
2025-12-10 20:52:44 +08:00
from src.core.logger import logger
from src.database.database import get_db
from src.models.database import Provider, ProviderEndpoint, User
2025-12-10 20:52:44 +08:00
from src.utils.auth_utils import get_current_user
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
2025-12-10 20:52:44 +08:00
# ============ Request/Response Models ============
class ModelsQueryRequest(BaseModel):
"""模型列表查询请求"""
provider_id: str
api_key_id: Optional[str] = None
# ============ API Endpoints ============
def _get_adapter_for_format(api_format: str):
"""根据 API 格式获取对应的 Adapter 类"""
# 先检查 Chat Adapter 注册表
adapter_class = get_adapter_class(api_format)
if adapter_class:
return adapter_class
2025-12-10 20:52:44 +08:00
# 再检查 CLI Adapter 注册表
cli_adapter_class = get_cli_adapter_class(api_format)
if cli_adapter_class:
return cli_adapter_class
2025-12-10 20:52:44 +08:00
return None
2025-12-10 20:52:44 +08:00
@router.post("/models")
async def query_available_models(
request: ModelsQueryRequest,
db: Session = Depends(get_db),
2025-12-10 20:52:44 +08:00
current_user: User = Depends(get_current_user),
):
"""
查询提供商可用模型
遍历所有活跃端点根据端点的 API 格式选择正确的 Adapter 进行请求
- OPENAI/OPENAI_CLI: 使用 OpenAIChatAdapter.fetch_models
- CLAUDE/CLAUDE_CLI: 使用 ClaudeChatAdapter.fetch_models
- GEMINI/GEMINI_CLI: 使用 GeminiChatAdapter.fetch_models
2025-12-10 20:52:44 +08:00
Args:
request: 查询请求
Returns:
所有端点的模型列表合并
2025-12-10 20:52:44 +08:00
"""
# 获取提供商及其端点
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.filter(Provider.id == request.provider_id)
.first()
2025-12-10 20:52:44 +08:00
)
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 收集所有活跃端点的配置
endpoint_configs: list[dict] = []
2025-12-10 20:52:44 +08:00
if request.api_key_id:
# 指定了特定的 API Key只使用该 Key 对应的端点
2025-12-10 20:52:44 +08:00
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
2025-12-10 20:52:44 +08:00
break
if endpoint_configs:
2025-12-10 20:52:44 +08:00
break
if not endpoint_configs:
2025-12-10 20:52:44 +08:00
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 遍历所有活跃端点,每个端点取第一个可用的 Key
2025-12-10 20:52:44 +08:00
for endpoint in provider.endpoints:
if not endpoint.is_active or not endpoint.api_keys:
continue
2025-12-10 20:52:44 +08:00
# 找第一个可用的 Key
for api_key in endpoint.api_keys:
if api_key.is_active:
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
continue # 尝试下一个 Key
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break # 只取第一个可用的 Key
2025-12-10 20:52:44 +08:00
if not endpoint_configs:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
2025-12-10 20:52:44 +08:00
# 并发请求所有端点的模型列表
all_models: list = []
errors: list[str] = []
async def fetch_endpoint_models(
client: httpx.AsyncClient, config: dict
) -> tuple[list, Optional[str]]:
base_url = config["base_url"]
if not base_url:
return [], None
base_url = base_url.rstrip("/")
api_format = config["api_format"]
api_key_value = config["api_key"]
extra_headers = config.get("extra_headers")
try:
# 获取对应的 Adapter 类并调用 fetch_models
adapter_class = _get_adapter_for_format(api_format)
if not adapter_class:
return [], f"Unknown API format: {api_format}"
models, error = await adapter_class.fetch_models(
client, base_url, api_key_value, extra_headers
)
# 确保所有模型都有 api_format 字段
for m in models:
if "api_format" not in m:
m["api_format"] = api_format
return models, error
except Exception as e:
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
return [], f"{api_format}: {str(e)}"
# 限制并发请求数量,避免触发上游速率限制
MAX_CONCURRENT_REQUESTS = 5
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
async def fetch_with_semaphore(
client: httpx.AsyncClient, config: dict
) -> tuple[list, Optional[str]]:
async with semaphore:
return await fetch_endpoint_models(client, config)
async with httpx.AsyncClient(timeout=30.0) as client:
results = await asyncio.gather(
*[fetch_with_semaphore(client, c) for c in endpoint_configs]
2025-12-10 20:52:44 +08:00
)
for models, error in results:
all_models.extend(models)
if error:
errors.append(error)
# 按 model id + api_format 去重(保留第一个)
seen_keys: set[str] = set()
unique_models: list = []
for model in all_models:
model_id = model.get("id")
api_format = model.get("api_format", "")
unique_key = f"{model_id}:{api_format}"
if model_id and unique_key not in seen_keys:
seen_keys.add(unique_key)
unique_models.append(model)
error = "; ".join(errors) if errors else None
if not unique_models and not error:
error = "No models returned from any endpoint"
2025-12-10 20:52:44 +08:00
return {
"success": len(unique_models) > 0,
"data": {"models": unique_models, "error": error},
2025-12-10 20:52:44 +08:00
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}