2 Commits

Author SHA1 Message Date
fawney19
b89a4af0cf refactor: 统一 HTTP 客户端超时配置
将 HTTPClientPool 中硬编码的超时参数改为使用可配置的环境变量,提高系统的灵活性和可维护性。

- 添加 HTTP_READ_TIMEOUT 环境变量配置(默认 300 秒)
- 统一所有 HTTP 客户端创建逻辑使用配置化超时
- 改进变量命名清晰性(config -> default_config 或 client_config)
2025-12-30 15:06:55 +08:00
fawney19
a56854af43 feat: 为 GlobalModel 添加关联提供商查询 API
添加新的 API 端点 GET /api/admin/models/global/{global_model_id}/providers,用于获取 GlobalModel 的所有关联提供商(包括非活跃的)。

- 后端:实现 AdminGetGlobalModelProvidersAdapter 适配器
- 前端:使用新 API 替换原有的 ModelCatalog 获取方式
- 数据库:改进初始化时的错误提示和连接异常处理
2025-12-30 14:47:35 +08:00
8 changed files with 191 additions and 55 deletions

View File

@@ -4,7 +4,8 @@ import type {
GlobalModelUpdate, GlobalModelUpdate,
GlobalModelResponse, GlobalModelResponse,
GlobalModelWithStats, GlobalModelWithStats,
GlobalModelListResponse GlobalModelListResponse,
ModelCatalogProviderDetail,
} from './types' } from './types'
/** /**
@@ -83,3 +84,16 @@ export async function batchAssignToProviders(
) )
return response.data return response.data
} }
/**
* 获取 GlobalModel 的所有关联提供商(包括非活跃的)
*/
export async function getGlobalModelProviders(globalModelId: string): Promise<{
providers: ModelCatalogProviderDetail[]
total: number
}> {
const response = await client.get(
`/api/admin/models/global/${globalModelId}/providers`
)
return response.data
}

View File

@@ -20,4 +20,5 @@ export {
updateGlobalModel, updateGlobalModel,
deleteGlobalModel, deleteGlobalModel,
batchAssignToProviders, batchAssignToProviders,
getGlobalModelProviders,
} from './endpoints/global-models' } from './endpoints/global-models'

View File

