Files
Aether/src/api/admin/provider_query.py

324 lines
11 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
Provider Query API 端点
用于查询提供商的模型列表等信息
2025-12-10 20:52:44 +08:00
"""
import asyncio
import os
2025-12-10 20:52:44 +08:00
from typing import Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException
2025-12-10 20:52:44 +08:00
from pydantic import BaseModel
from sqlalchemy.orm import Session, joinedload
2025-12-10 20:52:44 +08:00
from src.core.crypto import crypto_service
2025-12-10 20:52:44 +08:00
from src.core.logger import logger
from src.database.database import get_db
from src.models.database import Provider, ProviderEndpoint, User
2025-12-10 20:52:44 +08:00
from src.utils.auth_utils import get_current_user
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
2025-12-10 20:52:44 +08:00
# ============ Request/Response Models ============
class ModelsQueryRequest(BaseModel):
"""模型列表查询请求"""
provider_id: str
api_key_id: Optional[str] = None
# ============ API Endpoints ============
async def _fetch_openai_models(
client: httpx.AsyncClient,
base_url: str,
api_key: str,
api_format: str,
extra_headers: Optional[dict] = None,
) -> tuple[list, Optional[str]]:
"""获取 OpenAI 格式的模型列表
2025-12-10 20:52:44 +08:00
Returns:
tuple[list, Optional[str]]: (模型列表, 错误信息)
2025-12-10 20:52:44 +08:00
"""
useragent = os.getenv("OPENAI_USER_AGENT") or "codex_cli_rs/0.73.0 (Mac OS 14.8.4; x86_64) Apple_Terminal/453"
headers = {
"Authorization": f"Bearer {api_key}",
"User-Agent": useragent,
}
if extra_headers:
# 防止 extra_headers 覆盖 Authorization
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
headers.update(safe_headers)
# 构建 /v1/models URL
if base_url.endswith("/v1"):
models_url = f"{base_url}/models"
else:
models_url = f"{base_url}/v1/models"
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
models = []
if "data" in data:
models = data["data"]
elif isinstance(data, list):
models = data
# 为每个模型添加 api_format 字段
for m in models:
m["api_format"] = api_format
return models, None
else:
# 记录详细的错误信息
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch models from {models_url}: {e}")
return [], error_msg
async def _fetch_claude_models(
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
) -> tuple[list, Optional[str]]:
"""获取 Claude 格式的模型列表
2025-12-10 20:52:44 +08:00
Returns:
tuple[list, Optional[str]]: (模型列表, 错误信息)
2025-12-10 20:52:44 +08:00
"""
useragent = os.getenv("CLAUDE_USER_AGENT") or "claude-cli/2.0.62 (external, cli)"
headers = {
"x-api-key": api_key,
"Authorization": f"Bearer {api_key}",
"anthropic-version": "2023-06-01",
"User-Agent": useragent,
2025-12-10 20:52:44 +08:00
}
# 构建 /v1/models URL
if base_url.endswith("/v1"):
models_url = f"{base_url}/models"
2025-12-10 20:52:44 +08:00
else:
models_url = f"{base_url}/v1/models"
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
models = []
if "data" in data:
models = data["data"]
elif isinstance(data, list):
models = data
# 为每个模型添加 api_format 字段
for m in models:
m["api_format"] = api_format
return models, None
else:
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
return [], error_msg
2025-12-10 20:52:44 +08:00
async def _fetch_gemini_models(
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
) -> tuple[list, Optional[str]]:
"""获取 Gemini 格式的模型列表
2025-12-10 20:52:44 +08:00
Returns:
tuple[list, Optional[str]]: (模型列表, 错误信息)
2025-12-10 20:52:44 +08:00
"""
# 兼容 base_url 已包含 /v1beta 的情况
base_url_clean = base_url.rstrip("/")
if base_url_clean.endswith("/v1beta"):
models_url = f"{base_url_clean}/models?key={api_key}"
2025-12-10 20:52:44 +08:00
else:
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
useragent = os.getenv("GEMINI_USER_AGENT") or "gemini-cli/0.1.0 (external, cli)"
headers = {
"User-Agent": useragent,
}
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
if "models" in data:
# 转换为统一格式
return [
{
"id": m.get("name", "").replace("models/", ""),
"owned_by": "google",
"display_name": m.get("displayName", ""),
"api_format": api_format,
}
for m in data["models"]
], None
return [], None
else:
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
return [], error_msg
2025-12-10 20:52:44 +08:00
@router.post("/models")
async def query_available_models(
request: ModelsQueryRequest,
db: Session = Depends(get_db),
2025-12-10 20:52:44 +08:00
current_user: User = Depends(get_current_user),
):
"""
查询提供商可用模型
遍历所有活跃端点根据端点的 API 格式选择正确的请求方式
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
2025-12-10 20:52:44 +08:00
Args:
request: 查询请求
Returns:
所有端点的模型列表合并
2025-12-10 20:52:44 +08:00
"""
# 获取提供商及其端点
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.filter(Provider.id == request.provider_id)
.first()
2025-12-10 20:52:44 +08:00
)
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 收集所有活跃端点的配置
endpoint_configs: list[dict] = []
2025-12-10 20:52:44 +08:00
if request.api_key_id:
# 指定了特定的 API Key只使用该 Key 对应的端点
2025-12-10 20:52:44 +08:00
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
2025-12-10 20:52:44 +08:00
break
if endpoint_configs:
2025-12-10 20:52:44 +08:00
break
if not endpoint_configs:
2025-12-10 20:52:44 +08:00
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 遍历所有活跃端点,每个端点取第一个可用的 Key
2025-12-10 20:52:44 +08:00
for endpoint in provider.endpoints:
if not endpoint.is_active or not endpoint.api_keys:
continue
2025-12-10 20:52:44 +08:00
# 找第一个可用的 Key
for api_key in endpoint.api_keys:
if api_key.is_active:
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
continue # 尝试下一个 Key
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break # 只取第一个可用的 Key
2025-12-10 20:52:44 +08:00
if not endpoint_configs:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
2025-12-10 20:52:44 +08:00
# 并发请求所有端点的模型列表
all_models: list = []
errors: list[str] = []
async def fetch_endpoint_models(
client: httpx.AsyncClient, config: dict
) -> tuple[list, Optional[str]]:
base_url = config["base_url"]
if not base_url:
return [], None
base_url = base_url.rstrip("/")
api_format = config["api_format"]
api_key_value = config["api_key"]
extra_headers = config["extra_headers"]
try:
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
elif api_format in ["GEMINI", "GEMINI_CLI"]:
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
else:
return await _fetch_openai_models(
client, base_url, api_key_value, api_format, extra_headers
)
except Exception as e:
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
return [], f"{api_format}: {str(e)}"
async with httpx.AsyncClient(timeout=30.0) as client:
results = await asyncio.gather(
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
2025-12-10 20:52:44 +08:00
)
for models, error in results:
all_models.extend(models)
if error:
errors.append(error)
# 按 model id 去重(保留第一个)
seen_ids: set[str] = set()
unique_models: list = []
for model in all_models:
model_id = model.get("id")
if model_id and model_id not in seen_ids:
seen_ids.add(model_id)
unique_models.append(model)
error = "; ".join(errors) if errors else None
if not unique_models and not error:
error = "No models returned from any endpoint"
2025-12-10 20:52:44 +08:00
return {
"success": len(unique_models) > 0,
"data": {"models": unique_models, "error": error},
2025-12-10 20:52:44 +08:00
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}