mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
422 lines
17 KiB
Python
422 lines
17 KiB
Python
"""
|
||
模型管理服务
|
||
"""
|
||
|
||
import asyncio
|
||
from typing import List, Optional
|
||
|
||
from sqlalchemy import and_
|
||
from sqlalchemy.exc import IntegrityError
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||
from src.core.logger import logger
|
||
from src.models.api import ModelCreate, ModelResponse, ModelUpdate
|
||
from src.models.database import Model, Provider
|
||
from src.api.base.models_service import invalidate_models_list_cache
|
||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||
from src.services.cache.model_cache import ModelCacheService
|
||
|
||
|
||
|
||
class ModelService:
|
||
"""模型管理服务"""
|
||
|
||
@staticmethod
|
||
def create_model(db: Session, provider_id: str, model_data: ModelCreate) -> Model:
|
||
"""创建模型"""
|
||
# 检查提供商是否存在
|
||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||
if not provider:
|
||
raise NotFoundException(f"提供商 {provider_id} 不存在")
|
||
|
||
# 检查同一提供商下是否已存在同名模型
|
||
existing = (
|
||
db.query(Model)
|
||
.filter(
|
||
and_(
|
||
Model.provider_id == provider_id,
|
||
Model.provider_model_name == model_data.provider_model_name,
|
||
)
|
||
)
|
||
.first()
|
||
)
|
||
if existing:
|
||
raise InvalidRequestException(
|
||
f"提供商 {provider.name} 下已存在模型 {model_data.provider_model_name}"
|
||
)
|
||
|
||
try:
|
||
model = Model(
|
||
provider_id=provider_id,
|
||
global_model_id=model_data.global_model_id,
|
||
provider_model_name=model_data.provider_model_name,
|
||
provider_model_mappings=model_data.provider_model_mappings,
|
||
price_per_request=model_data.price_per_request,
|
||
tiered_pricing=model_data.tiered_pricing,
|
||
supports_vision=model_data.supports_vision,
|
||
supports_function_calling=model_data.supports_function_calling,
|
||
supports_streaming=model_data.supports_streaming,
|
||
supports_extended_thinking=model_data.supports_extended_thinking,
|
||
is_active=model_data.is_active if model_data.is_active is not None else True,
|
||
config=model_data.config,
|
||
)
|
||
db.add(model)
|
||
db.commit()
|
||
db.refresh(model)
|
||
# 显式加载 global_model 关系
|
||
if model.global_model_id:
|
||
from sqlalchemy.orm import joinedload
|
||
|
||
model = (
|
||
db.query(Model)
|
||
.options(joinedload(Model.global_model))
|
||
.filter(Model.id == model.id)
|
||
.first()
|
||
)
|
||
|
||
logger.info(f"创建模型成功: provider={provider.name}, model={model.provider_model_name}, global_model_id={model.global_model_id}")
|
||
|
||
# 清除 /v1/models 列表缓存
|
||
asyncio.create_task(invalidate_models_list_cache())
|
||
|
||
return model
|
||
|
||
except IntegrityError as e:
|
||
db.rollback()
|
||
logger.error(f"创建模型失败: {str(e)}")
|
||
raise InvalidRequestException("创建模型失败,请检查输入数据")
|
||
|
||
@staticmethod
|
||
def get_model(db: Session, model_id: str) -> Model: # UUID
|
||
"""获取模型详情"""
|
||
from sqlalchemy.orm import joinedload
|
||
|
||
model = (
|
||
db.query(Model)
|
||
.options(joinedload(Model.global_model))
|
||
.filter(Model.id == model_id)
|
||
.first()
|
||
)
|
||
if not model:
|
||
raise NotFoundException(f"模型 {model_id} 不存在")
|
||
return model
|
||
|
||
@staticmethod
|
||
def get_models_by_provider(
|
||
db: Session,
|
||
provider_id: str, # UUID
|
||
skip: int = 0,
|
||
limit: int = 100,
|
||
is_active: Optional[bool] = None,
|
||
) -> List[Model]:
|
||
"""获取提供商的模型列表"""
|
||
from sqlalchemy.orm import joinedload
|
||
|
||
query = (
|
||
db.query(Model)
|
||
.options(joinedload(Model.global_model))
|
||
.filter(Model.provider_id == provider_id)
|
||
)
|
||
|
||
if is_active is not None:
|
||
query = query.filter(Model.is_active == is_active)
|
||
|
||
# 按创建时间排序
|
||
query = query.order_by(Model.created_at.desc())
|
||
|
||
return query.offset(skip).limit(limit).all()
|
||
|
||
@staticmethod
|
||
def get_all_models(
|
||
db: Session,
|
||
skip: int = 0,
|
||
limit: int = 100,
|
||
is_active: Optional[bool] = None,
|
||
category: Optional[str] = None,
|
||
) -> List[Model]:
|
||
"""获取所有模型列表"""
|
||
query = db.query(Model)
|
||
|
||
if is_active is not None:
|
||
query = query.filter(Model.is_active == is_active)
|
||
|
||
# 按提供商和创建时间排序
|
||
query = query.order_by(Model.provider_id, Model.created_at.desc())
|
||
|
||
return query.offset(skip).limit(limit).all()
|
||
|
||
@staticmethod
|
||
def update_model(db: Session, model_id: str, model_data: ModelUpdate) -> Model: # UUID
|
||
"""更新模型"""
|
||
model = db.query(Model).filter(Model.id == model_id).first()
|
||
if not model:
|
||
raise NotFoundException(f"模型 {model_id} 不存在")
|
||
|
||
# 保存旧的映射,用于清除缓存
|
||
old_provider_model_name = model.provider_model_name
|
||
old_provider_model_mappings = model.provider_model_mappings
|
||
|
||
# 更新字段
|
||
update_data = model_data.model_dump(exclude_unset=True)
|
||
|
||
# 添加调试日志
|
||
logger.debug(f"更新模型 {model_id} 收到的数据: {update_data}")
|
||
logger.debug(f"更新前的 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
|
||
|
||
for field, value in update_data.items():
|
||
setattr(model, field, value)
|
||
|
||
logger.debug(f"更新后的 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
|
||
|
||
try:
|
||
db.commit()
|
||
db.refresh(model)
|
||
|
||
# 清除 Redis 缓存(异步执行,不阻塞返回)
|
||
# 先清除旧的映射缓存
|
||
asyncio.create_task(
|
||
ModelCacheService.invalidate_model_cache(
|
||
model_id=model.id,
|
||
provider_id=model.provider_id,
|
||
global_model_id=model.global_model_id,
|
||
provider_model_name=old_provider_model_name,
|
||
provider_model_mappings=old_provider_model_mappings,
|
||
)
|
||
)
|
||
# 再清除新的映射缓存(如果有变化)
|
||
if (model.provider_model_name != old_provider_model_name or
|
||
model.provider_model_mappings != old_provider_model_mappings):
|
||
asyncio.create_task(
|
||
ModelCacheService.invalidate_model_cache(
|
||
model_id=model.id,
|
||
provider_id=model.provider_id,
|
||
global_model_id=model.global_model_id,
|
||
provider_model_name=model.provider_model_name,
|
||
provider_model_mappings=model.provider_model_mappings,
|
||
)
|
||
)
|
||
|
||
# 清除内存缓存(ModelMapperMiddleware 实例)
|
||
if model.provider_id and model.global_model_id:
|
||
cache_service = get_cache_invalidation_service()
|
||
cache_service.on_model_changed(model.provider_id, model.global_model_id)
|
||
|
||
# 清除 /v1/models 列表缓存
|
||
asyncio.create_task(invalidate_models_list_cache())
|
||
|
||
logger.info(f"更新模型成功: id={model_id}, 最终 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
|
||
return model
|
||
except IntegrityError as e:
|
||
db.rollback()
|
||
logger.error(f"更新模型失败: {str(e)}")
|
||
raise InvalidRequestException("更新模型失败,请检查输入数据")
|
||
|
||
@staticmethod
|
||
def delete_model(db: Session, model_id: str): # UUID
|
||
"""删除模型
|
||
|
||
新架构删除逻辑:
|
||
- Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel
|
||
- 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除)
|
||
"""
|
||
model = db.query(Model).filter(Model.id == model_id).first()
|
||
if not model:
|
||
raise NotFoundException(f"模型 {model_id} 不存在")
|
||
|
||
# 检查这是否是该 GlobalModel 的最后一个关联提供商
|
||
if model.global_model_id:
|
||
other_implementations = (
|
||
db.query(Model)
|
||
.filter(
|
||
Model.global_model_id == model.global_model_id,
|
||
Model.id != model_id,
|
||
Model.is_active == True,
|
||
)
|
||
.count()
|
||
)
|
||
|
||
if other_implementations == 0:
|
||
logger.warning(f"警告:删除模型 {model_id}(Provider: {model.provider_id[:8]}...)后,"
|
||
f"GlobalModel '{model.global_model_id}' 将没有任何活跃的关联提供商")
|
||
|
||
# 保存缓存清除所需的信息(删除后无法访问)
|
||
cache_info = {
|
||
"model_id": model.id,
|
||
"provider_id": model.provider_id,
|
||
"global_model_id": model.global_model_id,
|
||
"provider_model_name": model.provider_model_name,
|
||
"provider_model_mappings": model.provider_model_mappings,
|
||
}
|
||
|
||
try:
|
||
db.delete(model)
|
||
db.commit()
|
||
|
||
# 清除 Redis 缓存
|
||
asyncio.create_task(
|
||
ModelCacheService.invalidate_model_cache(
|
||
model_id=cache_info["model_id"],
|
||
provider_id=cache_info["provider_id"],
|
||
global_model_id=cache_info["global_model_id"],
|
||
provider_model_name=cache_info["provider_model_name"],
|
||
provider_model_mappings=cache_info["provider_model_mappings"],
|
||
)
|
||
)
|
||
|
||
# 清除内存缓存
|
||
if cache_info["provider_id"] and cache_info["global_model_id"]:
|
||
cache_service = get_cache_invalidation_service()
|
||
cache_service.on_model_changed(cache_info["provider_id"], cache_info["global_model_id"])
|
||
|
||
# 清除 /v1/models 列表缓存
|
||
asyncio.create_task(invalidate_models_list_cache())
|
||
|
||
logger.info(f"删除模型成功: id={model_id}, provider_model_name={cache_info['provider_model_name']}, "
|
||
f"global_model_id={cache_info['global_model_id'][:8] if cache_info['global_model_id'] else 'None'}...")
|
||
except Exception as e:
|
||
db.rollback()
|
||
logger.error(f"删除模型失败: {str(e)}")
|
||
raise InvalidRequestException("删除模型失败")
|
||
|
||
@staticmethod
|
||
def toggle_model_availability(db: Session, model_id: str, is_available: bool) -> Model: # UUID
|
||
"""切换模型可用状态"""
|
||
model = db.query(Model).filter(Model.id == model_id).first()
|
||
if not model:
|
||
raise NotFoundException(f"模型 {model_id} 不存在")
|
||
|
||
model.is_available = is_available
|
||
db.commit()
|
||
db.refresh(model)
|
||
|
||
# 清除 Redis 缓存
|
||
asyncio.create_task(
|
||
ModelCacheService.invalidate_model_cache(
|
||
model_id=model.id,
|
||
provider_id=model.provider_id,
|
||
global_model_id=model.global_model_id,
|
||
provider_model_name=model.provider_model_name,
|
||
provider_model_mappings=model.provider_model_mappings,
|
||
)
|
||
)
|
||
|
||
# 清除内存缓存(ModelMapperMiddleware 实例)
|
||
if model.provider_id and model.global_model_id:
|
||
cache_service = get_cache_invalidation_service()
|
||
cache_service.on_model_changed(model.provider_id, model.global_model_id)
|
||
|
||
# 清除 /v1/models 列表缓存
|
||
asyncio.create_task(invalidate_models_list_cache())
|
||
|
||
status = "可用" if is_available else "不可用"
|
||
logger.info(f"更新模型可用状态: id={model_id}, status={status}")
|
||
return model
|
||
|
||
@staticmethod
|
||
def get_model_by_name(db: Session, provider_id: str, model_name: str) -> Optional[Model]:
|
||
"""根据 provider_model_name 获取模型"""
|
||
return (
|
||
db.query(Model)
|
||
.filter(and_(Model.provider_id == provider_id, Model.provider_model_name == model_name))
|
||
.first()
|
||
)
|
||
|
||
@staticmethod
|
||
def batch_create_models(
|
||
db: Session, provider_id: str, models_data: List[ModelCreate]
|
||
) -> List[Model]: # UUID
|
||
"""批量创建模型"""
|
||
# 检查提供商是否存在
|
||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||
if not provider:
|
||
raise NotFoundException(f"提供商 {provider_id} 不存在")
|
||
|
||
created_models = []
|
||
for model_data in models_data:
|
||
# 检查是否已存在
|
||
existing = (
|
||
db.query(Model)
|
||
.filter(
|
||
and_(
|
||
Model.provider_id == provider_id,
|
||
Model.provider_model_name == model_data.provider_model_name,
|
||
)
|
||
)
|
||
.first()
|
||
)
|
||
|
||
if existing:
|
||
logger.warning(f"模型 {model_data.provider_model_name} 已存在,跳过创建")
|
||
continue
|
||
|
||
model = Model(
|
||
provider_id=provider_id,
|
||
global_model_id=model_data.global_model_id,
|
||
provider_model_name=model_data.provider_model_name,
|
||
price_per_request=model_data.price_per_request,
|
||
tiered_pricing=model_data.tiered_pricing,
|
||
supports_vision=model_data.supports_vision,
|
||
supports_function_calling=model_data.supports_function_calling,
|
||
supports_streaming=model_data.supports_streaming,
|
||
supports_extended_thinking=model_data.supports_extended_thinking,
|
||
is_active=model_data.is_active,
|
||
config=model_data.config,
|
||
)
|
||
db.add(model)
|
||
created_models.append(model)
|
||
|
||
if created_models:
|
||
try:
|
||
db.commit()
|
||
for model in created_models:
|
||
db.refresh(model)
|
||
logger.info(f"批量创建 {len(created_models)} 个模型成功")
|
||
|
||
# 清除 /v1/models 列表缓存
|
||
asyncio.create_task(invalidate_models_list_cache())
|
||
except IntegrityError as e:
|
||
db.rollback()
|
||
logger.error(f"批量创建模型失败: {str(e)}")
|
||
raise InvalidRequestException("批量创建模型失败")
|
||
|
||
return created_models
|
||
|
||
@staticmethod
|
||
def convert_to_response(model: Model) -> ModelResponse:
|
||
"""转换为响应模型(新架构:从 GlobalModel 获取显示信息和默认值)"""
|
||
return ModelResponse(
|
||
id=model.id,
|
||
provider_id=model.provider_id,
|
||
global_model_id=model.global_model_id,
|
||
provider_model_name=model.provider_model_name,
|
||
provider_model_mappings=model.provider_model_mappings,
|
||
# 原始配置值(可能为空)
|
||
price_per_request=model.price_per_request,
|
||
tiered_pricing=model.tiered_pricing,
|
||
supports_vision=model.supports_vision,
|
||
supports_function_calling=model.supports_function_calling,
|
||
supports_streaming=model.supports_streaming,
|
||
supports_extended_thinking=model.supports_extended_thinking,
|
||
supports_image_generation=model.supports_image_generation,
|
||
# 有效值(合并 Model 和 GlobalModel 默认值)
|
||
effective_tiered_pricing=model.get_effective_tiered_pricing(),
|
||
effective_input_price=model.get_effective_input_price(),
|
||
effective_output_price=model.get_effective_output_price(),
|
||
effective_price_per_request=model.get_effective_price_per_request(),
|
||
effective_supports_vision=model.get_effective_supports_vision(),
|
||
effective_supports_function_calling=model.get_effective_supports_function_calling(),
|
||
effective_supports_streaming=model.get_effective_supports_streaming(),
|
||
effective_supports_extended_thinking=model.get_effective_supports_extended_thinking(),
|
||
effective_supports_image_generation=model.get_effective_supports_image_generation(),
|
||
is_active=model.is_active,
|
||
is_available=model.is_available if model.is_available is not None else True,
|
||
created_at=model.created_at,
|
||
updated_at=model.updated_at,
|
||
# GlobalModel 信息(如果存在)
|
||
global_model_name=model.global_model.name if model.global_model else None,
|
||
global_model_display_name=(
|
||
model.global_model.display_name if model.global_model else None
|
||
),
|
||
)
|