4 Commits

Author SHA1 Message Date
fawney19
394cc536a9 feat: 添加 API 格式访问限制
扩展访问限制功能,支持 API Key 级别的 API 格式限制(OPENAI、CLAUDE、GEMINI)。

- AccessRestrictions 新增 allowed_api_formats 字段
- 新增 is_api_format_allowed() 方法检查格式权限
- models.py 添加 _filter_formats_by_restrictions() 函数过滤 API 格式
- 在所有模型列表和查询端点应用格式限制检查
- 添加 _build_empty_list_response() 统一空响应构建逻辑
2025-12-30 17:50:39 +08:00
fawney19
e20a09f15a feat: 添加模型列表访问限制功能
实现 API Key 和 User 级别的模型访问权限控制,支持按 Provider 和模型名称限制。

- 新增 AccessRestrictions 类处理访问限制合并逻辑(API Key 优先于 User)
- models_service 支持根据限制过滤模型列表
- models.py 在列表查询时构建并应用访问限制
- 优化缓存策略:仅无限制请求使用缓存,有限制的请求旁路缓存
- 修复 logger 配置:enqueue 改为 False 避免 macOS 信号量泄漏
2025-12-30 16:57:59 +08:00
fawney19
b89a4af0cf refactor: 统一 HTTP 客户端超时配置
将 HTTPClientPool 中硬编码的超时参数改为使用可配置的环境变量,提高系统的灵活性和可维护性。

- 添加 HTTP_READ_TIMEOUT 环境变量配置(默认 300 秒)
- 统一所有 HTTP 客户端创建逻辑使用配置化超时
- 改进变量命名清晰性(config -> default_config 或 client_config)
2025-12-30 15:06:55 +08:00
fawney19
a56854af43 feat: 为 GlobalModel 添加关联提供商查询 API
添加新的 API 端点 GET /api/admin/models/global/{global_model_id}/providers,用于获取 GlobalModel 的所有关联提供商(包括非活跃的)。

- 后端:实现 AdminGetGlobalModelProvidersAdapter 适配器
- 前端:使用新 API 替换原有的 ModelCatalog 获取方式
- 数据库:改进初始化时的错误提示和连接异常处理
2025-12-30 14:47:35 +08:00
11 changed files with 405 additions and 78 deletions

View File

@@ -4,7 +4,8 @@ import type {
GlobalModelUpdate, GlobalModelUpdate,
GlobalModelResponse, GlobalModelResponse,
GlobalModelWithStats, GlobalModelWithStats,
GlobalModelListResponse GlobalModelListResponse,
ModelCatalogProviderDetail,
} from './types' } from './types'
/** /**
@@ -83,3 +84,16 @@ export async function batchAssignToProviders(
) )
return response.data return response.data
} }
/**
* 获取 GlobalModel 的所有关联提供商(包括非活跃的)
*/
export async function getGlobalModelProviders(globalModelId: string): Promise<{
providers: ModelCatalogProviderDetail[]
total: number
}> {
const response = await client.get(
`/api/admin/models/global/${globalModelId}/providers`
)
return response.data
}

View File

@@ -20,4 +20,5 @@ export {
updateGlobalModel, updateGlobalModel,
deleteGlobalModel, deleteGlobalModel,
batchAssignToProviders, batchAssignToProviders,
getGlobalModelProviders,
} from './endpoints/global-models' } from './endpoints/global-models'

View File

