refactor(backend): update handlers, utilities and core modules after models restructure

This commit is contained in:
fawney19
2025-12-15 14:30:53 +08:00
parent 03ee6c16d9
commit 88e37594cf
19 changed files with 121 additions and 186 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import List, Sequence, Tuple, TypeVar from typing import Any, List, Sequence, Tuple, TypeVar
from sqlalchemy.orm import Query from sqlalchemy.orm import Query
@@ -40,10 +40,10 @@ def paginate_sequence(
return sliced, meta return sliced, meta
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra) -> dict: def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra: Any) -> dict:
""" """
构建标准分页响应 payload。 构建标准分页响应 payload。
""" """
payload = {"items": items, "meta": meta.to_dict()} payload: dict = {"items": items, "meta": meta.to_dict()}
payload.update(extra) payload.update(extra)
return payload return payload

View File

@@ -263,7 +263,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
mapping = await mapper.get_mapping(source_model, provider_id) mapping = await mapper.get_mapping(source_model, provider_id)
if mapping and mapping.model: if mapping and mapping.model:
mapped_name = str(mapping.model.provider_model_name) # 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}") logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
return mapped_name return mapped_name

View File

@@ -190,14 +190,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
""" """
获取模型映射后的实际模型名 获取模型映射后的实际模型名
按优先级查找:映射 → 别名 → 直接匹配 GlobalModel 查找逻辑:
1. 直接通过 GlobalModel.name 匹配
2. 查找该 Provider 的 Model 实现
3. 使用 provider_model_name / provider_model_aliases 选择最终名称
Args: Args:
source_model: 用户请求的模型名(可能是别名 source_model: 用户请求的模型名(必须是 GlobalModel.name
provider_id: Provider ID provider_id: Provider ID
Returns: Returns:
映射后的 provider_model_name,如果没有找到映射则返回 None 映射后的 Provider 模型名,如果没有找到映射则返回 None
""" """
from src.services.model.mapper import ModelMapperMiddleware from src.services.model.mapper import ModelMapperMiddleware
@@ -207,7 +210,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}") logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
if mapping and mapping.model: if mapping and mapping.model:
mapped_name = str(mapping.model.provider_model_name) # 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)") logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
return mapped_name return mapped_name

View File

@@ -5,7 +5,7 @@
不再经过 Protocol 抽象层。 不再经过 Protocol 抽象层。
""" """
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple, Type
from src.api.handlers.base.response_parser import ( from src.api.handlers.base.response_parser import (
ParsedChunk, ParsedChunk,
@@ -60,7 +60,7 @@ def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[s
class OpenAIResponseParser(ResponseParser): class OpenAIResponseParser(ResponseParser):
"""OpenAI 格式响应解析器""" """OpenAI 格式响应解析器"""
def __init__(self): def __init__(self) -> None:
from src.api.handlers.openai.stream_parser import OpenAIStreamParser from src.api.handlers.openai.stream_parser import OpenAIStreamParser
self._parser = OpenAIStreamParser() self._parser = OpenAIStreamParser()
@@ -146,7 +146,7 @@ class OpenAIResponseParser(ResponseParser):
if choices: if choices:
message = choices[0].get("message", {}) message = choices[0].get("message", {})
content = message.get("content") content = message.get("content")
if content: if isinstance(content, str):
return content return content
return "" return ""
@@ -158,7 +158,7 @@ class OpenAIResponseParser(ResponseParser):
class OpenAICliResponseParser(OpenAIResponseParser): class OpenAICliResponseParser(OpenAIResponseParser):
"""OpenAI CLI 格式响应解析器""" """OpenAI CLI 格式响应解析器"""
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.name = "OPENAI_CLI" self.name = "OPENAI_CLI"
self.api_format = "OPENAI_CLI" self.api_format = "OPENAI_CLI"
@@ -167,7 +167,7 @@ class OpenAICliResponseParser(OpenAIResponseParser):
class ClaudeResponseParser(ResponseParser): class ClaudeResponseParser(ResponseParser):
"""Claude 格式响应解析器""" """Claude 格式响应解析器"""
def __init__(self): def __init__(self) -> None:
from src.api.handlers.claude.stream_parser import ClaudeStreamParser from src.api.handlers.claude.stream_parser import ClaudeStreamParser
self._parser = ClaudeStreamParser() self._parser = ClaudeStreamParser()
@@ -291,7 +291,7 @@ class ClaudeResponseParser(ResponseParser):
class ClaudeCliResponseParser(ClaudeResponseParser): class ClaudeCliResponseParser(ClaudeResponseParser):
"""Claude CLI 格式响应解析器""" """Claude CLI 格式响应解析器"""
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.name = "CLAUDE_CLI" self.name = "CLAUDE_CLI"
self.api_format = "CLAUDE_CLI" self.api_format = "CLAUDE_CLI"
@@ -300,7 +300,7 @@ class ClaudeCliResponseParser(ClaudeResponseParser):
class GeminiResponseParser(ResponseParser): class GeminiResponseParser(ResponseParser):
"""Gemini 格式响应解析器""" """Gemini 格式响应解析器"""
def __init__(self): def __init__(self) -> None:
from src.api.handlers.gemini.stream_parser import GeminiStreamParser from src.api.handlers.gemini.stream_parser import GeminiStreamParser
self._parser = GeminiStreamParser() self._parser = GeminiStreamParser()
@@ -443,20 +443,20 @@ class GeminiResponseParser(ResponseParser):
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误 使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
""" """
return self._parser.is_error_event(response) return bool(self._parser.is_error_event(response))
class GeminiCliResponseParser(GeminiResponseParser): class GeminiCliResponseParser(GeminiResponseParser):
"""Gemini CLI 格式响应解析器""" """Gemini CLI 格式响应解析器"""
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.name = "GEMINI_CLI" self.name = "GEMINI_CLI"
self.api_format = "GEMINI_CLI" self.api_format = "GEMINI_CLI"
# 解析器注册表 # 解析器注册表
_PARSERS = { _PARSERS: Dict[str, Type[ResponseParser]] = {
"CLAUDE": ClaudeResponseParser, "CLAUDE": ClaudeResponseParser,
"CLAUDE_CLI": ClaudeCliResponseParser, "CLAUDE_CLI": ClaudeCliResponseParser,
"OPENAI": OpenAIResponseParser, "OPENAI": OpenAIResponseParser,
@@ -498,6 +498,5 @@ __all__ = [
"GeminiResponseParser", "GeminiResponseParser",
"GeminiCliResponseParser", "GeminiCliResponseParser",
"get_parser_for_format", "get_parser_for_format",
"get_parser_from_protocol",
"is_cli_format", "is_cli_format",
] ]

View File

@@ -108,7 +108,10 @@ class ClaudeStreamParser:
return None return None
try: try:
return json.loads(line) result = json.loads(line)
if isinstance(result, dict):
return result
return None
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
@@ -147,7 +150,8 @@ class ClaudeStreamParser:
Returns: Returns:
事件类型字符串 事件类型字符串
""" """
return event.get("type") event_type = event.get("type")
return str(event_type) if event_type is not None else None
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]: def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
""" """
@@ -164,7 +168,8 @@ class ClaudeStreamParser:
delta = event.get("delta", {}) delta = event.get("delta", {})
if delta.get("type") == self.DELTA_TEXT: if delta.get("type") == self.DELTA_TEXT:
return delta.get("text") text = delta.get("text")
return str(text) if text is not None else None
return None return None
@@ -219,7 +224,8 @@ class ClaudeStreamParser:
return None return None
message = event.get("message", {}) message = event.get("message", {})
return message.get("id") msg_id = message.get("id")
return str(msg_id) if msg_id is not None else None
def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]: def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]:
""" """
@@ -235,7 +241,8 @@ class ClaudeStreamParser:
return None return None
delta = event.get("delta", {}) delta = event.get("delta", {})
return delta.get("stop_reason") reason = delta.get("stop_reason")
return str(reason) if reason is not None else None
__all__ = ["ClaudeStreamParser"] __all__ = ["ClaudeStreamParser"]

View File

@@ -70,7 +70,7 @@ class ClaudeToGeminiConverter:
return [{"text": content}] return [{"text": content}]
if isinstance(content, list): if isinstance(content, list):
parts = [] parts: List[Dict[str, Any]] = []
for block in content: for block in content:
if isinstance(block, str): if isinstance(block, str):
parts.append({"text": block}) parts.append({"text": block})
@@ -249,6 +249,8 @@ class GeminiToClaudeConverter:
"RECITATION": "content_filtered", "RECITATION": "content_filtered",
"OTHER": "stop_sequence", "OTHER": "stop_sequence",
} }
if finish_reason is None:
return "end_turn"
return mapping.get(finish_reason, "end_turn") return mapping.get(finish_reason, "end_turn")
def _create_empty_response(self) -> Dict[str, Any]: def _create_empty_response(self) -> Dict[str, Any]:
@@ -365,7 +367,7 @@ class OpenAIToGeminiConverter:
return [{"text": content}] return [{"text": content}]
if isinstance(content, list): if isinstance(content, list):
parts = [] parts: List[Dict[str, Any]] = []
for item in content: for item in content:
if isinstance(item, str): if isinstance(item, str):
parts.append({"text": item}) parts.append({"text": item})
@@ -524,7 +526,7 @@ class GeminiToOpenAIConverter:
"total_tokens": prompt_tokens + completion_tokens, "total_tokens": prompt_tokens + completion_tokens,
} }
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]: def _convert_finish_reason(self, finish_reason: Optional[str]) -> str:
"""转换停止原因""" """转换停止原因"""
mapping = { mapping = {
"STOP": "stop", "STOP": "stop",
@@ -533,6 +535,8 @@ class GeminiToOpenAIConverter:
"RECITATION": "content_filter", "RECITATION": "content_filter",
"OTHER": "stop", "OTHER": "stop",
} }
if finish_reason is None:
return "stop"
return mapping.get(finish_reason, "stop") return mapping.get(finish_reason, "stop")

