mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
feat: add internal model list query interface with configurable User-Agent headers
This commit is contained in:
@@ -4,7 +4,6 @@ Provider Query API 端点
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -12,6 +11,8 @@ from fastapi import APIRouter, Depends, HTTPException
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from src.api.handlers.base.chat_adapter_base import get_adapter_class
|
||||||
|
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_class
|
||||||
from src.core.crypto import crypto_service
|
from src.core.crypto import crypto_service
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database.database import get_db
|
from src.database.database import get_db
|
||||||
@@ -34,151 +35,19 @@ class ModelsQueryRequest(BaseModel):
|
|||||||
# ============ API Endpoints ============
|
# ============ API Endpoints ============
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_openai_models(
|
def _get_adapter_for_format(api_format: str):
|
||||||
client: httpx.AsyncClient,
|
"""根据 API 格式获取对应的 Adapter 类"""
|
||||||
base_url: str,
|
# 先检查 Chat Adapter 注册表
|
||||||
api_key: str,
|
adapter_class = get_adapter_class(api_format)
|
||||||
api_format: str,
|
if adapter_class:
|
||||||
extra_headers: Optional[dict] = None,
|
return adapter_class
|
||||||
) -> tuple[list, Optional[str]]:
|
|
||||||
"""获取 OpenAI 格式的模型列表
|
|
||||||
|
|
||||||
Returns:
|
# 再检查 CLI Adapter 注册表
|
||||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
cli_adapter_class = get_cli_adapter_class(api_format)
|
||||||
"""
|
if cli_adapter_class:
|
||||||
useragent = os.getenv("OPENAI_USER_AGENT") or "codex_cli_rs/0.73.0 (Mac OS 14.8.4; x86_64) Apple_Terminal/453"
|
return cli_adapter_class
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"User-Agent": useragent,
|
|
||||||
}
|
|
||||||
if extra_headers:
|
|
||||||
# 防止 extra_headers 覆盖 Authorization
|
|
||||||
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
|
|
||||||
headers.update(safe_headers)
|
|
||||||
|
|
||||||
# 构建 /v1/models URL
|
return None
|
||||||
if base_url.endswith("/v1"):
|
|
||||||
models_url = f"{base_url}/models"
|
|
||||||
else:
|
|
||||||
models_url = f"{base_url}/v1/models"
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.get(models_url, headers=headers)
|
|
||||||
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
|
|
||||||
if response.status_code == 200:
|
|
||||||
data = response.json()
|
|
||||||
models = []
|
|
||||||
if "data" in data:
|
|
||||||
models = data["data"]
|
|
||||||
elif isinstance(data, list):
|
|
||||||
models = data
|
|
||||||
# 为每个模型添加 api_format 字段
|
|
||||||
for m in models:
|
|
||||||
m["api_format"] = api_format
|
|
||||||
return models, None
|
|
||||||
else:
|
|
||||||
# 记录详细的错误信息
|
|
||||||
error_body = response.text[:500] if response.text else "(empty)"
|
|
||||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
|
||||||
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
|
|
||||||
return [], error_msg
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Request error: {str(e)}"
|
|
||||||
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
|
||||||
return [], error_msg
|
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_claude_models(
|
|
||||||
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
|
||||||
) -> tuple[list, Optional[str]]:
|
|
||||||
"""获取 Claude 格式的模型列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
|
||||||
"""
|
|
||||||
useragent = os.getenv("CLAUDE_USER_AGENT") or "claude-cli/2.0.62 (external, cli)"
|
|
||||||
headers = {
|
|
||||||
"x-api-key": api_key,
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"anthropic-version": "2023-06-01",
|
|
||||||
"User-Agent": useragent,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 构建 /v1/models URL
|
|
||||||
if base_url.endswith("/v1"):
|
|
||||||
models_url = f"{base_url}/models"
|
|
||||||
else:
|
|
||||||
models_url = f"{base_url}/v1/models"
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.get(models_url, headers=headers)
|
|
||||||
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
|
|
||||||
if response.status_code == 200:
|
|
||||||
data = response.json()
|
|
||||||
models = []
|
|
||||||
if "data" in data:
|
|
||||||
models = data["data"]
|
|
||||||
elif isinstance(data, list):
|
|
||||||
models = data
|
|
||||||
# 为每个模型添加 api_format 字段
|
|
||||||
for m in models:
|
|
||||||
m["api_format"] = api_format
|
|
||||||
return models, None
|
|
||||||
else:
|
|
||||||
error_body = response.text[:500] if response.text else "(empty)"
|
|
||||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
|
||||||
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
|
|
||||||
return [], error_msg
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Request error: {str(e)}"
|
|
||||||
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
|
||||||
return [], error_msg
|
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_gemini_models(
|
|
||||||
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
|
||||||
) -> tuple[list, Optional[str]]:
|
|
||||||
"""获取 Gemini 格式的模型列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
|
||||||
"""
|
|
||||||
# 兼容 base_url 已包含 /v1beta 的情况
|
|
||||||
base_url_clean = base_url.rstrip("/")
|
|
||||||
if base_url_clean.endswith("/v1beta"):
|
|
||||||
models_url = f"{base_url_clean}/models?key={api_key}"
|
|
||||||
else:
|
|
||||||
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
|
|
||||||
useragent = os.getenv("GEMINI_USER_AGENT") or "gemini-cli/0.1.0 (external, cli)"
|
|
||||||
headers = {
|
|
||||||
"User-Agent": useragent,
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
response = await client.get(models_url, headers=headers)
|
|
||||||
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
|
|
||||||
if response.status_code == 200:
|
|
||||||
data = response.json()
|
|
||||||
if "models" in data:
|
|
||||||
# 转换为统一格式
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": m.get("name", "").replace("models/", ""),
|
|
||||||
"owned_by": "google",
|
|
||||||
"display_name": m.get("displayName", ""),
|
|
||||||
"api_format": api_format,
|
|
||||||
}
|
|
||||||
for m in data["models"]
|
|
||||||
], None
|
|
||||||
return [], None
|
|
||||||
else:
|
|
||||||
error_body = response.text[:500] if response.text else "(empty)"
|
|
||||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
|
||||||
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
|
|
||||||
return [], error_msg
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Request error: {str(e)}"
|
|
||||||
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
|
||||||
return [], error_msg
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/models")
|
@router.post("/models")
|
||||||
@@ -190,10 +59,10 @@ async def query_available_models(
|
|||||||
"""
|
"""
|
||||||
查询提供商可用模型
|
查询提供商可用模型
|
||||||
|
|
||||||
遍历所有活跃端点,根据端点的 API 格式选择正确的请求方式:
|
遍历所有活跃端点,根据端点的 API 格式选择正确的 Adapter 进行请求:
|
||||||
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
|
- OPENAI/OPENAI_CLI: 使用 OpenAIChatAdapter.fetch_models
|
||||||
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
|
- CLAUDE/CLAUDE_CLI: 使用 ClaudeChatAdapter.fetch_models
|
||||||
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
|
- GEMINI/GEMINI_CLI: 使用 GeminiChatAdapter.fetch_models
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: 查询请求
|
request: 查询请求
|
||||||
@@ -275,16 +144,15 @@ async def query_available_models(
|
|||||||
base_url = base_url.rstrip("/")
|
base_url = base_url.rstrip("/")
|
||||||
api_format = config["api_format"]
|
api_format = config["api_format"]
|
||||||
api_key_value = config["api_key"]
|
api_key_value = config["api_key"]
|
||||||
extra_headers = config["extra_headers"]
|
extra_headers = config.get("extra_headers")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
|
# 获取对应的 Adapter 类并调用 fetch_models
|
||||||
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
|
adapter_class = _get_adapter_for_format(api_format)
|
||||||
elif api_format in ["GEMINI", "GEMINI_CLI"]:
|
if not adapter_class:
|
||||||
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
|
return [], f"Unknown API format: {api_format}"
|
||||||
else:
|
return await adapter_class.fetch_models(
|
||||||
return await _fetch_openai_models(
|
client, base_url, api_key_value, extra_headers
|
||||||
client, base_url, api_key_value, api_format, extra_headers
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
|
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
|
||||||
|
|||||||
@@ -19,8 +19,9 @@ Chat Adapter 通用基类
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Dict, Optional, Type
|
from typing import Any, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
@@ -620,6 +621,39 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||||
return tiers[-1] if tiers else None
|
return tiers[-1] if tiers else None
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 模型列表查询 - 子类应覆盖此方法
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_models(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Tuple[list, Optional[str]]:
|
||||||
|
"""
|
||||||
|
查询上游 API 支持的模型列表
|
||||||
|
|
||||||
|
这是 Aether 内部发起的请求(非用户透传),用于:
|
||||||
|
- 管理后台查询提供商支持的模型
|
||||||
|
- 自动发现可用模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: httpx 异步客户端
|
||||||
|
base_url: API 基础 URL
|
||||||
|
api_key: API 密钥(已解密)
|
||||||
|
extra_headers: 端点配置的额外请求头
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(models, error): 模型列表和错误信息
|
||||||
|
- models: 模型信息列表,每个模型至少包含 id 字段
|
||||||
|
- error: 错误信息,成功时为 None
|
||||||
|
"""
|
||||||
|
# 默认实现返回空列表,子类应覆盖
|
||||||
|
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
||||||
|
|||||||
@@ -17,8 +17,9 @@ CLI Adapter 通用基类
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, Dict, Optional, Type
|
from typing import Any, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
@@ -580,6 +581,39 @@ class CliAdapterBase(ApiAdapter):
|
|||||||
|
|
||||||
return tiers[-1] if tiers else None
|
return tiers[-1] if tiers else None
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 模型列表查询 - 子类应覆盖此方法
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_models(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Tuple[list, Optional[str]]:
|
||||||
|
"""
|
||||||
|
查询上游 API 支持的模型列表
|
||||||
|
|
||||||
|
这是 Aether 内部发起的请求(非用户透传),用于:
|
||||||
|
- 管理后台查询提供商支持的模型
|
||||||
|
- 自动发现可用模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: httpx 异步客户端
|
||||||
|
base_url: API 基础 URL
|
||||||
|
api_key: API 密钥(已解密)
|
||||||
|
extra_headers: 端点配置的额外请求头
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(models, error): 模型列表和错误信息
|
||||||
|
- models: 模型信息列表,每个模型至少包含 id 字段
|
||||||
|
- error: 错误信息,成功时为 None
|
||||||
|
"""
|
||||||
|
# 默认实现返回空列表,子类应覆盖
|
||||||
|
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ Claude Chat Adapter - 基于 ChatAdapterBase 的 Claude Chat API 适配器
|
|||||||
处理 /v1/messages 端点的 Claude Chat 格式请求。
|
处理 /v1/messages 端点的 Claude Chat 格式请求。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Type
|
from typing import Any, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
@@ -155,6 +156,59 @@ class ClaudeChatAdapter(ChatAdapterBase):
|
|||||||
"thinking_enabled": bool(request_obj.thinking),
|
"thinking_enabled": bool(request_obj.thinking),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_models(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Tuple[list, Optional[str]]:
|
||||||
|
"""查询 Claude API 支持的模型列表"""
|
||||||
|
headers = {
|
||||||
|
"x-api-key": api_key,
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
}
|
||||||
|
if extra_headers:
|
||||||
|
# 防止 extra_headers 覆盖认证头
|
||||||
|
safe_headers = {
|
||||||
|
k: v for k, v in extra_headers.items()
|
||||||
|
if k.lower() not in ("x-api-key", "authorization", "anthropic-version")
|
||||||
|
}
|
||||||
|
headers.update(safe_headers)
|
||||||
|
|
||||||
|
# 构建 /v1/models URL
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
models_url = f"{base_url}/models"
|
||||||
|
else:
|
||||||
|
models_url = f"{base_url}/v1/models"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get(models_url, headers=headers)
|
||||||
|
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
models = []
|
||||||
|
if "data" in data:
|
||||||
|
models = data["data"]
|
||||||
|
elif isinstance(data, list):
|
||||||
|
models = data
|
||||||
|
# 为每个模型添加 api_format 字段
|
||||||
|
for m in models:
|
||||||
|
m["api_format"] = cls.FORMAT_ID
|
||||||
|
return models, None
|
||||||
|
else:
|
||||||
|
error_body = response.text[:500] if response.text else "(empty)"
|
||||||
|
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||||
|
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
|
||||||
|
return [], error_msg
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Request error: {str(e)}"
|
||||||
|
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
||||||
|
return [], error_msg
|
||||||
|
|
||||||
|
|
||||||
def build_claude_adapter(x_app_header: Optional[str]):
|
def build_claude_adapter(x_app_header: Optional[str]):
|
||||||
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
||||||
|
|||||||
@@ -4,13 +4,15 @@ Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
|||||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Type
|
from typing import Any, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||||
from src.api.handlers.claude.adapter import ClaudeCapabilityDetector
|
from src.api.handlers.claude.adapter import ClaudeCapabilityDetector, ClaudeChatAdapter
|
||||||
|
from src.config.settings import config
|
||||||
|
|
||||||
|
|
||||||
@register_cli_adapter
|
@register_cli_adapter
|
||||||
@@ -99,5 +101,30 @@ class ClaudeCliAdapter(CliAdapterBase):
|
|||||||
"system_present": bool(payload.get("system")),
|
"system_present": bool(payload.get("system")),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 模型列表查询
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_models(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Tuple[list, Optional[str]]:
|
||||||
|
"""查询 Claude API 支持的模型列表(带 CLI User-Agent)"""
|
||||||
|
# 复用 ClaudeChatAdapter 的实现,添加 CLI User-Agent
|
||||||
|
cli_headers = {"User-Agent": config.internal_user_agent_claude}
|
||||||
|
if extra_headers:
|
||||||
|
cli_headers.update(extra_headers)
|
||||||
|
models, error = await ClaudeChatAdapter.fetch_models(
|
||||||
|
client, base_url, api_key, cli_headers
|
||||||
|
)
|
||||||
|
# 更新 api_format 为 CLI 格式
|
||||||
|
for m in models:
|
||||||
|
m["api_format"] = cls.FORMAT_ID
|
||||||
|
return models, error
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ClaudeCliAdapter"]
|
__all__ = ["ClaudeCliAdapter"]
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ Gemini Chat Adapter
|
|||||||
处理 Gemini API 格式的请求适配
|
处理 Gemini API 格式的请求适配
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Type
|
from typing import Any, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
@@ -151,6 +152,53 @@ class GeminiChatAdapter(ChatAdapterBase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_models(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Tuple[list, Optional[str]]:
|
||||||
|
"""查询 Gemini API 支持的模型列表"""
|
||||||
|
# 兼容 base_url 已包含 /v1beta 的情况
|
||||||
|
base_url_clean = base_url.rstrip("/")
|
||||||
|
if base_url_clean.endswith("/v1beta"):
|
||||||
|
models_url = f"{base_url_clean}/models?key={api_key}"
|
||||||
|
else:
|
||||||
|
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
|
||||||
|
|
||||||
|
headers: Dict[str, str] = {}
|
||||||
|
if extra_headers:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get(models_url, headers=headers)
|
||||||
|
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if "models" in data:
|
||||||
|
# 转换为统一格式
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": m.get("name", "").replace("models/", ""),
|
||||||
|
"owned_by": "google",
|
||||||
|
"display_name": m.get("displayName", ""),
|
||||||
|
"api_format": cls.FORMAT_ID,
|
||||||
|
}
|
||||||
|
for m in data["models"]
|
||||||
|
], None
|
||||||
|
return [], None
|
||||||
|
else:
|
||||||
|
error_body = response.text[:500] if response.text else "(empty)"
|
||||||
|
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||||
|
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
|
||||||
|
return [], error_msg
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Request error: {str(e)}"
|
||||||
|
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
||||||
|
return [], error_msg
|
||||||
|
|
||||||
|
|
||||||
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
|
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
|
|||||||
继承 CliAdapterBase,处理 Gemini CLI 格式的请求。
|
继承 CliAdapterBase,处理 Gemini CLI 格式的请求。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Type
|
from typing import Any, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||||
|
from src.api.handlers.gemini.adapter import GeminiChatAdapter
|
||||||
|
from src.config.settings import config
|
||||||
|
|
||||||
|
|
||||||
@register_cli_adapter
|
@register_cli_adapter
|
||||||
@@ -95,6 +98,31 @@ class GeminiCliAdapter(CliAdapterBase):
|
|||||||
"safety_settings_count": len(payload.get("safety_settings") or []),
|
"safety_settings_count": len(payload.get("safety_settings") or []),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 模型列表查询
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_models(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Tuple[list, Optional[str]]:
|
||||||
|
"""查询 Gemini API 支持的模型列表(带 CLI User-Agent)"""
|
||||||
|
# 复用 GeminiChatAdapter 的实现,添加 CLI User-Agent
|
||||||
|
cli_headers = {"User-Agent": config.internal_user_agent_gemini}
|
||||||
|
if extra_headers:
|
||||||
|
cli_headers.update(extra_headers)
|
||||||
|
models, error = await GeminiChatAdapter.fetch_models(
|
||||||
|
client, base_url, api_key, cli_headers
|
||||||
|
)
|
||||||
|
# 更新 api_format 为 CLI 格式
|
||||||
|
for m in models:
|
||||||
|
m["api_format"] = cls.FORMAT_ID
|
||||||
|
return models, error
|
||||||
|
|
||||||
|
|
||||||
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
|
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
|
|||||||
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
|
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Type
|
from typing import Any, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
@@ -105,5 +106,53 @@ class OpenAIChatAdapter(ChatAdapterBase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_models(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Tuple[list, Optional[str]]:
|
||||||
|
"""查询 OpenAI 兼容 API 支持的模型列表"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
}
|
||||||
|
if extra_headers:
|
||||||
|
# 防止 extra_headers 覆盖 Authorization
|
||||||
|
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
|
||||||
|
headers.update(safe_headers)
|
||||||
|
|
||||||
|
# 构建 /v1/models URL
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
models_url = f"{base_url}/models"
|
||||||
|
else:
|
||||||
|
models_url = f"{base_url}/v1/models"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get(models_url, headers=headers)
|
||||||
|
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
models = []
|
||||||
|
if "data" in data:
|
||||||
|
models = data["data"]
|
||||||
|
elif isinstance(data, list):
|
||||||
|
models = data
|
||||||
|
# 为每个模型添加 api_format 字段
|
||||||
|
for m in models:
|
||||||
|
m["api_format"] = cls.FORMAT_ID
|
||||||
|
return models, None
|
||||||
|
else:
|
||||||
|
error_body = response.text[:500] if response.text else "(empty)"
|
||||||
|
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||||
|
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
|
||||||
|
return [], error_msg
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Request error: {str(e)}"
|
||||||
|
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
||||||
|
return [], error_msg
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["OpenAIChatAdapter"]
|
__all__ = ["OpenAIChatAdapter"]
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
|||||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Type
|
from typing import Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||||
|
from src.api.handlers.openai.adapter import OpenAIChatAdapter
|
||||||
|
from src.config.settings import config
|
||||||
|
|
||||||
|
|
||||||
@register_cli_adapter
|
@register_cli_adapter
|
||||||
@@ -40,5 +43,30 @@ class OpenAICliAdapter(CliAdapterBase):
|
|||||||
return authorization.replace("Bearer ", "")
|
return authorization.replace("Bearer ", "")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 模型列表查询
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_models(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
) -> Tuple[list, Optional[str]]:
|
||||||
|
"""查询 OpenAI 兼容 API 支持的模型列表(带 CLI User-Agent)"""
|
||||||
|
# 复用 OpenAIChatAdapter 的实现,添加 CLI User-Agent
|
||||||
|
cli_headers = {"User-Agent": config.internal_user_agent_openai}
|
||||||
|
if extra_headers:
|
||||||
|
cli_headers.update(extra_headers)
|
||||||
|
models, error = await OpenAIChatAdapter.fetch_models(
|
||||||
|
client, base_url, api_key, cli_headers
|
||||||
|
)
|
||||||
|
# 更新 api_format 为 CLI 格式
|
||||||
|
for m in models:
|
||||||
|
m["api_format"] = cls.FORMAT_ID
|
||||||
|
return models, error
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["OpenAICliAdapter"]
|
__all__ = ["OpenAICliAdapter"]
|
||||||
|
|||||||
@@ -144,6 +144,18 @@ class Config:
|
|||||||
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
|
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
|
||||||
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
||||||
|
|
||||||
|
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
|
||||||
|
# 可通过环境变量覆盖默认值
|
||||||
|
self.internal_user_agent_claude = os.getenv(
|
||||||
|
"CLAUDE_USER_AGENT", "claude-cli/1.0"
|
||||||
|
)
|
||||||
|
self.internal_user_agent_openai = os.getenv(
|
||||||
|
"OPENAI_USER_AGENT", "openai-cli/1.0"
|
||||||
|
)
|
||||||
|
self.internal_user_agent_gemini = os.getenv(
|
||||||
|
"GEMINI_USER_AGENT", "gemini-cli/1.0"
|
||||||
|
)
|
||||||
|
|
||||||
# 验证连接池配置
|
# 验证连接池配置
|
||||||
self._validate_pool_config()
|
self._validate_pool_config()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user