@@ -737,6 +737,7 @@ import {
updateGlobalModel, updateGlobalModel,
deleteGlobalModel, deleteGlobalModel,
batchAssignToProviders, batchAssignToProviders,
getGlobalModelProviders,
type GlobalModelResponse, type GlobalModelResponse,
} from '@/api/global-models' } from '@/api/global-models'
import { log } from '@/utils/logger' import { log } from '@/utils/logger'
@@ -1080,18 +1081,11 @@ async function selectModel(model: GlobalModelResponse) {
async function loadModelProviders(_globalModelId: string) { async function loadModelProviders(_globalModelId: string) {
loadingModelProviders.value = true loadingModelProviders.value = true
try { try {
// 使用 ModelCatalog API 获取详细的关联提供商信息 // 使用新的 API 获取所有关联提供商(包括非活跃的)
const { getModelCatalog } = await import('@/api/endpoints') const response = await getGlobalModelProviders(_globalModelId)
const catalogResponse = await getModelCatalog()
// 查找当前 GlobalModel 对应的 catalog item // 转换为展示格式
const catalogItem = catalogResponse.models.find( selectedModelProviders.value = response.providers.map(p => ({
m => m.global_model_name === selectedModel.value?.name
)
if (catalogItem) {
// 转换为展示格式,包含完整的模型实现信息
selectedModelProviders.value = catalogItem.providers.map(p => ({
id: p.provider_id, id: p.provider_id,
model_id: p.model_id, model_id: p.model_id,
display_name: p.provider_display_name || p.provider_name, display_name: p.provider_display_name || p.provider_name,
@@ -1113,9 +1107,6 @@ async function loadModelProviders(_globalModelId: string) {
supports_function_calling: p.supports_function_calling, supports_function_calling: p.supports_function_calling,
supports_streaming: p.supports_streaming supports_streaming: p.supports_streaming
})) }))
} else {
selectedModelProviders.value = []
}
} catch (err: any) { } catch (err: any) {
log.error('加载关联提供商失败:', err) log.error('加载关联提供商失败:', err)
showError(parseApiError(err, '加载关联提供商失败'), '错误') showError(parseApiError(err, '加载关联提供商失败'), '错误')

View File

@@ -5,7 +5,7 @@ GlobalModel Admin API
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import Optional
from fastapi import APIRouter, Depends, Query, Request from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -19,9 +19,11 @@ from src.models.pydantic_models import (
BatchAssignToProvidersResponse, BatchAssignToProvidersResponse,
GlobalModelCreate, GlobalModelCreate,
GlobalModelListResponse, GlobalModelListResponse,
GlobalModelProvidersResponse,
GlobalModelResponse, GlobalModelResponse,
GlobalModelUpdate, GlobalModelUpdate,
GlobalModelWithStats, GlobalModelWithStats,
ModelCatalogProviderDetail,
) )
from src.services.model.global_model import GlobalModelService from src.services.model.global_model import GlobalModelService
@@ -108,6 +110,17 @@ async def batch_assign_to_providers(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{global_model_id}/providers", response_model=GlobalModelProvidersResponse)
async def get_global_model_providers(
request: Request,
global_model_id: str,
db: Session = Depends(get_db),
) -> GlobalModelProvidersResponse:
"""获取 GlobalModel 的所有关联提供商(包括非活跃的)"""
adapter = AdminGetGlobalModelProvidersAdapter(global_model_id=global_model_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ========== Adapters ========== # ========== Adapters ==========
@@ -275,3 +288,61 @@ class AdminBatchAssignToProvidersAdapter(AdminApiAdapter):
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}") logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
return BatchAssignToProvidersResponse(**result) return BatchAssignToProvidersResponse(**result)
@dataclass
class AdminGetGlobalModelProvidersAdapter(AdminApiAdapter):
"""获取 GlobalModel 的所有关联提供商(包括非活跃的)"""
global_model_id: str
async def handle(self, context): # type: ignore[override]
from sqlalchemy.orm import joinedload
from src.models.database import Model
global_model = GlobalModelService.get_global_model(context.db, self.global_model_id)
# 获取所有关联的 Model包括非活跃的
models = (
context.db.query(Model)
.options(joinedload(Model.provider), joinedload(Model.global_model))
.filter(Model.global_model_id == global_model.id)
.all()
)
provider_entries = []
for model in models:
provider = model.provider
if not provider:
continue
effective_tiered = model.get_effective_tiered_pricing()
tier_count = len(effective_tiered.get("tiers", [])) if effective_tiered else 1
provider_entries.append(
ModelCatalogProviderDetail(
provider_id=provider.id,
provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id,
target_model=model.provider_model_name,
input_price_per_1m=model.get_effective_input_price(),
output_price_per_1m=model.get_effective_output_price(),
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
cache_read_price_per_1m=model.get_effective_cache_read_price(),
cache_1h_creation_price_per_1m=model.get_effective_1h_cache_creation_price(),
price_per_request=model.get_effective_price_per_request(),
effective_tiered_pricing=effective_tiered,
tier_count=tier_count,
supports_vision=model.get_effective_supports_vision(),
supports_function_calling=model.get_effective_supports_function_calling(),
supports_streaming=model.get_effective_supports_streaming(),
is_active=bool(model.is_active),
)
)
return GlobalModelProvidersResponse(
providers=provider_entries,
total=len(provider_entries),
)

View File

@@ -18,7 +18,15 @@ from sqlalchemy.orm import Session, joinedload
from src.config.constants import CacheTTL from src.config.constants import CacheTTL
from src.core.cache_service import CacheService from src.core.cache_service import CacheService
from src.core.logger import logger from src.core.logger import logger
from src.models.database import GlobalModel, Model, Provider, ProviderAPIKey, ProviderEndpoint from src.models.database import (
ApiKey,
GlobalModel,
Model,
Provider,
ProviderAPIKey,
ProviderEndpoint,
User,
)
# 缓存 key 前缀 # 缓存 key 前缀
_CACHE_KEY_PREFIX = "models:list" _CACHE_KEY_PREFIX = "models:list"
@@ -82,6 +90,7 @@ class ModelInfo:
created_at: Optional[str] # ISO 格式 created_at: Optional[str] # ISO 格式
created_timestamp: int # Unix 时间戳 created_timestamp: int # Unix 时间戳
provider_name: str provider_name: str
provider_id: str = "" # Provider ID用于权限过滤
# 能力配置 # 能力配置
streaming: bool = True streaming: bool = True
vision: bool = False vision: bool = False
@@ -99,6 +108,92 @@ class ModelInfo:
output_modalities: Optional[list[str]] = None output_modalities: Optional[list[str]] = None
@dataclass
class AccessRestrictions:
"""API Key 或 User 的访问限制"""
allowed_providers: Optional[list[str]] = None # 允许的 Provider ID 列表
allowed_models: Optional[list[str]] = None # 允许的模型名称列表
allowed_api_formats: Optional[list[str]] = None # 允许的 API 格式列表
@classmethod
def from_api_key_and_user(
cls, api_key: Optional[ApiKey], user: Optional[User]
) -> "AccessRestrictions":
"""
从 API Key 和 User 合并访问限制
限制逻辑:
- API Key 的限制优先于 User 的限制
- 如果 API Key 有限制,使用 API Key 的限制
- 如果 API Key 无限制但 User 有限制,使用 User 的限制
- 两者都无限制则返回空限制
"""
allowed_providers: Optional[list[str]] = None
allowed_models: Optional[list[str]] = None
allowed_api_formats: Optional[list[str]] = None
# 优先使用 API Key 的限制
if api_key:
if api_key.allowed_providers is not None:
allowed_providers = api_key.allowed_providers
if api_key.allowed_models is not None:
allowed_models = api_key.allowed_models
if api_key.allowed_api_formats is not None:
allowed_api_formats = api_key.allowed_api_formats
# 如果 API Key 没有限制,检查 User 的限制
# 注意: User 没有 allowed_api_formats 字段
if user:
if allowed_providers is None and user.allowed_providers is not None:
allowed_providers = user.allowed_providers
if allowed_models is None and user.allowed_models is not None:
allowed_models = user.allowed_models
return cls(
allowed_providers=allowed_providers,
allowed_models=allowed_models,
allowed_api_formats=allowed_api_formats,
)
def is_api_format_allowed(self, api_format: str) -> bool:
"""
检查 API 格式是否被允许
Args:
api_format: API 格式 (如 "OPENAI", "CLAUDE", "GEMINI")
Returns:
True 如果格式被允许False 否则
"""
if self.allowed_api_formats is None:
return True
return api_format in self.allowed_api_formats
def is_model_allowed(self, model_id: str, provider_id: str) -> bool:
"""
检查模型是否被允许访问
Args:
model_id: 模型 ID
provider_id: Provider ID
Returns:
True 如果模型被允许False 否则
"""
# 检查 Provider 限制
if self.allowed_providers is not None:
if provider_id not in self.allowed_providers:
return False
# 检查模型限制
if self.allowed_models is not None:
if model_id not in self.allowed_models:
return False
return True
def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]: def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
""" """
返回有可用端点的 Provider IDs 返回有可用端点的 Provider IDs
@@ -218,6 +313,7 @@ def _extract_model_info(model: Any) -> ModelInfo:
) )
created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0 created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0
provider_name: str = model.provider.name if model.provider else "unknown" provider_name: str = model.provider.name if model.provider else "unknown"
provider_id: str = model.provider_id or ""
# 从 GlobalModel.config 提取配置信息 # 从 GlobalModel.config 提取配置信息
config: dict = {} config: dict = {}
@@ -233,6 +329,7 @@ def _extract_model_info(model: Any) -> ModelInfo:
created_at=created_at, created_at=created_at,
created_timestamp=created_timestamp, created_timestamp=created_timestamp,
provider_name=provider_name, provider_name=provider_name,
provider_id=provider_id,
# 能力配置 # 能力配置
streaming=config.get("streaming", True), streaming=config.get("streaming", True),
vision=config.get("vision", False), vision=config.get("vision", False),
@@ -255,6 +352,7 @@ async def list_available_models(
db: Session, db: Session,
available_provider_ids: set[str], available_provider_ids: set[str],
api_formats: Optional[list[str]] = None, api_formats: Optional[list[str]] = None,
restrictions: Optional[AccessRestrictions] = None,
) -> list[ModelInfo]: ) -> list[ModelInfo]:
""" """
获取可用模型列表(已去重,带缓存) 获取可用模型列表(已去重,带缓存)
@@ -263,6 +361,7 @@ async def list_available_models(
db: 数据库会话 db: 数据库会话
available_provider_ids: 有可用端点的 Provider ID 集合 available_provider_ids: 有可用端点的 Provider ID 集合
api_formats: API 格式列表,用于检查 Key 的 allowed_models api_formats: API 格式列表,用于检查 Key 的 allowed_models
restrictions: API Key/User 的访问限制
Returns: Returns:
去重后的 ModelInfo 列表,按创建时间倒序 去重后的 ModelInfo 列表,按创建时间倒序
@@ -270,8 +369,16 @@ async def list_available_models(
if not available_provider_ids: if not available_provider_ids:
return [] return []
# 缓存策略:只有完全无访问限制时才使用缓存
# - restrictions is None: 未传入限制对象
# - restrictions 的两个字段都为 None: 传入了限制对象但无实际限制
# 以上两种情况返回的结果相同,可以共享全局缓存
use_cache = restrictions is None or (
restrictions.allowed_providers is None and restrictions.allowed_models is None
)
# 尝试从缓存获取 # 尝试从缓存获取
if api_formats: if api_formats and use_cache:
cached = await _get_cached_models(api_formats) cached = await _get_cached_models(api_formats)
if cached is not None: if cached is not None:
return cached return cached
@@ -306,14 +413,19 @@ async def list_available_models(
if available_model_ids is not None and info.id not in available_model_ids: if available_model_ids is not None and info.id not in available_model_ids:
continue continue
# 检查 API Key/User 访问限制
if restrictions is not None:
if not restrictions.is_model_allowed(info.id, info.provider_id):
continue
if info.id in seen_model_ids: if info.id in seen_model_ids:
continue continue
seen_model_ids.add(info.id) seen_model_ids.add(info.id)
result.append(info) result.append(info)
# 写入缓存 # 只有无限制的情况才写入缓存
if api_formats: if api_formats and use_cache:
await _set_cached_models(api_formats, result) await _set_cached_models(api_formats, result)
return result return result
@@ -324,6 +436,7 @@ def find_model_by_id(
model_id: str, model_id: str,
available_provider_ids: set[str], available_provider_ids: set[str],
api_formats: Optional[list[str]] = None, api_formats: Optional[list[str]] = None,
restrictions: Optional[AccessRestrictions] = None,
) -> Optional[ModelInfo]: ) -> Optional[ModelInfo]:
""" """
按 ID 查找模型 按 ID 查找模型
@@ -338,6 +451,7 @@ def find_model_by_id(
model_id: 模型 ID model_id: 模型 ID
available_provider_ids: 有可用端点的 Provider ID 集合 available_provider_ids: 有可用端点的 Provider ID 集合
api_formats: API 格式列表,用于检查 Key 的 allowed_models api_formats: API 格式列表,用于检查 Key 的 allowed_models
restrictions: API Key/User 的访问限制
Returns: Returns:
ModelInfo 或 None ModelInfo 或 None
@@ -353,6 +467,11 @@ def find_model_by_id(
if available_model_ids is not None and model_id not in available_model_ids: if available_model_ids is not None and model_id not in available_model_ids:
return None return None
# 快速检查:如果 restrictions 明确限制了模型列表且目标模型不在其中,直接返回 None
if restrictions is not None and restrictions.allowed_models is not None:
if model_id not in restrictions.allowed_models:
return None
# 先按 GlobalModel.name 查找 # 先按 GlobalModel.name 查找
models_by_global = ( models_by_global = (
db.query(Model) db.query(Model)
@@ -368,8 +487,19 @@ def find_model_by_id(
.all() .all()
) )
def is_model_accessible(m: Model) -> bool:
"""检查模型是否可访问"""
if m.provider_id not in available_provider_ids:
return False
# 检查 API Key/User 访问限制
if restrictions is not None:
provider_id = m.provider_id or ""
if not restrictions.is_model_allowed(model_id, provider_id):
return False
return True
model = next( model = next(
(m for m in models_by_global if m.provider_id in available_provider_ids), (m for m in models_by_global if is_model_accessible(m)),
None, None,
) )
@@ -393,7 +523,7 @@ def find_model_by_id(
) )
model = next( model = next(
(m for m in models_by_provider_name if m.provider_id in available_provider_ids), (m for m in models_by_provider_name if is_model_accessible(m)),
None, None,
) )

View File

@@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.api.base.models_service import ( from src.api.base.models_service import (
AccessRestrictions,
ModelInfo, ModelInfo,
find_model_by_id, find_model_by_id,
get_available_provider_ids, get_available_provider_ids,
@@ -103,6 +104,35 @@ def _get_formats_for_api(api_format: str) -> list[str]:
return _OPENAI_FORMATS return _OPENAI_FORMATS
def _build_empty_list_response(api_format: str) -> dict:
"""根据 API 格式构建空列表响应"""
if api_format == "claude":
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
elif api_format == "gemini":
return {"models": []}
else:
return {"object": "list", "data": []}
def _filter_formats_by_restrictions(
formats: list[str], restrictions: AccessRestrictions, api_format: str
) -> Tuple[list[str], Optional[dict]]:
"""
根据访问限制过滤 API 格式
Returns:
(过滤后的格式列表, 空响应或None)
如果过滤后为空,返回对应格式的空响应
"""
if restrictions.allowed_api_formats is None:
return formats, None
filtered = [f for f in formats if f in restrictions.allowed_api_formats]
if not filtered:
logger.info(f"[Models] API Key 不允许访问格式 {api_format}")
return [], _build_empty_list_response(api_format)
return filtered, None
def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]: def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]:
""" """
认证 API Key 认证 API Key
@@ -375,22 +405,24 @@ async def list_models(
logger.info(f"[Models] GET /v1/models | format={api_format}") logger.info(f"[Models] GET /v1/models | format={api_format}")
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response(api_format) return _build_auth_error_response(api_format)
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
# 检查 API 格式限制
formats = _get_formats_for_api(api_format) formats = _get_formats_for_api(api_format)
formats, empty_response = _filter_formats_by_restrictions(formats, restrictions, api_format)
if empty_response is not None:
return empty_response
available_provider_ids = get_available_provider_ids(db, formats) available_provider_ids = get_available_provider_ids(db, formats)
if not available_provider_ids: if not available_provider_ids:
if api_format == "claude": return _build_empty_list_response(api_format)
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
elif api_format == "gemini":
return {"models": []}
else:
return {"object": "list", "data": []}
models = await list_available_models(db, available_provider_ids, formats) models = await list_available_models(db, available_provider_ids, formats, restrictions)
logger.debug(f"[Models] 返回 {len(models)} 个模型") logger.debug(f"[Models] 返回 {len(models)} 个模型")
if api_format == "claude": if api_format == "claude":
@@ -419,14 +451,21 @@ async def retrieve_model(
logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}") logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}")
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response(api_format) return _build_auth_error_response(api_format)
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
# 检查 API 格式限制
formats = _get_formats_for_api(api_format) formats = _get_formats_for_api(api_format)
formats, _ = _filter_formats_by_restrictions(formats, restrictions, api_format)
if not formats:
return _build_404_response(model_id, api_format)
available_provider_ids = get_available_provider_ids(db, formats) available_provider_ids = get_available_provider_ids(db, formats)
model_info = find_model_by_id(db, model_id, available_provider_ids, formats) model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions)
if not model_info: if not model_info:
return _build_404_response(model_id, api_format) return _build_404_response(model_id, api_format)
@@ -455,15 +494,25 @@ async def list_models_gemini(
api_key = _extract_api_key_from_request(request, gemini_def) api_key = _extract_api_key_from_request(request, gemini_def)
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response("gemini") return _build_auth_error_response("gemini")
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) # 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
# 检查 API 格式限制
formats, empty_response = _filter_formats_by_restrictions(
_GEMINI_FORMATS, restrictions, "gemini"
)
if empty_response is not None:
return empty_response
available_provider_ids = get_available_provider_ids(db, formats)
if not available_provider_ids: if not available_provider_ids:
return {"models": []} return {"models": []}
models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS) models = await list_available_models(db, available_provider_ids, formats, restrictions)
logger.debug(f"[Models] 返回 {len(models)} 个模型") logger.debug(f"[Models] 返回 {len(models)} 个模型")
response = _build_gemini_list_response(models, page_size, page_token) response = _build_gemini_list_response(models, page_size, page_token)
logger.debug(f"[Models] Gemini 响应: {response}") logger.debug(f"[Models] Gemini 响应: {response}")
@@ -486,12 +535,22 @@ async def get_model_gemini(
api_key = _extract_api_key_from_request(request, gemini_def) api_key = _extract_api_key_from_request(request, gemini_def)
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response("gemini") return _build_auth_error_response("gemini")
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) # 构建访问限制
model_info = find_model_by_id(db, model_id, available_provider_ids, _GEMINI_FORMATS) restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
# 检查 API 格式限制
formats, _ = _filter_formats_by_restrictions(_GEMINI_FORMATS, restrictions, "gemini")
if not formats:
return _build_404_response(model_id, "gemini")
available_provider_ids = get_available_provider_ids(db, formats)
model_info = find_model_by_id(
db, model_id, available_provider_ids, formats, restrictions
)
if not model_info: if not model_info:
return _build_404_response(model_id, "gemini") return _build_404_response(model_id, "gemini")

View File

@@ -9,6 +9,7 @@ from urllib.parse import quote, urlparse
import httpx import httpx
from src.config import config
from src.core.logger import logger from src.core.logger import logger
@@ -83,10 +84,10 @@ class HTTPClientPool:
http2=False, # 暂时禁用HTTP/2以提高兼容性 http2=False, # 暂时禁用HTTP/2以提高兼容性
verify=True, # 启用SSL验证 verify=True, # 启用SSL验证
timeout=httpx.Timeout( timeout=httpx.Timeout(
connect=10.0, # 连接超时 connect=config.http_connect_timeout,
read=300.0, # 读取超时(5分钟,适合流式响应) read=config.http_read_timeout,
write=60.0, # 写入超时(60秒,支持大请求体) write=config.http_write_timeout,
pool=5.0, # 连接池超时 pool=config.http_pool_timeout,
), ),
limits=httpx.Limits( limits=httpx.Limits(
max_connections=100, # 最大连接数 max_connections=100, # 最大连接数
@@ -111,15 +112,20 @@ class HTTPClientPool:
""" """
if name not in cls._clients: if name not in cls._clients:
# 合并默认配置和自定义配置 # 合并默认配置和自定义配置
config = { default_config = {
"http2": False, "http2": False,
"verify": True, "verify": True,
"timeout": httpx.Timeout(10.0, read=300.0), "timeout": httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
"follow_redirects": True, "follow_redirects": True,
} }
config.update(kwargs) default_config.update(kwargs)
cls._clients[name] = httpx.AsyncClient(**config) cls._clients[name] = httpx.AsyncClient(**default_config)
logger.debug(f"创建命名HTTP客户端: {name}") logger.debug(f"创建命名HTTP客户端: {name}")
return cls._clients[name] return cls._clients[name]
@@ -151,14 +157,19 @@ class HTTPClientPool:
async with HTTPClientPool.get_temp_client() as client: async with HTTPClientPool.get_temp_client() as client:
response = await client.get('https://example.com') response = await client.get('https://example.com')
""" """
config = { default_config = {
"http2": False, "http2": False,
"verify": True, "verify": True,
"timeout": httpx.Timeout(10.0), "timeout": httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
} }
config.update(kwargs) default_config.update(kwargs)
client = httpx.AsyncClient(**config) client = httpx.AsyncClient(**default_config)
try: try:
yield client yield client
finally: finally:
@@ -182,25 +193,30 @@ class HTTPClientPool:
Returns: Returns:
配置好的 httpx.AsyncClient 实例 配置好的 httpx.AsyncClient 实例
""" """
config: Dict[str, Any] = { client_config: Dict[str, Any] = {
"http2": False, "http2": False,
"verify": True, "verify": True,
"follow_redirects": True, "follow_redirects": True,
} }
if timeout: if timeout:
config["timeout"] = timeout client_config["timeout"] = timeout
else: else:
config["timeout"] = httpx.Timeout(10.0, read=300.0) client_config["timeout"] = httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
)
# 添加代理配置 # 添加代理配置
proxy_url = build_proxy_url(proxy_config) if proxy_config else None proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url: if proxy_url:
config["proxy"] = proxy_url client_config["proxy"] = proxy_url
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}") logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
config.update(kwargs) client_config.update(kwargs)
return httpx.AsyncClient(**config) return httpx.AsyncClient(**client_config)
# 便捷访问函数 # 便捷访问函数

View File

@@ -148,6 +148,7 @@ class Config:
# HTTP 请求超时配置(秒) # HTTP 请求超时配置(秒)
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0")) self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
self.http_read_timeout = float(os.getenv("HTTP_READ_TIMEOUT", "300.0"))
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0")) self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0")) self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))