View File

@@ -10,7 +10,7 @@ Gemini API 的流式响应格式与 Claude/OpenAI 不同:
""" """
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
class GeminiStreamParser: class GeminiStreamParser:
@@ -32,18 +32,18 @@ class GeminiStreamParser:
FINISH_REASON_RECITATION = "RECITATION" FINISH_REASON_RECITATION = "RECITATION"
FINISH_REASON_OTHER = "OTHER" FINISH_REASON_OTHER = "OTHER"
def __init__(self): def __init__(self) -> None:
self._buffer = "" self._buffer = ""
self._in_array = False self._in_array = False
self._brace_depth = 0 self._brace_depth = 0
def reset(self): def reset(self) -> None:
"""重置解析器状态""" """重置解析器状态"""
self._buffer = "" self._buffer = ""
self._in_array = False self._in_array = False
self._brace_depth = 0 self._brace_depth = 0
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]: def parse_chunk(self, chunk: Union[bytes, str]) -> List[Dict[str, Any]]:
""" """
解析流式数据块 解析流式数据块
@@ -111,7 +111,10 @@ class GeminiStreamParser:
return None return None
try: try:
return json.loads(line.strip().rstrip(",")) result = json.loads(line.strip().rstrip(","))
if isinstance(result, dict):
return result
return None
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
@@ -216,7 +219,8 @@ class GeminiStreamParser:
""" """
candidates = event.get("candidates", []) candidates = event.get("candidates", [])
if candidates: if candidates:
return candidates[0].get("finishReason") reason = candidates[0].get("finishReason")
return str(reason) if reason is not None else None
return None return None
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]: def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
@@ -285,7 +289,8 @@ class GeminiStreamParser:
Returns: Returns:
模型版本,如果没有返回 None 模型版本,如果没有返回 None
""" """
return event.get("modelVersion") version = event.get("modelVersion")
return str(version) if version is not None else None
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]: def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
""" """
@@ -301,7 +306,10 @@ class GeminiStreamParser:
if not candidates: if not candidates:
return None return None
return candidates[0].get("safetyRatings") ratings = candidates[0].get("safetyRatings")
if isinstance(ratings, list):
return ratings
return None
__all__ = ["GeminiStreamParser"] __all__ = ["GeminiStreamParser"]

