diff --git a/src/api/base/pagination.py b/src/api/base/pagination.py index b5df15a..1e300d4 100644 --- a/src/api/base/pagination.py +++ b/src/api/base/pagination.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import asdict, dataclass -from typing import List, Sequence, Tuple, TypeVar +from typing import Any, List, Sequence, Tuple, TypeVar from sqlalchemy.orm import Query @@ -40,10 +40,10 @@ def paginate_sequence( return sliced, meta -def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra) -> dict: +def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra: Any) -> dict: """ 构建标准分页响应 payload。 """ - payload = {"items": items, "meta": meta.to_dict()} + payload: dict = {"items": items, "meta": meta.to_dict()} payload.update(extra) return payload diff --git a/src/api/handlers/base/chat_handler_base.py b/src/api/handlers/base/chat_handler_base.py index e599ac9..c48568c 100644 --- a/src/api/handlers/base/chat_handler_base.py +++ b/src/api/handlers/base/chat_handler_base.py @@ -263,7 +263,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC): mapping = await mapper.get_mapping(source_model, provider_id) if mapping and mapping.model: - mapped_name = str(mapping.model.provider_model_name) + # 使用 select_provider_model_name 支持别名功能 + # 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名 + affinity_key = self.api_key.id if self.api_key else None + mapped_name = mapping.model.select_provider_model_name(affinity_key) logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}") return mapped_name diff --git a/src/api/handlers/base/cli_handler_base.py b/src/api/handlers/base/cli_handler_base.py index 49099da..04a0326 100644 --- a/src/api/handlers/base/cli_handler_base.py +++ b/src/api/handlers/base/cli_handler_base.py @@ -190,14 +190,17 @@ class CliMessageHandlerBase(BaseMessageHandler): """ 获取模型映射后的实际模型名 - 按优先级查找:映射 → 别名 → 直接匹配 GlobalModel + 查找逻辑: + 1. 直接通过 GlobalModel.name 匹配 + 2. 查找该 Provider 的 Model 实现 + 3. 使用 provider_model_name / provider_model_aliases 选择最终名称 Args: - source_model: 用户请求的模型名(可能是别名) + source_model: 用户请求的模型名(必须是 GlobalModel.name) provider_id: Provider ID Returns: - 映射后的 provider_model_name,如果没有找到映射则返回 None + 映射后的 Provider 模型名,如果没有找到映射则返回 None """ from src.services.model.mapper import ModelMapperMiddleware @@ -207,7 +210,10 @@ class CliMessageHandlerBase(BaseMessageHandler): logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}") if mapping and mapping.model: - mapped_name = str(mapping.model.provider_model_name) + # 使用 select_provider_model_name 支持别名功能 + # 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名 + affinity_key = self.api_key.id if self.api_key else None + mapped_name = mapping.model.select_provider_model_name(affinity_key) logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)") return mapped_name diff --git a/src/api/handlers/base/parsers.py b/src/api/handlers/base/parsers.py index 0ec7636..d268478 100644 --- a/src/api/handlers/base/parsers.py +++ b/src/api/handlers/base/parsers.py @@ -5,7 +5,7 @@ 不再经过 Protocol 抽象层。 """ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Type from src.api.handlers.base.response_parser import ( ParsedChunk, @@ -60,7 +60,7 @@ def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[s class OpenAIResponseParser(ResponseParser): """OpenAI 格式响应解析器""" - def __init__(self): + def __init__(self) -> None: from src.api.handlers.openai.stream_parser import OpenAIStreamParser self._parser = OpenAIStreamParser() @@ -146,7 +146,7 @@ class OpenAIResponseParser(ResponseParser): if choices: message = choices[0].get("message", {}) content = message.get("content") - if content: + if isinstance(content, str): return content return "" @@ -158,7 +158,7 @@ class OpenAIResponseParser(ResponseParser): class OpenAICliResponseParser(OpenAIResponseParser): """OpenAI CLI 格式响应解析器""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.name = "OPENAI_CLI" self.api_format = "OPENAI_CLI" @@ -167,7 +167,7 @@ class OpenAICliResponseParser(OpenAIResponseParser): class ClaudeResponseParser(ResponseParser): """Claude 格式响应解析器""" - def __init__(self): + def __init__(self) -> None: from src.api.handlers.claude.stream_parser import ClaudeStreamParser self._parser = ClaudeStreamParser() @@ -291,7 +291,7 @@ class ClaudeResponseParser(ResponseParser): class ClaudeCliResponseParser(ClaudeResponseParser): """Claude CLI 格式响应解析器""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.name = "CLAUDE_CLI" self.api_format = "CLAUDE_CLI" @@ -300,7 +300,7 @@ class ClaudeCliResponseParser(ClaudeResponseParser): class GeminiResponseParser(ResponseParser): """Gemini 格式响应解析器""" - def __init__(self): + def __init__(self) -> None: from src.api.handlers.gemini.stream_parser import GeminiStreamParser self._parser = GeminiStreamParser() @@ -443,20 +443,20 @@ class GeminiResponseParser(ResponseParser): 使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误 """ - return self._parser.is_error_event(response) + return bool(self._parser.is_error_event(response)) class GeminiCliResponseParser(GeminiResponseParser): """Gemini CLI 格式响应解析器""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.name = "GEMINI_CLI" self.api_format = "GEMINI_CLI" # 解析器注册表 -_PARSERS = { +_PARSERS: Dict[str, Type[ResponseParser]] = { "CLAUDE": ClaudeResponseParser, "CLAUDE_CLI": ClaudeCliResponseParser, "OPENAI": OpenAIResponseParser, @@ -498,6 +498,5 @@ __all__ = [ "GeminiResponseParser", "GeminiCliResponseParser", "get_parser_for_format", - "get_parser_from_protocol", "is_cli_format", ] diff --git a/src/api/handlers/claude/stream_parser.py b/src/api/handlers/claude/stream_parser.py index 6fe507a..88ac955 100644 --- a/src/api/handlers/claude/stream_parser.py +++ b/src/api/handlers/claude/stream_parser.py @@ -108,7 +108,10 @@ class ClaudeStreamParser: return None try: - return json.loads(line) + result = json.loads(line) + if isinstance(result, dict): + return result + return None except json.JSONDecodeError: return None @@ -147,7 +150,8 @@ class ClaudeStreamParser: Returns: 事件类型字符串 """ - return event.get("type") + event_type = event.get("type") + return str(event_type) if event_type is not None else None def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]: """ @@ -164,7 +168,8 @@ class ClaudeStreamParser: delta = event.get("delta", {}) if delta.get("type") == self.DELTA_TEXT: - return delta.get("text") + text = delta.get("text") + return str(text) if text is not None else None return None @@ -219,7 +224,8 @@ class ClaudeStreamParser: return None message = event.get("message", {}) - return message.get("id") + msg_id = message.get("id") + return str(msg_id) if msg_id is not None else None def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]: """ @@ -235,7 +241,8 @@ class ClaudeStreamParser: return None delta = event.get("delta", {}) - return delta.get("stop_reason") + reason = delta.get("stop_reason") + return str(reason) if reason is not None else None __all__ = ["ClaudeStreamParser"] diff --git a/src/api/handlers/gemini/converter.py b/src/api/handlers/gemini/converter.py index 551d8be..5f0d0af 100644 --- a/src/api/handlers/gemini/converter.py +++ b/src/api/handlers/gemini/converter.py @@ -70,7 +70,7 @@ class ClaudeToGeminiConverter: return [{"text": content}] if isinstance(content, list): - parts = [] + parts: List[Dict[str, Any]] = [] for block in content: if isinstance(block, str): parts.append({"text": block}) @@ -249,6 +249,8 @@ class GeminiToClaudeConverter: "RECITATION": "content_filtered", "OTHER": "stop_sequence", } + if finish_reason is None: + return "end_turn" return mapping.get(finish_reason, "end_turn") def _create_empty_response(self) -> Dict[str, Any]: @@ -365,7 +367,7 @@ class OpenAIToGeminiConverter: return [{"text": content}] if isinstance(content, list): - parts = [] + parts: List[Dict[str, Any]] = [] for item in content: if isinstance(item, str): parts.append({"text": item}) @@ -524,7 +526,7 @@ class GeminiToOpenAIConverter: "total_tokens": prompt_tokens + completion_tokens, } - def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]: + def _convert_finish_reason(self, finish_reason: Optional[str]) -> str: """转换停止原因""" mapping = { "STOP": "stop", @@ -533,6 +535,8 @@ class GeminiToOpenAIConverter: "RECITATION": "content_filter", "OTHER": "stop", } + if finish_reason is None: + return "stop" return mapping.get(finish_reason, "stop") diff --git a/src/api/handlers/gemini/stream_parser.py b/src/api/handlers/gemini/stream_parser.py index ac17beb..68b4ce1 100644 --- a/src/api/handlers/gemini/stream_parser.py +++ b/src/api/handlers/gemini/stream_parser.py @@ -10,7 +10,7 @@ Gemini API 的流式响应格式与 Claude/OpenAI 不同: """ import json -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union class GeminiStreamParser: @@ -32,18 +32,18 @@ class GeminiStreamParser: FINISH_REASON_RECITATION = "RECITATION" FINISH_REASON_OTHER = "OTHER" - def __init__(self): + def __init__(self) -> None: self._buffer = "" self._in_array = False self._brace_depth = 0 - def reset(self): + def reset(self) -> None: """重置解析器状态""" self._buffer = "" self._in_array = False self._brace_depth = 0 - def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]: + def parse_chunk(self, chunk: Union[bytes, str]) -> List[Dict[str, Any]]: """ 解析流式数据块 @@ -111,7 +111,10 @@ class GeminiStreamParser: return None try: - return json.loads(line.strip().rstrip(",")) + result = json.loads(line.strip().rstrip(",")) + if isinstance(result, dict): + return result + return None except json.JSONDecodeError: return None @@ -216,7 +219,8 @@ class GeminiStreamParser: """ candidates = event.get("candidates", []) if candidates: - return candidates[0].get("finishReason") + reason = candidates[0].get("finishReason") + return str(reason) if reason is not None else None return None def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]: @@ -285,7 +289,8 @@ class GeminiStreamParser: Returns: 模型版本,如果没有返回 None """ - return event.get("modelVersion") + version = event.get("modelVersion") + return str(version) if version is not None else None def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]: """ @@ -301,7 +306,10 @@ class GeminiStreamParser: if not candidates: return None - return candidates[0].get("safetyRatings") + ratings = candidates[0].get("safetyRatings") + if isinstance(ratings, list): + return ratings + return None __all__ = ["GeminiStreamParser"] diff --git a/src/api/handlers/openai/stream_parser.py b/src/api/handlers/openai/stream_parser.py index 2df7f71..cc639b0 100644 --- a/src/api/handlers/openai/stream_parser.py +++ b/src/api/handlers/openai/stream_parser.py @@ -78,7 +78,10 @@ class OpenAIStreamParser: return None try: - return json.loads(line) + result = json.loads(line) + if isinstance(result, dict): + return result + return None except json.JSONDecodeError: return None @@ -116,7 +119,8 @@ class OpenAIStreamParser: """ choices = chunk.get("choices", []) if choices: - return choices[0].get("finish_reason") + reason = choices[0].get("finish_reason") + return str(reason) if reason is not None else None return None def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]: @@ -156,7 +160,10 @@ class OpenAIStreamParser: return None delta = choices[0].get("delta", {}) - return delta.get("tool_calls") + tool_calls = delta.get("tool_calls") + if isinstance(tool_calls, list): + return tool_calls + return None def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]: """ @@ -175,7 +182,8 @@ class OpenAIStreamParser: return None delta = choices[0].get("delta", {}) - return delta.get("role") + role = delta.get("role") + return str(role) if role is not None else None __all__ = ["OpenAIStreamParser"] diff --git a/src/api/public/capabilities.py b/src/api/public/capabilities.py index 6aae312..72fd44e 100644 --- a/src/api/public/capabilities.py +++ b/src/api/public/capabilities.py @@ -61,15 +61,18 @@ async def get_model_supported_capabilities( 获取指定模型支持的能力列表 Args: - model_name: 模型名称(如 claude-sonnet-4-20250514) + model_name: 模型名称(如 claude-sonnet-4-20250514,必须是 GlobalModel.name) Returns: 模型支持的能力列表,以及每个能力的详细定义 """ - from src.services.model.mapping_resolver import get_model_mapping_resolver + from src.models.database import GlobalModel - mapping_resolver = get_model_mapping_resolver() - global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None) + global_model = ( + db.query(GlobalModel) + .filter(GlobalModel.name == model_name, GlobalModel.is_active == True) + .first() + ) if not global_model: return { diff --git a/src/api/user_me/routes.py b/src/api/user_me/routes.py index 4458468..31f4ba6 100644 --- a/src/api/user_me/routes.py +++ b/src/api/user_me/routes.py @@ -7,7 +7,7 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query, Request from pydantic import ValidationError from sqlalchemy import and_, func -from sqlalchemy.orm import Session, joinedload +from sqlalchemy.orm import Session from src.api.base.adapter import ApiAdapter, ApiMode from src.api.base.authenticated_adapter import AuthenticatedApiAdapter @@ -713,7 +713,7 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter): async def handle(self, context): # type: ignore[override] from sqlalchemy.orm import selectinload - from src.models.database import Model, ModelMapping, ProviderEndpoint + from src.models.database import Model, ProviderEndpoint db = context.db @@ -765,53 +765,6 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter): } ) - # 查询该 Provider 所有 Model 对应的 GlobalModel 的别名/映射 - provider_model_global_ids = { - m.global_model_id for m in provider.models if m.global_model_id - } - if provider_model_global_ids: - # 查询全局别名 + Provider 特定映射 - alias_mappings = ( - db.query(ModelMapping) - .options(joinedload(ModelMapping.target_global_model)) - .filter( - ModelMapping.target_global_model_id.in_(provider_model_global_ids), - ModelMapping.is_active == True, - (ModelMapping.provider_id == provider.id) - | (ModelMapping.provider_id == None), - ) - .all() - ) - for alias_obj in alias_mappings: - # 为这个别名找到该 Provider 的 Model 实现 - model = next( - ( - m - for m in provider.models - if m.global_model_id == alias_obj.target_global_model_id - ), - None, - ) - if model: - models_data.append( - { - "id": alias_obj.id, - "name": alias_obj.source_model, - "display_name": ( - alias_obj.target_global_model.display_name - if alias_obj.target_global_model - else alias_obj.source_model - ), - "input_price_per_1m": model.input_price_per_1m, - "output_price_per_1m": model.output_price_per_1m, - "cache_creation_price_per_1m": model.cache_creation_price_per_1m, - "cache_read_price_per_1m": model.cache_read_price_per_1m, - "supports_vision": model.supports_vision, - "supports_function_calling": model.supports_function_calling, - "supports_streaming": model.supports_streaming, - } - ) - result.append( { "id": provider.id, diff --git a/src/config/constants.py b/src/config/constants.py index f8034df..7b8eb1e 100644 --- a/src/config/constants.py +++ b/src/config/constants.py @@ -14,7 +14,6 @@ class CacheTTL: # Provider/Model 缓存 - 配置变更不频繁 PROVIDER = 300 # 5分钟 MODEL = 300 # 5分钟 - MODEL_MAPPING = 300 # 5分钟 # 缓存亲和性 - 对应 provider_api_key.cache_ttl_minutes 默认值 CACHE_AFFINITY = 300 # 5分钟 @@ -33,9 +32,6 @@ class CacheSize: # 默认 LRU 缓存大小 DEFAULT = 1000 - # ModelMapping 缓存(可能有较多别名) - MODEL_MAPPING = 2000 - # ============================================================================== # 并发和限流常量 diff --git a/src/core/cache_utils.py b/src/core/cache_utils.py index 879fa79..8a2e389 100644 --- a/src/core/cache_utils.py +++ b/src/core/cache_utils.py @@ -115,7 +115,7 @@ class SyncLRUCache: """删除缓存值(通过索引)""" self.delete(key) - def keys(self): + def keys(self) -> list: """返回所有未过期的 key""" with self._lock: now = time.time() diff --git a/src/core/logger.py b/src/core/logger.py index d6cf5d5..ca224d9 100644 --- a/src/core/logger.py +++ b/src/core/logger.py @@ -67,7 +67,7 @@ FILE_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}: logger.remove() -def _log_filter(record): +def _log_filter(record: dict) -> bool: # type: ignore[type-arg] return "watchfiles" not in record["name"] @@ -76,7 +76,7 @@ if IS_DOCKER: sys.stdout, format=CONSOLE_FORMAT_PROD, level=LOG_LEVEL, - filter=_log_filter, + filter=_log_filter, # type: ignore[arg-type] colorize=False, ) else: @@ -84,7 +84,7 @@ else: sys.stdout, format=CONSOLE_FORMAT_DEV, level=LOG_LEVEL, - filter=_log_filter, + filter=_log_filter, # type: ignore[arg-type] colorize=True, ) @@ -97,7 +97,7 @@ if not DISABLE_FILE_LOG: log_dir / "app.log", format=FILE_FORMAT, level="DEBUG", - filter=_log_filter, + filter=_log_filter, # type: ignore[arg-type] rotation="00:00", retention="30 days", compression="gz", @@ -110,7 +110,7 @@ if not DISABLE_FILE_LOG: log_dir / "error.log", format=FILE_FORMAT, level="ERROR", - filter=_log_filter, + filter=_log_filter, # type: ignore[arg-type] rotation="00:00", retention="30 days", compression="gz", diff --git a/src/core/provider_health.py b/src/core/provider_health.py index 942d7f4..149a1e1 100644 --- a/src/core/provider_health.py +++ b/src/core/provider_health.py @@ -6,7 +6,7 @@ import time from collections import defaultdict from datetime import datetime, timedelta -from typing import Dict, Optional +from typing import Any, Dict, Optional class ProviderHealthTracker: @@ -32,7 +32,7 @@ class ProviderHealthTracker: # 存储优先级调整 self.priority_adjustments: Dict[str, int] = {} - def record_success(self, provider_name: str): + def record_success(self, provider_name: str) -> None: """记录成功的请求""" current_time = time.time() @@ -47,7 +47,7 @@ class ProviderHealthTracker: if self.priority_adjustments.get(provider_name, 0) < 0: self.priority_adjustments[provider_name] += 1 - def record_failure(self, provider_name: str): + def record_failure(self, provider_name: str) -> None: """记录失败的请求""" current_time = time.time() @@ -93,7 +93,7 @@ class ProviderHealthTracker: "status": self._get_status_label(failure_rate, recent_failures), } - def _cleanup_old_records(self, provider_name: str, current_time: float): + def _cleanup_old_records(self, provider_name: str, current_time: float) -> None: """清理超出时间窗口的记录""" # 清理失败记录 self.failures[provider_name] = [ @@ -130,7 +130,7 @@ class ProviderHealthTracker: adjustment = self.get_priority_adjustment(provider_name) return adjustment > -3 - def reset_provider_health(self, provider_name: str): + def reset_provider_health(self, provider_name: str) -> None: """重置提供商的健康状态(管理员手动操作)""" self.failures[provider_name] = [] self.successes[provider_name] = [] @@ -146,7 +146,7 @@ class SimpleProviderSelector: def __init__(self, health_tracker: ProviderHealthTracker): self.health_tracker = health_tracker - def select_provider(self, providers: list, specified_provider: Optional[str] = None): + def select_provider(self, providers: list, specified_provider: Optional[str] = None) -> Any: """ 选择提供商 diff --git a/src/database/async_utils.py b/src/database/async_utils.py index a4ee86d..5afbcf8 100644 --- a/src/database/async_utils.py +++ b/src/database/async_utils.py @@ -5,12 +5,12 @@ import asyncio from functools import wraps -from typing import Any, Callable, TypeVar +from typing import Any, Callable, Coroutine, TypeVar T = TypeVar("T") -async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T: +async def run_in_executor(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: """ 在线程池中运行同步函数,避免阻塞事件循环 @@ -21,7 +21,7 @@ async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T: return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) -def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]: +def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]: """ 装饰器:包装同步数据库函数为异步函数 @@ -35,7 +35,7 @@ def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]: """ @wraps(func) - async def wrapper(*args, **kwargs): + async def wrapper(*args: Any, **kwargs: Any) -> T: return await run_in_executor(func, *args, **kwargs) return wrapper diff --git a/src/database/database.py b/src/database/database.py index 5cfbf01..7397d46 100644 --- a/src/database/database.py +++ b/src/database/database.py @@ -347,7 +347,7 @@ def init_default_models(db: Session): """初始化默认模型配置""" # 注意:作为中转代理服务,不再预设模型配置 - # 模型配置应该通过 Model 和 ModelMapping 表动态管理 + # 模型配置应该通过 GlobalModel 和 Model 表动态管理 # 这个函数保留用于未来可能的默认模型初始化 pass diff --git a/src/middleware/__init__.py b/src/middleware/__init__.py index 9662090..a2f63f5 100644 --- a/src/middleware/__init__.py +++ b/src/middleware/__init__.py @@ -2,4 +2,4 @@ 中间件模块 """ -__all__ = [] +__all__: list[str] = [] diff --git a/src/models/api.py b/src/models/api.py index 0b0e9ed..41dd9c2 100644 --- a/src/models/api.py +++ b/src/models/api.py @@ -334,7 +334,6 @@ class ProviderResponse(BaseModel): updated_at: datetime models_count: int = 0 active_models_count: int = 0 - model_mappings_count: int = 0 api_keys_count: int = 0 class Config: @@ -346,7 +345,11 @@ class ModelCreate(BaseModel): """创建模型请求 - 价格和能力字段可选,为空时使用 GlobalModel 默认值""" provider_model_name: str = Field( - ..., min_length=1, max_length=200, description="Provider 侧的模型名称" + ..., min_length=1, max_length=200, description="Provider 侧的主模型名称" + ) + provider_model_aliases: Optional[List[dict]] = Field( + None, + description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]", ) global_model_id: str = Field(..., description="关联的 GlobalModel ID(必填)") # 按次计费配置 - 可选,为空时使用 GlobalModel 默认值 @@ -374,6 +377,10 @@ class ModelUpdate(BaseModel): """更新模型请求""" provider_model_name: Optional[str] = Field(None, min_length=1, max_length=200) + provider_model_aliases: Optional[List[dict]] = Field( + None, + description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]", + ) global_model_id: Optional[str] = None # 按次计费配置 price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用") @@ -398,6 +405,7 @@ class ModelResponse(BaseModel): provider_id: str global_model_id: Optional[str] provider_model_name: str + provider_model_aliases: Optional[List[dict]] = None # 按次计费配置 price_per_request: Optional[float] = None @@ -465,54 +473,6 @@ class ModelDetailResponse(BaseModel): from_attributes = True -# ========== 模型映射 ========== -class ModelMappingCreate(BaseModel): - """创建模型映射请求(源模型到目标模型的映射)""" - - source_model: str = Field(..., min_length=1, max_length=200, description="源模型名或别名") - target_global_model_id: str = Field(..., description="目标 GlobalModel ID") - provider_id: Optional[str] = Field(None, description="Provider ID(为空时表示全局别名)") - mapping_type: str = Field( - "alias", - description="映射类型:alias=按目标模型计费(别名),mapping=按源模型计费(降级映射)", - ) - is_active: bool = Field(True, description="是否启用") - - -class ModelMappingUpdate(BaseModel): - """更新模型映射请求""" - - source_model: Optional[str] = Field( - None, min_length=1, max_length=200, description="源模型名或别名" - ) - target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID") - provider_id: Optional[str] = Field(None, description="Provider ID(为空时表示全局别名)") - mapping_type: Optional[str] = Field( - None, description="映射类型:alias=按目标模型计费(别名),mapping=按源模型计费(降级映射)" - ) - is_active: Optional[bool] = None - - -class ModelMappingResponse(BaseModel): - """模型映射响应""" - - id: str - source_model: str - target_global_model_id: str - target_global_model_name: Optional[str] - target_global_model_display_name: Optional[str] - provider_id: Optional[str] - provider_name: Optional[str] - scope: str = Field(..., description="global 或 provider") - mapping_type: str = Field(..., description="映射类型:alias 或 mapping") - is_active: bool - created_at: datetime - updated_at: datetime - - class Config: - from_attributes = True - - # ========== 系统设置 ========== class SystemSettingsRequest(BaseModel): """系统设置请求""" @@ -558,7 +518,6 @@ class PublicProviderResponse(BaseModel): # 统计信息 models_count: int active_models_count: int - mappings_count: int endpoints_count: int # 端点总数 active_endpoints_count: int # 活跃端点数 @@ -587,19 +546,6 @@ class PublicModelResponse(BaseModel): is_active: bool = True -class PublicModelMappingResponse(BaseModel): - """公开的模型映射信息响应""" - - id: str - source_model: str - target_global_model_id: str - target_global_model_name: Optional[str] - target_global_model_display_name: Optional[str] - provider_id: Optional[str] = None - scope: str = Field(..., description="global 或 provider") - is_active: bool - - class ProviderStatsResponse(BaseModel): """提供商统计信息响应""" @@ -607,7 +553,6 @@ class ProviderStatsResponse(BaseModel): active_providers: int total_models: int active_models: int - total_mappings: int supported_formats: List[str] diff --git a/src/utils/sse_parser.py b/src/utils/sse_parser.py index 1bacdc5..54ae426 100644 --- a/src/utils/sse_parser.py +++ b/src/utils/sse_parser.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union class SSEEventParser: @@ -8,7 +8,7 @@ class SSEEventParser: self._reset_buffer() def _reset_buffer(self) -> None: - self._buffer: Dict[str, Optional[str] | List[str]] = { + self._buffer: Dict[str, Union[Optional[str], List[str]]] = { "event": None, "data": [], "id": None, @@ -17,16 +17,19 @@ class SSEEventParser: def _finalize_event(self) -> Optional[Dict[str, Optional[str]]]: data_lines = self._buffer.get("data", []) - if not data_lines: + if not isinstance(data_lines, list) or not data_lines: self._reset_buffer() return None data_str = "\n".join(data_lines) - event = { - "event": self._buffer.get("event"), + event_val = self._buffer.get("event") + id_val = self._buffer.get("id") + retry_val = self._buffer.get("retry") + event: Dict[str, Optional[str]] = { + "event": event_val if isinstance(event_val, str) else None, "data": data_str, - "id": self._buffer.get("id"), - "retry": self._buffer.get("retry"), + "id": id_val if isinstance(id_val, str) else None, + "retry": retry_val if isinstance(retry_val, str) else None, } self._reset_buffer()