View File

@@ -96,13 +96,15 @@ if not DISABLE_FILE_LOG:
log_dir.mkdir(exist_ok=True) log_dir.mkdir(exist_ok=True)
# 文件日志通用配置 # 文件日志通用配置
# 注意: enqueue=False 使用同步模式,避免 multiprocessing 信号量泄漏
# 在 macOS 上,进程异常退出时 POSIX 信号量不会自动释放,导致资源耗尽
file_log_config = { file_log_config = {
"format": FILE_FORMAT, "format": FILE_FORMAT,
"filter": _log_filter, "filter": _log_filter,
"rotation": "100 MB", "rotation": "100 MB",
"retention": "30 days", "retention": "30 days",
"compression": "gz", "compression": "gz",
"enqueue": True, "enqueue": False,
"encoding": "utf-8", "encoding": "utf-8",
"catch": True, "catch": True,
} }

View File

@@ -360,6 +360,9 @@ def init_db():
注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh 注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh
""" """
import sys
from sqlalchemy.exc import OperationalError
logger.info("初始化数据库...") logger.info("初始化数据库...")
# 确保引擎已创建 # 确保引擎已创建
@@ -382,6 +385,38 @@ def init_db():
db.commit() db.commit()
logger.info("数据库初始化完成") logger.info("数据库初始化完成")
except OperationalError as e:
db.rollback()
# 提取数据库连接信息用于提示
db_url = config.database_url
# 隐藏密码,只显示 host:port/database
if "@" in db_url:
db_info = db_url.split("@")[-1]
else:
db_info = db_url
import os
# 直接打印到 stderr确保消息显示
print("", file=sys.stderr)
print("=" * 60, file=sys.stderr)
print("数据库连接失败", file=sys.stderr)
print("=" * 60, file=sys.stderr)
print("", file=sys.stderr)
print(f"无法连接到数据库: {db_info}", file=sys.stderr)
print("", file=sys.stderr)
print("请检查以下事项:", file=sys.stderr)
print(" 1. PostgreSQL 服务是否正在运行", file=sys.stderr)
print(" 2. 数据库连接配置是否正确 (DATABASE_URL)", file=sys.stderr)
print(" 3. 数据库用户名和密码是否正确", file=sys.stderr)
print("", file=sys.stderr)
print("如果使用 Docker请先运行:", file=sys.stderr)
print(" docker-compose up -d postgres redis", file=sys.stderr)
print("", file=sys.stderr)
print("=" * 60, file=sys.stderr)
# 使用 os._exit 直接退出,避免 uvicorn 捕获并打印堆栈
os._exit(1)
except Exception as e: except Exception as e:
logger.error(f"数据库初始化失败: {e}") logger.error(f"数据库初始化失败: {e}")
db.rollback() db.rollback()

View File

@@ -274,6 +274,13 @@ class GlobalModelListResponse(BaseModel):
total: int total: int
class GlobalModelProvidersResponse(BaseModel):
"""GlobalModel 关联提供商列表响应"""
providers: List[ModelCatalogProviderDetail]
total: int
class BatchAssignToProvidersRequest(BaseModel): class BatchAssignToProvidersRequest(BaseModel):
"""批量为 Provider 添加 GlobalModel 实现""" """批量为 Provider 添加 GlobalModel 实现"""