View File

@@ -78,7 +78,10 @@ class OpenAIStreamParser:
return None return None
try: try:
return json.loads(line) result = json.loads(line)
if isinstance(result, dict):
return result
return None
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
@@ -116,7 +119,8 @@ class OpenAIStreamParser:
""" """
choices = chunk.get("choices", []) choices = chunk.get("choices", [])
if choices: if choices:
return choices[0].get("finish_reason") reason = choices[0].get("finish_reason")
return str(reason) if reason is not None else None
return None return None
def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]: def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]:
@@ -156,7 +160,10 @@ class OpenAIStreamParser:
return None return None
delta = choices[0].get("delta", {}) delta = choices[0].get("delta", {})
return delta.get("tool_calls") tool_calls = delta.get("tool_calls")
if isinstance(tool_calls, list):
return tool_calls
return None
def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]: def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]:
""" """
@@ -175,7 +182,8 @@ class OpenAIStreamParser:
return None return None
delta = choices[0].get("delta", {}) delta = choices[0].get("delta", {})
return delta.get("role") role = delta.get("role")
return str(role) if role is not None else None
__all__ = ["OpenAIStreamParser"] __all__ = ["OpenAIStreamParser"]

View File

@@ -61,15 +61,18 @@ async def get_model_supported_capabilities(
获取指定模型支持的能力列表 获取指定模型支持的能力列表
Args: Args:
model_name: 模型名称(如 claude-sonnet-4-20250514 model_name: 模型名称(如 claude-sonnet-4-20250514,必须是 GlobalModel.name
Returns: Returns:
模型支持的能力列表,以及每个能力的详细定义 模型支持的能力列表,以及每个能力的详细定义
""" """
from src.services.model.mapping_resolver import get_model_mapping_resolver from src.models.database import GlobalModel
mapping_resolver = get_model_mapping_resolver() global_model = (
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None) db.query(GlobalModel)
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first()
)
if not global_model: if not global_model:
return { return {

View File

@@ -7,7 +7,7 @@ from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import and_, func from sqlalchemy import and_, func
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session
from src.api.base.adapter import ApiAdapter, ApiMode from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
@@ -713,7 +713,7 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from src.models.database import Model, ModelMapping, ProviderEndpoint from src.models.database import Model, ProviderEndpoint
db = context.db db = context.db
@@ -765,53 +765,6 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
} }
) )
# 查询该 Provider 所有 Model 对应的 GlobalModel 的别名/映射
provider_model_global_ids = {
m.global_model_id for m in provider.models if m.global_model_id
}
if provider_model_global_ids:
# 查询全局别名 + Provider 特定映射
alias_mappings = (
db.query(ModelMapping)
.options(joinedload(ModelMapping.target_global_model))
.filter(
ModelMapping.target_global_model_id.in_(provider_model_global_ids),
ModelMapping.is_active == True,
(ModelMapping.provider_id == provider.id)
| (ModelMapping.provider_id == None),
)
.all()
)
for alias_obj in alias_mappings:
# 为这个别名找到该 Provider 的 Model 实现
model = next(
(
m
for m in provider.models
if m.global_model_id == alias_obj.target_global_model_id
),
None,
)
if model:
models_data.append(
{
"id": alias_obj.id,
"name": alias_obj.source_model,
"display_name": (
alias_obj.target_global_model.display_name
if alias_obj.target_global_model
else alias_obj.source_model
),
"input_price_per_1m": model.input_price_per_1m,
"output_price_per_1m": model.output_price_per_1m,
"cache_creation_price_per_1m": model.cache_creation_price_per_1m,
"cache_read_price_per_1m": model.cache_read_price_per_1m,
"supports_vision": model.supports_vision,
"supports_function_calling": model.supports_function_calling,
"supports_streaming": model.supports_streaming,
}
)
result.append( result.append(
{ {
"id": provider.id, "id": provider.id,

View File

@@ -14,7 +14,6 @@ class CacheTTL:
# Provider/Model 缓存 - 配置变更不频繁 # Provider/Model 缓存 - 配置变更不频繁
PROVIDER = 300 # 5分钟 PROVIDER = 300 # 5分钟
MODEL = 300 # 5分钟 MODEL = 300 # 5分钟
MODEL_MAPPING = 300 # 5分钟
# 缓存亲和性 - 对应 provider_api_key.cache_ttl_minutes 默认值 # 缓存亲和性 - 对应 provider_api_key.cache_ttl_minutes 默认值
CACHE_AFFINITY = 300 # 5分钟 CACHE_AFFINITY = 300 # 5分钟
@@ -33,9 +32,6 @@ class CacheSize:
# 默认 LRU 缓存大小 # 默认 LRU 缓存大小
DEFAULT = 1000 DEFAULT = 1000
# ModelMapping 缓存(可能有较多别名)
MODEL_MAPPING = 2000
# ============================================================================== # ==============================================================================
# 并发和限流常量 # 并发和限流常量

View File

@@ -115,7 +115,7 @@ class SyncLRUCache:
"""删除缓存值(通过索引)""" """删除缓存值(通过索引)"""
self.delete(key) self.delete(key)
def keys(self): def keys(self) -> list:
"""返回所有未过期的 key""" """返回所有未过期的 key"""
with self._lock: with self._lock:
now = time.time() now = time.time()

View File

@@ -67,7 +67,7 @@ FILE_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:
logger.remove() logger.remove()
def _log_filter(record): def _log_filter(record: dict) -> bool: # type: ignore[type-arg]
return "watchfiles" not in record["name"] return "watchfiles" not in record["name"]
@@ -76,7 +76,7 @@ if IS_DOCKER:
sys.stdout, sys.stdout,
format=CONSOLE_FORMAT_PROD, format=CONSOLE_FORMAT_PROD,
level=LOG_LEVEL, level=LOG_LEVEL,
filter=_log_filter, filter=_log_filter, # type: ignore[arg-type]
colorize=False, colorize=False,
) )
else: else:
@@ -84,7 +84,7 @@ else:
sys.stdout, sys.stdout,
format=CONSOLE_FORMAT_DEV, format=CONSOLE_FORMAT_DEV,
level=LOG_LEVEL, level=LOG_LEVEL,
filter=_log_filter, filter=_log_filter, # type: ignore[arg-type]
colorize=True, colorize=True,
) )
@@ -97,7 +97,7 @@ if not DISABLE_FILE_LOG:
log_dir / "app.log", log_dir / "app.log",
format=FILE_FORMAT, format=FILE_FORMAT,
level="DEBUG", level="DEBUG",
filter=_log_filter, filter=_log_filter, # type: ignore[arg-type]
rotation="00:00", rotation="00:00",
retention="30 days", retention="30 days",
compression="gz", compression="gz",
@@ -110,7 +110,7 @@ if not DISABLE_FILE_LOG:
log_dir / "error.log", log_dir / "error.log",
format=FILE_FORMAT, format=FILE_FORMAT,
level="ERROR", level="ERROR",
filter=_log_filter, filter=_log_filter, # type: ignore[arg-type]
rotation="00:00", rotation="00:00",
retention="30 days", retention="30 days",
compression="gz", compression="gz",

