mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
Compare commits
2 Commits
4a35d78c8d
...
v0.1.29
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b89a4af0cf | ||
|
|
a56854af43 |
@@ -4,7 +4,8 @@ import type {
|
||||
GlobalModelUpdate,
|
||||
GlobalModelResponse,
|
||||
GlobalModelWithStats,
|
||||
GlobalModelListResponse
|
||||
GlobalModelListResponse,
|
||||
ModelCatalogProviderDetail,
|
||||
} from './types'
|
||||
|
||||
/**
|
||||
@@ -83,3 +84,16 @@ export async function batchAssignToProviders(
|
||||
)
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 GlobalModel 的所有关联提供商(包括非活跃的)
|
||||
*/
|
||||
export async function getGlobalModelProviders(globalModelId: string): Promise<{
|
||||
providers: ModelCatalogProviderDetail[]
|
||||
total: number
|
||||
}> {
|
||||
const response = await client.get(
|
||||
`/api/admin/models/global/${globalModelId}/providers`
|
||||
)
|
||||
return response.data
|
||||
}
|
||||
|
||||
@@ -20,4 +20,5 @@ export {
|
||||
updateGlobalModel,
|
||||
deleteGlobalModel,
|
||||
batchAssignToProviders,
|
||||
getGlobalModelProviders,
|
||||
} from './endpoints/global-models'
|
||||
|
||||
@@ -737,6 +737,7 @@ import {
|
||||
updateGlobalModel,
|
||||
deleteGlobalModel,
|
||||
batchAssignToProviders,
|
||||
getGlobalModelProviders,
|
||||
type GlobalModelResponse,
|
||||
} from '@/api/global-models'
|
||||
import { log } from '@/utils/logger'
|
||||
@@ -1080,42 +1081,32 @@ async function selectModel(model: GlobalModelResponse) {
|
||||
async function loadModelProviders(_globalModelId: string) {
|
||||
loadingModelProviders.value = true
|
||||
try {
|
||||
// 使用 ModelCatalog API 获取详细的关联提供商信息
|
||||
const { getModelCatalog } = await import('@/api/endpoints')
|
||||
const catalogResponse = await getModelCatalog()
|
||||
// 使用新的 API 获取所有关联提供商(包括非活跃的)
|
||||
const response = await getGlobalModelProviders(_globalModelId)
|
||||
|
||||
// 查找当前 GlobalModel 对应的 catalog item
|
||||
const catalogItem = catalogResponse.models.find(
|
||||
m => m.global_model_name === selectedModel.value?.name
|
||||
)
|
||||
|
||||
if (catalogItem) {
|
||||
// 转换为展示格式,包含完整的模型实现信息
|
||||
selectedModelProviders.value = catalogItem.providers.map(p => ({
|
||||
id: p.provider_id,
|
||||
model_id: p.model_id,
|
||||
display_name: p.provider_display_name || p.provider_name,
|
||||
identifier: p.provider_name,
|
||||
provider_type: 'API',
|
||||
target_model: p.target_model,
|
||||
is_active: p.is_active,
|
||||
// 价格信息
|
||||
input_price_per_1m: p.input_price_per_1m,
|
||||
output_price_per_1m: p.output_price_per_1m,
|
||||
cache_creation_price_per_1m: p.cache_creation_price_per_1m,
|
||||
cache_read_price_per_1m: p.cache_read_price_per_1m,
|
||||
cache_1h_creation_price_per_1m: p.cache_1h_creation_price_per_1m,
|
||||
price_per_request: p.price_per_request,
|
||||
effective_tiered_pricing: p.effective_tiered_pricing,
|
||||
tier_count: p.tier_count,
|
||||
// 能力信息
|
||||
supports_vision: p.supports_vision,
|
||||
supports_function_calling: p.supports_function_calling,
|
||||
supports_streaming: p.supports_streaming
|
||||
}))
|
||||
} else {
|
||||
selectedModelProviders.value = []
|
||||
}
|
||||
// 转换为展示格式
|
||||
selectedModelProviders.value = response.providers.map(p => ({
|
||||
id: p.provider_id,
|
||||
model_id: p.model_id,
|
||||
display_name: p.provider_display_name || p.provider_name,
|
||||
identifier: p.provider_name,
|
||||
provider_type: 'API',
|
||||
target_model: p.target_model,
|
||||
is_active: p.is_active,
|
||||
// 价格信息
|
||||
input_price_per_1m: p.input_price_per_1m,
|
||||
output_price_per_1m: p.output_price_per_1m,
|
||||
cache_creation_price_per_1m: p.cache_creation_price_per_1m,
|
||||
cache_read_price_per_1m: p.cache_read_price_per_1m,
|
||||
cache_1h_creation_price_per_1m: p.cache_1h_creation_price_per_1m,
|
||||
price_per_request: p.price_per_request,
|
||||
effective_tiered_pricing: p.effective_tiered_pricing,
|
||||
tier_count: p.tier_count,
|
||||
// 能力信息
|
||||
supports_vision: p.supports_vision,
|
||||
supports_function_calling: p.supports_function_calling,
|
||||
supports_streaming: p.supports_streaming
|
||||
}))
|
||||
} catch (err: any) {
|
||||
log.error('加载关联提供商失败:', err)
|
||||
showError(parseApiError(err, '加载关联提供商失败'), '错误')
|
||||
|
||||
@@ -5,7 +5,7 @@ GlobalModel Admin API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -19,9 +19,11 @@ from src.models.pydantic_models import (
|
||||
BatchAssignToProvidersResponse,
|
||||
GlobalModelCreate,
|
||||
GlobalModelListResponse,
|
||||
GlobalModelProvidersResponse,
|
||||
GlobalModelResponse,
|
||||
GlobalModelUpdate,
|
||||
GlobalModelWithStats,
|
||||
ModelCatalogProviderDetail,
|
||||
)
|
||||
from src.services.model.global_model import GlobalModelService
|
||||
|
||||
@@ -108,6 +110,17 @@ async def batch_assign_to_providers(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{global_model_id}/providers", response_model=GlobalModelProvidersResponse)
|
||||
async def get_global_model_providers(
|
||||
request: Request,
|
||||
global_model_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelProvidersResponse:
|
||||
"""获取 GlobalModel 的所有关联提供商(包括非活跃的)"""
|
||||
adapter = AdminGetGlobalModelProvidersAdapter(global_model_id=global_model_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ========== Adapters ==========
|
||||
|
||||
|
||||
@@ -275,3 +288,61 @@ class AdminBatchAssignToProvidersAdapter(AdminApiAdapter):
|
||||
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
|
||||
|
||||
return BatchAssignToProvidersResponse(**result)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetGlobalModelProvidersAdapter(AdminApiAdapter):
|
||||
"""获取 GlobalModel 的所有关联提供商(包括非活跃的)"""
|
||||
|
||||
global_model_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from src.models.database import Model
|
||||
|
||||
global_model = GlobalModelService.get_global_model(context.db, self.global_model_id)
|
||||
|
||||
# 获取所有关联的 Model(包括非活跃的)
|
||||
models = (
|
||||
context.db.query(Model)
|
||||
.options(joinedload(Model.provider), joinedload(Model.global_model))
|
||||
.filter(Model.global_model_id == global_model.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_entries = []
|
||||
for model in models:
|
||||
provider = model.provider
|
||||
if not provider:
|
||||
continue
|
||||
|
||||
effective_tiered = model.get_effective_tiered_pricing()
|
||||
tier_count = len(effective_tiered.get("tiers", [])) if effective_tiered else 1
|
||||
|
||||
provider_entries.append(
|
||||
ModelCatalogProviderDetail(
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.name,
|
||||
provider_display_name=provider.display_name,
|
||||
model_id=model.id,
|
||||
target_model=model.provider_model_name,
|
||||
input_price_per_1m=model.get_effective_input_price(),
|
||||
output_price_per_1m=model.get_effective_output_price(),
|
||||
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||
cache_read_price_per_1m=model.get_effective_cache_read_price(),
|
||||
cache_1h_creation_price_per_1m=model.get_effective_1h_cache_creation_price(),
|
||||
price_per_request=model.get_effective_price_per_request(),
|
||||
effective_tiered_pricing=effective_tiered,
|
||||
tier_count=tier_count,
|
||||
supports_vision=model.get_effective_supports_vision(),
|
||||
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
supports_streaming=model.get_effective_supports_streaming(),
|
||||
is_active=bool(model.is_active),
|
||||
)
|
||||
)
|
||||
|
||||
return GlobalModelProvidersResponse(
|
||||
providers=provider_entries,
|
||||
total=len(provider_entries),
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from urllib.parse import quote, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from src.config import config
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
@@ -83,10 +84,10 @@ class HTTPClientPool:
|
||||
http2=False, # 暂时禁用HTTP/2以提高兼容性
|
||||
verify=True, # 启用SSL验证
|
||||
timeout=httpx.Timeout(
|
||||
connect=10.0, # 连接超时
|
||||
read=300.0, # 读取超时(5分钟,适合流式响应)
|
||||
write=60.0, # 写入超时(60秒,支持大请求体)
|
||||
pool=5.0, # 连接池超时
|
||||
connect=config.http_connect_timeout,
|
||||
read=config.http_read_timeout,
|
||||
write=config.http_write_timeout,
|
||||
pool=config.http_pool_timeout,
|
||||
),
|
||||
limits=httpx.Limits(
|
||||
max_connections=100, # 最大连接数
|
||||
@@ -111,15 +112,20 @@ class HTTPClientPool:
|
||||
"""
|
||||
if name not in cls._clients:
|
||||
# 合并默认配置和自定义配置
|
||||
config = {
|
||||
default_config = {
|
||||
"http2": False,
|
||||
"verify": True,
|
||||
"timeout": httpx.Timeout(10.0, read=300.0),
|
||||
"timeout": httpx.Timeout(
|
||||
connect=config.http_connect_timeout,
|
||||
read=config.http_read_timeout,
|
||||
write=config.http_write_timeout,
|
||||
pool=config.http_pool_timeout,
|
||||
),
|
||||
"follow_redirects": True,
|
||||
}
|
||||
config.update(kwargs)
|
||||
default_config.update(kwargs)
|
||||
|
||||
cls._clients[name] = httpx.AsyncClient(**config)
|
||||
cls._clients[name] = httpx.AsyncClient(**default_config)
|
||||
logger.debug(f"创建命名HTTP客户端: {name}")
|
||||
|
||||
return cls._clients[name]
|
||||
@@ -151,14 +157,19 @@ class HTTPClientPool:
|
||||
async with HTTPClientPool.get_temp_client() as client:
|
||||
response = await client.get('https://example.com')
|
||||
"""
|
||||
config = {
|
||||
default_config = {
|
||||
"http2": False,
|
||||
"verify": True,
|
||||
"timeout": httpx.Timeout(10.0),
|
||||
"timeout": httpx.Timeout(
|
||||
connect=config.http_connect_timeout,
|
||||
read=config.http_read_timeout,
|
||||
write=config.http_write_timeout,
|
||||
pool=config.http_pool_timeout,
|
||||
),
|
||||
}
|
||||
config.update(kwargs)
|
||||
default_config.update(kwargs)
|
||||
|
||||
client = httpx.AsyncClient(**config)
|
||||
client = httpx.AsyncClient(**default_config)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
@@ -182,25 +193,30 @@ class HTTPClientPool:
|
||||
Returns:
|
||||
配置好的 httpx.AsyncClient 实例
|
||||
"""
|
||||
config: Dict[str, Any] = {
|
||||
client_config: Dict[str, Any] = {
|
||||
"http2": False,
|
||||
"verify": True,
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
if timeout:
|
||||
config["timeout"] = timeout
|
||||
client_config["timeout"] = timeout
|
||||
else:
|
||||
config["timeout"] = httpx.Timeout(10.0, read=300.0)
|
||||
client_config["timeout"] = httpx.Timeout(
|
||||
connect=config.http_connect_timeout,
|
||||
read=config.http_read_timeout,
|
||||
write=config.http_write_timeout,
|
||||
pool=config.http_pool_timeout,
|
||||
)
|
||||
|
||||
# 添加代理配置
|
||||
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
|
||||
if proxy_url:
|
||||
config["proxy"] = proxy_url
|
||||
client_config["proxy"] = proxy_url
|
||||
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
|
||||
|
||||
config.update(kwargs)
|
||||
return httpx.AsyncClient(**config)
|
||||
client_config.update(kwargs)
|
||||
return httpx.AsyncClient(**client_config)
|
||||
|
||||
|
||||
# 便捷访问函数
|
||||
|
||||
@@ -148,6 +148,7 @@ class Config:
|
||||
|
||||
# HTTP 请求超时配置(秒)
|
||||
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
||||
self.http_read_timeout = float(os.getenv("HTTP_READ_TIMEOUT", "300.0"))
|
||||
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
||||
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
|
||||
|
||||
|
||||
@@ -360,6 +360,9 @@ def init_db():
|
||||
|
||||
注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh
|
||||
"""
|
||||
import sys
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
logger.info("初始化数据库...")
|
||||
|
||||
# 确保引擎已创建
|
||||
@@ -382,6 +385,38 @@ def init_db():
|
||||
db.commit()
|
||||
logger.info("数据库初始化完成")
|
||||
|
||||
except OperationalError as e:
|
||||
db.rollback()
|
||||
# 提取数据库连接信息用于提示
|
||||
db_url = config.database_url
|
||||
# 隐藏密码,只显示 host:port/database
|
||||
if "@" in db_url:
|
||||
db_info = db_url.split("@")[-1]
|
||||
else:
|
||||
db_info = db_url
|
||||
|
||||
import os
|
||||
|
||||
# 直接打印到 stderr,确保消息显示
|
||||
print("", file=sys.stderr)
|
||||
print("=" * 60, file=sys.stderr)
|
||||
print("数据库连接失败", file=sys.stderr)
|
||||
print("=" * 60, file=sys.stderr)
|
||||
print("", file=sys.stderr)
|
||||
print(f"无法连接到数据库: {db_info}", file=sys.stderr)
|
||||
print("", file=sys.stderr)
|
||||
print("请检查以下事项:", file=sys.stderr)
|
||||
print(" 1. PostgreSQL 服务是否正在运行", file=sys.stderr)
|
||||
print(" 2. 数据库连接配置是否正确 (DATABASE_URL)", file=sys.stderr)
|
||||
print(" 3. 数据库用户名和密码是否正确", file=sys.stderr)
|
||||
print("", file=sys.stderr)
|
||||
print("如果使用 Docker,请先运行:", file=sys.stderr)
|
||||
print(" docker-compose up -d postgres redis", file=sys.stderr)
|
||||
print("", file=sys.stderr)
|
||||
print("=" * 60, file=sys.stderr)
|
||||
# 使用 os._exit 直接退出,避免 uvicorn 捕获并打印堆栈
|
||||
os._exit(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库初始化失败: {e}")
|
||||
db.rollback()
|
||||
|
||||
@@ -274,6 +274,13 @@ class GlobalModelListResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class GlobalModelProvidersResponse(BaseModel):
|
||||
"""GlobalModel 关联提供商列表响应"""
|
||||
|
||||
providers: List[ModelCatalogProviderDetail]
|
||||
total: int
|
||||
|
||||
|
||||
class BatchAssignToProvidersRequest(BaseModel):
|
||||
"""批量为 Provider 添加 GlobalModel 实现"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user