@@ -737,6 +737,7 @@ import {
updateGlobalModel, updateGlobalModel,
deleteGlobalModel, deleteGlobalModel,
batchAssignToProviders, batchAssignToProviders,
getGlobalModelProviders,
type GlobalModelResponse, type GlobalModelResponse,
} from '@/api/global-models' } from '@/api/global-models'
import { log } from '@/utils/logger' import { log } from '@/utils/logger'
@@ -1080,42 +1081,32 @@ async function selectModel(model: GlobalModelResponse) {
async function loadModelProviders(_globalModelId: string) { async function loadModelProviders(_globalModelId: string) {
loadingModelProviders.value = true loadingModelProviders.value = true
try { try {
// 使用 ModelCatalog API 获取详细的关联提供商信息 // 使用新的 API 获取所有关联提供商(包括非活跃的)
const { getModelCatalog } = await import('@/api/endpoints') const response = await getGlobalModelProviders(_globalModelId)
const catalogResponse = await getModelCatalog()
// 查找当前 GlobalModel 对应的 catalog item // 转换为展示格式
const catalogItem = catalogResponse.models.find( selectedModelProviders.value = response.providers.map(p => ({
m => m.global_model_name === selectedModel.value?.name id: p.provider_id,
) model_id: p.model_id,
display_name: p.provider_display_name || p.provider_name,
if (catalogItem) { identifier: p.provider_name,
// 转换为展示格式,包含完整的模型实现信息 provider_type: 'API',
selectedModelProviders.value = catalogItem.providers.map(p => ({ target_model: p.target_model,
id: p.provider_id, is_active: p.is_active,
model_id: p.model_id, // 价格信息
display_name: p.provider_display_name || p.provider_name, input_price_per_1m: p.input_price_per_1m,
identifier: p.provider_name, output_price_per_1m: p.output_price_per_1m,
provider_type: 'API', cache_creation_price_per_1m: p.cache_creation_price_per_1m,
target_model: p.target_model, cache_read_price_per_1m: p.cache_read_price_per_1m,
is_active: p.is_active, cache_1h_creation_price_per_1m: p.cache_1h_creation_price_per_1m,
// 价格信息 price_per_request: p.price_per_request,
input_price_per_1m: p.input_price_per_1m, effective_tiered_pricing: p.effective_tiered_pricing,
output_price_per_1m: p.output_price_per_1m, tier_count: p.tier_count,
cache_creation_price_per_1m: p.cache_creation_price_per_1m, // 能力信息
cache_read_price_per_1m: p.cache_read_price_per_1m, supports_vision: p.supports_vision,
cache_1h_creation_price_per_1m: p.cache_1h_creation_price_per_1m, supports_function_calling: p.supports_function_calling,
price_per_request: p.price_per_request, supports_streaming: p.supports_streaming
effective_tiered_pricing: p.effective_tiered_pricing, }))
tier_count: p.tier_count,
// 能力信息
supports_vision: p.supports_vision,
supports_function_calling: p.supports_function_calling,
supports_streaming: p.supports_streaming
}))
} else {
selectedModelProviders.value = []
}
} catch (err: any) { } catch (err: any) {
log.error('加载关联提供商失败:', err) log.error('加载关联提供商失败:', err)
showError(parseApiError(err, '加载关联提供商失败'), '错误') showError(parseApiError(err, '加载关联提供商失败'), '错误')

View File

@@ -5,7 +5,7 @@ GlobalModel Admin API
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import Optional
from fastapi import APIRouter, Depends, Query, Request from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -19,9 +19,11 @@ from src.models.pydantic_models import (
BatchAssignToProvidersResponse, BatchAssignToProvidersResponse,
GlobalModelCreate, GlobalModelCreate,
GlobalModelListResponse, GlobalModelListResponse,
GlobalModelProvidersResponse,
GlobalModelResponse, GlobalModelResponse,
GlobalModelUpdate, GlobalModelUpdate,
GlobalModelWithStats, GlobalModelWithStats,
ModelCatalogProviderDetail,
) )
from src.services.model.global_model import GlobalModelService from src.services.model.global_model import GlobalModelService
@@ -108,6 +110,17 @@ async def batch_assign_to_providers(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{global_model_id}/providers", response_model=GlobalModelProvidersResponse)
async def get_global_model_providers(
request: Request,
global_model_id: str,
db: Session = Depends(get_db),
) -> GlobalModelProvidersResponse:
"""获取 GlobalModel 的所有关联提供商(包括非活跃的)"""
adapter = AdminGetGlobalModelProvidersAdapter(global_model_id=global_model_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ========== Adapters ========== # ========== Adapters ==========
@@ -275,3 +288,61 @@ class AdminBatchAssignToProvidersAdapter(AdminApiAdapter):
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}") logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
return BatchAssignToProvidersResponse(**result) return BatchAssignToProvidersResponse(**result)
@dataclass
class AdminGetGlobalModelProvidersAdapter(AdminApiAdapter):
"""获取 GlobalModel 的所有关联提供商(包括非活跃的)"""
global_model_id: str
async def handle(self, context): # type: ignore[override]
from sqlalchemy.orm import joinedload
from src.models.database import Model
global_model = GlobalModelService.get_global_model(context.db, self.global_model_id)
# 获取所有关联的 Model包括非活跃的
models = (
context.db.query(Model)
.options(joinedload(Model.provider), joinedload(Model.global_model))
.filter(Model.global_model_id == global_model.id)
.all()
)
provider_entries = []
for model in models:
provider = model.provider
if not provider:
continue
effective_tiered = model.get_effective_tiered_pricing()
tier_count = len(effective_tiered.get("tiers", [])) if effective_tiered else 1
provider_entries.append(
ModelCatalogProviderDetail(
provider_id=provider.id,
provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id,
target_model=model.provider_model_name,
input_price_per_1m=model.get_effective_input_price(),
output_price_per_1m=model.get_effective_output_price(),
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
cache_read_price_per_1m=model.get_effective_cache_read_price(),
cache_1h_creation_price_per_1m=model.get_effective_1h_cache_creation_price(),
price_per_request=model.get_effective_price_per_request(),
effective_tiered_pricing=effective_tiered,
tier_count=tier_count,
supports_vision=model.get_effective_supports_vision(),
supports_function_calling=model.get_effective_supports_function_calling(),
supports_streaming=model.get_effective_supports_streaming(),
is_active=bool(model.is_active),
)
)
return GlobalModelProvidersResponse(
providers=provider_entries,
total=len(provider_entries),
)

View File

@@ -9,6 +9,7 @@ from urllib.parse import quote, urlparse
import httpx import httpx
from src.config import config
from src.core.logger import logger from src.core.logger import logger
@@ -83,10 +84,10 @@ class HTTPClientPool:
http2=False, # 暂时禁用HTTP/2以提高兼容性 http2=False, # 暂时禁用HTTP/2以提高兼容性
verify=True, # 启用SSL验证 verify=True, # 启用SSL验证
timeout=httpx.Timeout( timeout=httpx.Timeout(
connect=10.0, # 连接超时 connect=config.http_connect_timeout,
read=300.0, # 读取超时(5分钟,适合流式响应) read=config.http_read_timeout,
write=60.0, # 写入超时(60秒,支持大请求体) write=config.http_write_timeout,
pool=5.0, # 连接池超时 pool=config.http_pool_timeout,
), ),
limits=httpx.Limits( limits=httpx.Limits(
max_connections=100, # 最大连接数 max_connections=100, # 最大连接数
@@ -111,15 +112,20 @@ class HTTPClientPool:
""" """
if name not in cls._clients: if name not in cls._clients:
# 合并默认配置和自定义配置 # 合并默认配置和自定义配置
config = { default_config = {
"http2": False, "http2": False,
"verify": True, "verify": True,
"timeout": httpx.Timeout(10.0, read=300.0), "timeout": httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
"follow_redirects": True, "follow_redirects": True,
} }
config.update(kwargs) default_config.update(kwargs)
cls._clients[name] = httpx.AsyncClient(**config) cls._clients[name] = httpx.AsyncClient(**default_config)
logger.debug(f"创建命名HTTP客户端: {name}") logger.debug(f"创建命名HTTP客户端: {name}")
return cls._clients[name] return cls._clients[name]
@@ -151,14 +157,19 @@ class HTTPClientPool:
async with HTTPClientPool.get_temp_client() as client: async with HTTPClientPool.get_temp_client() as client:
response = await client.get('https://example.com') response = await client.get('https://example.com')
""" """
config = { default_config = {
"http2": False, "http2": False,
"verify": True, "verify": True,
"timeout": httpx.Timeout(10.0), "timeout": httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
} }
config.update(kwargs) default_config.update(kwargs)
client = httpx.AsyncClient(**config) client = httpx.AsyncClient(**default_config)
try: try:
yield client yield client
finally: finally:
@@ -182,25 +193,30 @@ class HTTPClientPool:
Returns: Returns:
配置好的 httpx.AsyncClient 实例 配置好的 httpx.AsyncClient 实例
""" """
config: Dict[str, Any] = { client_config: Dict[str, Any] = {
"http2": False, "http2": False,
"verify": True, "verify": True,
"follow_redirects": True, "follow_redirects": True,
} }
if timeout: if timeout:
config["timeout"] = timeout client_config["timeout"] = timeout
else: else:
config["timeout"] = httpx.Timeout(10.0, read=300.0) client_config["timeout"] = httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
)
# 添加代理配置 # 添加代理配置
proxy_url = build_proxy_url(proxy_config) if proxy_config else None proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url: if proxy_url:
config["proxy"] = proxy_url client_config["proxy"] = proxy_url
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}") logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
config.update(kwargs) client_config.update(kwargs)
return httpx.AsyncClient(**config) return httpx.AsyncClient(**client_config)
# 便捷访问函数 # 便捷访问函数

View File

@@ -148,6 +148,7 @@ class Config:
# HTTP 请求超时配置(秒) # HTTP 请求超时配置(秒)
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0")) self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
self.http_read_timeout = float(os.getenv("HTTP_READ_TIMEOUT", "300.0"))
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0")) self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0")) self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))

View File

@@ -360,6 +360,9 @@ def init_db():
注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh 注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh
""" """
import sys
from sqlalchemy.exc import OperationalError
logger.info("初始化数据库...") logger.info("初始化数据库...")
# 确保引擎已创建 # 确保引擎已创建
@@ -382,6 +385,38 @@ def init_db():
db.commit() db.commit()
logger.info("数据库初始化完成") logger.info("数据库初始化完成")
except OperationalError as e:
db.rollback()
# 提取数据库连接信息用于提示
db_url = config.database_url
# 隐藏密码,只显示 host:port/database
if "@" in db_url:
db_info = db_url.split("@")[-1]
else:
db_info = db_url
import os
# 直接打印到 stderr确保消息显示
print("", file=sys.stderr)
print("=" * 60, file=sys.stderr)
print("数据库连接失败", file=sys.stderr)
print("=" * 60, file=sys.stderr)
print("", file=sys.stderr)
print(f"无法连接到数据库: {db_info}", file=sys.stderr)
print("", file=sys.stderr)
print("请检查以下事项:", file=sys.stderr)
print(" 1. PostgreSQL 服务是否正在运行", file=sys.stderr)
print(" 2. 数据库连接配置是否正确 (DATABASE_URL)", file=sys.stderr)
print(" 3. 数据库用户名和密码是否正确", file=sys.stderr)
print("", file=sys.stderr)
print("如果使用 Docker请先运行:", file=sys.stderr)
print(" docker-compose up -d postgres redis", file=sys.stderr)
print("", file=sys.stderr)
print("=" * 60, file=sys.stderr)
# 使用 os._exit 直接退出,避免 uvicorn 捕获并打印堆栈
os._exit(1)
except Exception as e: except Exception as e:
logger.error(f"数据库初始化失败: {e}") logger.error(f"数据库初始化失败: {e}")
db.rollback() db.rollback()

View File

@@ -274,6 +274,13 @@ class GlobalModelListResponse(BaseModel):
total: int total: int
class GlobalModelProvidersResponse(BaseModel):
"""GlobalModel 关联提供商列表响应"""
providers: List[ModelCatalogProviderDetail]
total: int
class BatchAssignToProvidersRequest(BaseModel): class BatchAssignToProvidersRequest(BaseModel):
"""批量为 Provider 添加 GlobalModel 实现""" """批量为 Provider 添加 GlobalModel 实现"""