View File

@@ -6,7 +6,7 @@
import time import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, Optional from typing import Any, Dict, Optional
class ProviderHealthTracker: class ProviderHealthTracker:
@@ -32,7 +32,7 @@ class ProviderHealthTracker:
# 存储优先级调整 # 存储优先级调整
self.priority_adjustments: Dict[str, int] = {} self.priority_adjustments: Dict[str, int] = {}
def record_success(self, provider_name: str): def record_success(self, provider_name: str) -> None:
"""记录成功的请求""" """记录成功的请求"""
current_time = time.time() current_time = time.time()
@@ -47,7 +47,7 @@ class ProviderHealthTracker:
if self.priority_adjustments.get(provider_name, 0) < 0: if self.priority_adjustments.get(provider_name, 0) < 0:
self.priority_adjustments[provider_name] += 1 self.priority_adjustments[provider_name] += 1
def record_failure(self, provider_name: str): def record_failure(self, provider_name: str) -> None:
"""记录失败的请求""" """记录失败的请求"""
current_time = time.time() current_time = time.time()
@@ -93,7 +93,7 @@ class ProviderHealthTracker:
"status": self._get_status_label(failure_rate, recent_failures), "status": self._get_status_label(failure_rate, recent_failures),
} }
def _cleanup_old_records(self, provider_name: str, current_time: float): def _cleanup_old_records(self, provider_name: str, current_time: float) -> None:
"""清理超出时间窗口的记录""" """清理超出时间窗口的记录"""
# 清理失败记录 # 清理失败记录
self.failures[provider_name] = [ self.failures[provider_name] = [
@@ -130,7 +130,7 @@ class ProviderHealthTracker:
adjustment = self.get_priority_adjustment(provider_name) adjustment = self.get_priority_adjustment(provider_name)
return adjustment > -3 return adjustment > -3
def reset_provider_health(self, provider_name: str): def reset_provider_health(self, provider_name: str) -> None:
"""重置提供商的健康状态(管理员手动操作)""" """重置提供商的健康状态(管理员手动操作)"""
self.failures[provider_name] = [] self.failures[provider_name] = []
self.successes[provider_name] = [] self.successes[provider_name] = []
@@ -146,7 +146,7 @@ class SimpleProviderSelector:
def __init__(self, health_tracker: ProviderHealthTracker): def __init__(self, health_tracker: ProviderHealthTracker):
self.health_tracker = health_tracker self.health_tracker = health_tracker
def select_provider(self, providers: list, specified_provider: Optional[str] = None): def select_provider(self, providers: list, specified_provider: Optional[str] = None) -> Any:
""" """
选择提供商 选择提供商

View File

@@ -5,12 +5,12 @@
import asyncio import asyncio
from functools import wraps from functools import wraps
from typing import Any, Callable, TypeVar from typing import Any, Callable, Coroutine, TypeVar
T = TypeVar("T") T = TypeVar("T")
async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T: async def run_in_executor(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
""" """
在线程池中运行同步函数,避免阻塞事件循环 在线程池中运行同步函数,避免阻塞事件循环
@@ -21,7 +21,7 @@ async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T:
return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]: def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
""" """
装饰器:包装同步数据库函数为异步函数 装饰器:包装同步数据库函数为异步函数
@@ -35,7 +35,7 @@ def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]:
""" """
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args: Any, **kwargs: Any) -> T:
return await run_in_executor(func, *args, **kwargs) return await run_in_executor(func, *args, **kwargs)
return wrapper return wrapper

View File

@@ -347,7 +347,7 @@ def init_default_models(db: Session):
"""初始化默认模型配置""" """初始化默认模型配置"""
# 注意:作为中转代理服务,不再预设模型配置 # 注意:作为中转代理服务,不再预设模型配置
# 模型配置应该通过 Model 和 ModelMapping 表动态管理 # 模型配置应该通过 GlobalModel 和 Model 表动态管理
# 这个函数保留用于未来可能的默认模型初始化 # 这个函数保留用于未来可能的默认模型初始化
pass pass

View File

@@ -2,4 +2,4 @@
中间件模块 中间件模块
""" """
__all__ = [] __all__: list[str] = []

View File

@@ -334,7 +334,6 @@ class ProviderResponse(BaseModel):
updated_at: datetime updated_at: datetime
models_count: int = 0 models_count: int = 0
active_models_count: int = 0 active_models_count: int = 0
model_mappings_count: int = 0
api_keys_count: int = 0 api_keys_count: int = 0
class Config: class Config:
@@ -346,7 +345,11 @@ class ModelCreate(BaseModel):
"""创建模型请求 - 价格和能力字段可选,为空时使用 GlobalModel 默认值""" """创建模型请求 - 价格和能力字段可选,为空时使用 GlobalModel 默认值"""
provider_model_name: str = Field( provider_model_name: str = Field(
..., min_length=1, max_length=200, description="Provider 侧的模型名称" ..., min_length=1, max_length=200, description="Provider 侧的模型名称"
)
provider_model_aliases: Optional[List[dict]] = Field(
None,
description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
) )
global_model_id: str = Field(..., description="关联的 GlobalModel ID必填") global_model_id: str = Field(..., description="关联的 GlobalModel ID必填")
# 按次计费配置 - 可选,为空时使用 GlobalModel 默认值 # 按次计费配置 - 可选,为空时使用 GlobalModel 默认值
@@ -374,6 +377,10 @@ class ModelUpdate(BaseModel):
"""更新模型请求""" """更新模型请求"""
provider_model_name: Optional[str] = Field(None, min_length=1, max_length=200) provider_model_name: Optional[str] = Field(None, min_length=1, max_length=200)
provider_model_aliases: Optional[List[dict]] = Field(
None,
description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
)
global_model_id: Optional[str] = None global_model_id: Optional[str] = None
# 按次计费配置 # 按次计费配置
price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用") price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
@@ -398,6 +405,7 @@ class ModelResponse(BaseModel):
provider_id: str provider_id: str
global_model_id: Optional[str] global_model_id: Optional[str]
provider_model_name: str provider_model_name: str
provider_model_aliases: Optional[List[dict]] = None
# 按次计费配置 # 按次计费配置
price_per_request: Optional[float] = None price_per_request: Optional[float] = None
@@ -465,54 +473,6 @@ class ModelDetailResponse(BaseModel):
from_attributes = True from_attributes = True
# ========== 模型映射 ==========
class ModelMappingCreate(BaseModel):
"""创建模型映射请求(源模型到目标模型的映射)"""
source_model: str = Field(..., min_length=1, max_length=200, description="源模型名或别名")
target_global_model_id: str = Field(..., description="目标 GlobalModel ID")
provider_id: Optional[str] = Field(None, description="Provider ID为空时表示全局别名")
mapping_type: str = Field(
"alias",
description="映射类型alias=按目标模型计费别名mapping=按源模型计费(降级映射)",
)
is_active: bool = Field(True, description="是否启用")
class ModelMappingUpdate(BaseModel):
"""更新模型映射请求"""
source_model: Optional[str] = Field(
None, min_length=1, max_length=200, description="源模型名或别名"
)
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
provider_id: Optional[str] = Field(None, description="Provider ID为空时表示全局别名")
mapping_type: Optional[str] = Field(
None, description="映射类型alias=按目标模型计费别名mapping=按源模型计费(降级映射)"
)
is_active: Optional[bool] = None
class ModelMappingResponse(BaseModel):
"""模型映射响应"""
id: str
source_model: str
target_global_model_id: str
target_global_model_name: Optional[str]
target_global_model_display_name: Optional[str]
provider_id: Optional[str]
provider_name: Optional[str]
scope: str = Field(..., description="global 或 provider")
mapping_type: str = Field(..., description="映射类型alias 或 mapping")
is_active: bool
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# ========== 系统设置 ========== # ========== 系统设置 ==========
class SystemSettingsRequest(BaseModel): class SystemSettingsRequest(BaseModel):
"""系统设置请求""" """系统设置请求"""
@@ -558,7 +518,6 @@ class PublicProviderResponse(BaseModel):
# 统计信息 # 统计信息
models_count: int models_count: int
active_models_count: int active_models_count: int
mappings_count: int
endpoints_count: int # 端点总数 endpoints_count: int # 端点总数
active_endpoints_count: int # 活跃端点数 active_endpoints_count: int # 活跃端点数
@@ -587,19 +546,6 @@ class PublicModelResponse(BaseModel):
is_active: bool = True is_active: bool = True
class PublicModelMappingResponse(BaseModel):
"""公开的模型映射信息响应"""
id: str
source_model: str
target_global_model_id: str
target_global_model_name: Optional[str]
target_global_model_display_name: Optional[str]
provider_id: Optional[str] = None
scope: str = Field(..., description="global 或 provider")
is_active: bool
class ProviderStatsResponse(BaseModel): class ProviderStatsResponse(BaseModel):
"""提供商统计信息响应""" """提供商统计信息响应"""
@@ -607,7 +553,6 @@ class ProviderStatsResponse(BaseModel):
active_providers: int active_providers: int
total_models: int total_models: int
active_models: int active_models: int
total_mappings: int
supported_formats: List[str] supported_formats: List[str]

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
class SSEEventParser: class SSEEventParser:
@@ -8,7 +8,7 @@ class SSEEventParser:
self._reset_buffer() self._reset_buffer()
def _reset_buffer(self) -> None: def _reset_buffer(self) -> None:
self._buffer: Dict[str, Optional[str] | List[str]] = { self._buffer: Dict[str, Union[Optional[str], List[str]]] = {
"event": None, "event": None,
"data": [], "data": [],
"id": None, "id": None,
@@ -17,16 +17,19 @@ class SSEEventParser:
def _finalize_event(self) -> Optional[Dict[str, Optional[str]]]: def _finalize_event(self) -> Optional[Dict[str, Optional[str]]]:
data_lines = self._buffer.get("data", []) data_lines = self._buffer.get("data", [])
if not data_lines: if not isinstance(data_lines, list) or not data_lines:
self._reset_buffer() self._reset_buffer()
return None return None
data_str = "\n".join(data_lines) data_str = "\n".join(data_lines)
event = { event_val = self._buffer.get("event")
"event": self._buffer.get("event"), id_val = self._buffer.get("id")
retry_val = self._buffer.get("retry")
event: Dict[str, Optional[str]] = {
"event": event_val if isinstance(event_val, str) else None,
"data": data_str, "data": data_str,
"id": self._buffer.get("id"), "id": id_val if isinstance(id_val, str) else None,
"retry": self._buffer.get("retry"), "retry": retry_val if isinstance(retry_val, str) else None,
} }
self._reset_buffer() self._reset_buffer()