mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
refactor: 完善 handler 基类类型注解和流式状态更新
- 为 BaseMessageHandler 和 MessageTelemetry 添加完整类型注解 - 新增 _update_usage_to_streaming 方法,异步更新 Usage 状态为 streaming - 优化 chat/cli handler 的类型提示,提升代码可维护性 - 修复类型检查警告,确保 mypy 通过
This commit is contained in:
@@ -50,7 +50,9 @@ class MessageTelemetry:
|
|||||||
负责记录 Usage/Audit,避免处理器里重复代码。
|
负责记录 Usage/Audit,避免处理器里重复代码。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db: Session, user, api_key, request_id: str, client_ip: str):
|
def __init__(
|
||||||
|
self, db: Session, user: Any, api_key: Any, request_id: str, client_ip: str
|
||||||
|
) -> None:
|
||||||
self.db = db
|
self.db = db
|
||||||
self.user = user
|
self.user = user
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
@@ -187,7 +189,7 @@ class MessageTelemetry:
|
|||||||
response_body: Optional[Dict[str, Any]] = None,
|
response_body: Optional[Dict[str, Any]] = None,
|
||||||
# 模型映射信息
|
# 模型映射信息
|
||||||
target_model: Optional[str] = None,
|
target_model: Optional[str] = None,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
记录失败请求
|
记录失败请求
|
||||||
|
|
||||||
@@ -283,15 +285,15 @@ class BaseMessageHandler:
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
db: Session,
|
db: Session,
|
||||||
user,
|
user: Any,
|
||||||
api_key,
|
api_key: Any,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
client_ip: str,
|
client_ip: str,
|
||||||
user_agent: str,
|
user_agent: str,
|
||||||
start_time: float,
|
start_time: float,
|
||||||
allowed_api_formats: Optional[list[str]] = None,
|
allowed_api_formats: Optional[list[str]] = None,
|
||||||
adapter_detector: Optional[AdapterDetectorType] = None,
|
adapter_detector: Optional[AdapterDetectorType] = None,
|
||||||
):
|
) -> None:
|
||||||
self.db = db
|
self.db = db
|
||||||
self.user = user
|
self.user = user
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
@@ -304,7 +306,7 @@ class BaseMessageHandler:
|
|||||||
self.adapter_detector = adapter_detector
|
self.adapter_detector = adapter_detector
|
||||||
|
|
||||||
redis_client = get_redis_client_sync()
|
redis_client = get_redis_client_sync()
|
||||||
self.orchestrator = FallbackOrchestrator(db, redis_client)
|
self.orchestrator = FallbackOrchestrator(db, redis_client) # type: ignore[arg-type]
|
||||||
self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip)
|
self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip)
|
||||||
|
|
||||||
def elapsed_ms(self) -> int:
|
def elapsed_ms(self) -> int:
|
||||||
@@ -347,7 +349,8 @@ class BaseMessageHandler:
|
|||||||
def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat:
|
def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat:
|
||||||
"""根据 provider_type 解析 API 格式,未知类型默认 OPENAI"""
|
"""根据 provider_type 解析 API 格式,未知类型默认 OPENAI"""
|
||||||
if provider_type:
|
if provider_type:
|
||||||
return resolve_api_format(provider_type, default=APIFormat.OPENAI)
|
result = resolve_api_format(provider_type, default=APIFormat.OPENAI)
|
||||||
|
return result or APIFormat.OPENAI
|
||||||
return self.primary_api_format
|
return self.primary_api_format
|
||||||
|
|
||||||
def build_provider_payload(
|
def build_provider_payload(
|
||||||
@@ -361,3 +364,34 @@ class BaseMessageHandler:
|
|||||||
if mapped_model:
|
if mapped_model:
|
||||||
payload["model"] = mapped_model
|
payload["model"] = mapped_model
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
def _update_usage_to_streaming(self, request_id: Optional[str] = None) -> None:
|
||||||
|
"""更新 Usage 状态为 streaming(流式传输开始时调用)
|
||||||
|
|
||||||
|
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: 请求 ID,如果不传则使用 self.request_id
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from src.database.database import get_db
|
||||||
|
|
||||||
|
target_request_id = request_id or self.request_id
|
||||||
|
|
||||||
|
async def _do_update() -> None:
|
||||||
|
try:
|
||||||
|
db_gen = get_db()
|
||||||
|
db = next(db_gen)
|
||||||
|
try:
|
||||||
|
UsageService.update_usage_status(
|
||||||
|
db=db,
|
||||||
|
request_id=target_request_id,
|
||||||
|
status="streaming",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
|
||||||
|
|
||||||
|
# 创建后台任务,不阻塞当前流
|
||||||
|
asyncio.create_task(_do_update())
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
# ==================== 抽象方法 ====================
|
# ==================== 抽象方法 ====================
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _convert_request(self, request):
|
async def _convert_request(self, request: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
将请求转换为目标格式
|
将请求转换为目标格式
|
||||||
|
|
||||||
@@ -261,7 +261,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
mapping = await mapper.get_mapping(source_model, provider_id)
|
mapping = await mapper.get_mapping(source_model, provider_id)
|
||||||
|
|
||||||
if mapping and mapping.model:
|
if mapping and mapping.model:
|
||||||
mapped_name = mapping.model.provider_model_name
|
mapped_name = str(mapping.model.provider_model_name)
|
||||||
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
||||||
return mapped_name
|
return mapped_name
|
||||||
|
|
||||||
@@ -271,10 +271,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
async def process_stream(
|
async def process_stream(
|
||||||
self,
|
self,
|
||||||
request,
|
request: Any,
|
||||||
http_request: Request,
|
http_request: Request,
|
||||||
original_headers: Dict,
|
original_headers: Dict[str, Any],
|
||||||
original_request_body: Dict,
|
original_request_body: Dict[str, Any],
|
||||||
query_params: Optional[Dict[str, str]] = None,
|
query_params: Optional[Dict[str, str]] = None,
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
"""处理流式响应"""
|
"""处理流式响应"""
|
||||||
@@ -315,7 +315,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
provider: Provider,
|
provider: Provider,
|
||||||
endpoint: ProviderEndpoint,
|
endpoint: ProviderEndpoint,
|
||||||
key: ProviderAPIKey,
|
key: ProviderAPIKey,
|
||||||
):
|
) -> AsyncGenerator[bytes, None]:
|
||||||
return await self._execute_stream_request(
|
return await self._execute_stream_request(
|
||||||
ctx,
|
ctx,
|
||||||
provider,
|
provider,
|
||||||
@@ -401,16 +401,16 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
ctx["cached_tokens"] = 0
|
ctx["cached_tokens"] = 0
|
||||||
ctx["cache_creation_tokens"] = 0
|
ctx["cache_creation_tokens"] = 0
|
||||||
|
|
||||||
ctx["provider_name"] = provider.name
|
ctx["provider_name"] = str(provider.name)
|
||||||
ctx["provider_id"] = provider.id
|
ctx["provider_id"] = str(provider.id)
|
||||||
ctx["endpoint_id"] = endpoint.id
|
ctx["endpoint_id"] = str(endpoint.id)
|
||||||
ctx["key_id"] = key.id
|
ctx["key_id"] = str(key.id)
|
||||||
ctx["provider_api_format"] = endpoint.api_format # Provider 的响应格式
|
ctx["provider_api_format"] = str(endpoint.api_format) if endpoint.api_format else ""
|
||||||
|
|
||||||
# 获取模型映射
|
# 获取模型映射
|
||||||
mapped_model = await self._get_mapped_model(
|
mapped_model = await self._get_mapped_model(
|
||||||
source_model=ctx["model"],
|
source_model=ctx["model"],
|
||||||
provider_id=provider.id,
|
provider_id=str(provider.id),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 应用模型映射到请求体
|
# 应用模型映射到请求体
|
||||||
@@ -514,14 +514,20 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
"""创建响应流生成器"""
|
"""创建响应流生成器"""
|
||||||
try:
|
try:
|
||||||
sse_parser = SSEEventParser()
|
sse_parser = SSEEventParser()
|
||||||
|
streaming_status_updated = False
|
||||||
|
|
||||||
async for line in stream_response.aiter_lines():
|
async for line in stream_response.aiter_lines():
|
||||||
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
|
if not streaming_status_updated:
|
||||||
|
self._update_usage_to_streaming()
|
||||||
|
streaming_status_updated = True
|
||||||
|
|
||||||
normalized_line = line.rstrip("\r")
|
normalized_line = line.rstrip("\r")
|
||||||
events = sse_parser.feed_line(normalized_line)
|
events = sse_parser.feed_line(normalized_line)
|
||||||
|
|
||||||
if normalized_line == "":
|
if normalized_line == "":
|
||||||
for event in events:
|
for event in events:
|
||||||
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
|
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||||
yield b"\n"
|
yield b"\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -530,11 +536,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
yield (line + "\n").encode("utf-8")
|
yield (line + "\n").encode("utf-8")
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
|
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||||
|
|
||||||
# 处理剩余事件
|
# 处理剩余事件
|
||||||
for event in sse_parser.flush():
|
for event in sse_parser.flush():
|
||||||
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
|
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||||
|
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
raise
|
raise
|
||||||
@@ -618,7 +624,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
f"error_type={parsed.error_type}, "
|
f"error_type={parsed.error_type}, "
|
||||||
f"message={parsed.error_message}")
|
f"message={parsed.error_message}")
|
||||||
raise EmbeddedErrorException(
|
raise EmbeddedErrorException(
|
||||||
provider_name=provider.name,
|
provider_name=str(provider.name),
|
||||||
error_code=(
|
error_code=(
|
||||||
int(parsed.error_type)
|
int(parsed.error_type)
|
||||||
if parsed.error_type and parsed.error_type.isdigit()
|
if parsed.error_type and parsed.error_type.isdigit()
|
||||||
@@ -652,6 +658,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
try:
|
try:
|
||||||
sse_parser = SSEEventParser()
|
sse_parser = SSEEventParser()
|
||||||
|
|
||||||
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
|
if prefetched_lines:
|
||||||
|
self._update_usage_to_streaming()
|
||||||
|
|
||||||
# 先输出预读的数据
|
# 先输出预读的数据
|
||||||
for line in prefetched_lines:
|
for line in prefetched_lines:
|
||||||
normalized_line = line.rstrip("\r")
|
normalized_line = line.rstrip("\r")
|
||||||
@@ -659,7 +669,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
if normalized_line == "":
|
if normalized_line == "":
|
||||||
for event in events:
|
for event in events:
|
||||||
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
|
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||||
yield b"\n"
|
yield b"\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -667,7 +677,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
yield (line + "\n").encode("utf-8")
|
yield (line + "\n").encode("utf-8")
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
|
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||||
|
|
||||||
# 继续输出剩余的流数据(使用同一个迭代器)
|
# 继续输出剩余的流数据(使用同一个迭代器)
|
||||||
async for line in line_iterator:
|
async for line in line_iterator:
|
||||||
@@ -676,7 +686,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
if normalized_line == "":
|
if normalized_line == "":
|
||||||
for event in events:
|
for event in events:
|
||||||
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
|
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||||
yield b"\n"
|
yield b"\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -684,11 +694,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
yield (line + "\n").encode("utf-8")
|
yield (line + "\n").encode("utf-8")
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
|
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||||
|
|
||||||
# 处理剩余事件
|
# 处理剩余事件
|
||||||
for event in sse_parser.flush():
|
for event in sse_parser.flush():
|
||||||
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
|
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||||
|
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
raise
|
raise
|
||||||
@@ -853,8 +863,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
# 根据状态码决定记录成功还是失败
|
# 根据状态码决定记录成功还是失败
|
||||||
# 499 = 客户端断开连接,503 = 服务不可用(如流中断)
|
# 499 = 客户端断开连接,503 = 服务不可用(如流中断)
|
||||||
status_code = ctx.get("status_code")
|
status_code: int = ctx.get("status_code") or 200
|
||||||
if status_code and status_code >= 400:
|
if status_code >= 400:
|
||||||
# 记录失败的 Usage,但使用已收到的预估 token 信息(来自 message_start)
|
# 记录失败的 Usage,但使用已收到的预估 token 信息(来自 message_start)
|
||||||
# 这样即使请求中断,也能记录预估成本
|
# 这样即使请求中断,也能记录预估成本
|
||||||
await bg_telemetry.record_failure(
|
await bg_telemetry.record_failure(
|
||||||
@@ -955,6 +965,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
error_message=f"记录统计信息失败: {str(e)[:200]}",
|
error_message=f"记录统计信息失败: {str(e)[:200]}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
|
||||||
|
|
||||||
async def _update_usage_status_on_error(
|
async def _update_usage_status_on_error(
|
||||||
self,
|
self,
|
||||||
response_time_ms: int,
|
response_time_ms: int,
|
||||||
@@ -991,11 +1003,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
|
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
|
||||||
if usage:
|
if usage:
|
||||||
usage.status = status
|
setattr(usage, "status", status)
|
||||||
usage.status_code = status_code
|
setattr(usage, "status_code", status_code)
|
||||||
usage.response_time_ms = response_time_ms
|
setattr(usage, "response_time_ms", response_time_ms)
|
||||||
if error_message:
|
if error_message:
|
||||||
usage.error_message = error_message
|
setattr(usage, "error_message", error_message)
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
|
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1040,10 +1052,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
async def process_sync(
|
async def process_sync(
|
||||||
self,
|
self,
|
||||||
request,
|
request: Any,
|
||||||
http_request: Request,
|
http_request: Request,
|
||||||
original_headers: Dict,
|
original_headers: Dict[str, Any],
|
||||||
original_request_body: Dict,
|
original_request_body: Dict[str, Any],
|
||||||
query_params: Optional[Dict[str, str]] = None,
|
query_params: Optional[Dict[str, str]] = None,
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""处理非流式响应"""
|
"""处理非流式响应"""
|
||||||
@@ -1055,31 +1067,31 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
api_format = self.allowed_api_formats[0]
|
api_format = self.allowed_api_formats[0]
|
||||||
|
|
||||||
# 用于跟踪的变量
|
# 用于跟踪的变量
|
||||||
provider_name = None
|
provider_name: Optional[str] = None
|
||||||
response_json = None
|
response_json: Optional[Dict[str, Any]] = None
|
||||||
status_code = 200
|
status_code = 200
|
||||||
response_headers = {}
|
response_headers: Dict[str, str] = {}
|
||||||
provider_request_headers = {}
|
provider_request_headers: Dict[str, str] = {}
|
||||||
provider_request_body = None
|
provider_request_body: Optional[Dict[str, Any]] = None
|
||||||
provider_id = None # Provider ID(用于失败记录)
|
provider_id: Optional[str] = None # Provider ID(用于失败记录)
|
||||||
endpoint_id = None # Endpoint ID(用于失败记录)
|
endpoint_id: Optional[str] = None # Endpoint ID(用于失败记录)
|
||||||
key_id = None # Key ID(用于失败记录)
|
key_id: Optional[str] = None # Key ID(用于失败记录)
|
||||||
mapped_model_result = None # 映射后的目标模型名(用于 Usage 记录)
|
mapped_model_result: Optional[str] = None # 映射后的目标模型名(用于 Usage 记录)
|
||||||
|
|
||||||
async def sync_request_func(
|
async def sync_request_func(
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
endpoint: ProviderEndpoint,
|
endpoint: ProviderEndpoint,
|
||||||
key: ProviderAPIKey,
|
key: ProviderAPIKey,
|
||||||
):
|
) -> Dict[str, Any]:
|
||||||
nonlocal provider_name, response_json, status_code, response_headers
|
nonlocal provider_name, response_json, status_code, response_headers
|
||||||
nonlocal provider_request_headers, provider_request_body, mapped_model_result
|
nonlocal provider_request_headers, provider_request_body, mapped_model_result
|
||||||
|
|
||||||
provider_name = provider.name
|
provider_name = str(provider.name)
|
||||||
|
|
||||||
# 获取模型映射
|
# 获取模型映射
|
||||||
mapped_model = await self._get_mapped_model(
|
mapped_model = await self._get_mapped_model(
|
||||||
source_model=model,
|
source_model=model,
|
||||||
provider_id=provider.id,
|
provider_id=str(provider.id),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 应用模型映射
|
# 应用模型映射
|
||||||
@@ -1131,7 +1143,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
elif resp.status_code == 429:
|
elif resp.status_code == 429:
|
||||||
raise ProviderRateLimitException(
|
raise ProviderRateLimitException(
|
||||||
f"提供商速率限制: {provider.name}",
|
f"提供商速率限制: {provider.name}",
|
||||||
provider_name=provider.name,
|
provider_name=str(provider.name),
|
||||||
response_headers=response_headers,
|
response_headers=response_headers,
|
||||||
)
|
)
|
||||||
elif resp.status_code >= 500:
|
elif resp.status_code >= 500:
|
||||||
@@ -1142,7 +1154,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
response_json = resp.json()
|
response_json = resp.json()
|
||||||
return response_json
|
return response_json if isinstance(response_json, dict) else {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析能力需求
|
# 解析能力需求
|
||||||
@@ -1170,6 +1182,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
provider_name = actual_provider_name
|
provider_name = actual_provider_name
|
||||||
response_time_ms = self.elapsed_ms()
|
response_time_ms = self.elapsed_ms()
|
||||||
|
|
||||||
|
# 确保 response_json 不为 None
|
||||||
|
if response_json is None:
|
||||||
|
response_json = {}
|
||||||
|
|
||||||
# 规范化响应
|
# 规范化响应
|
||||||
response_json = self._normalize_response(response_json)
|
response_json = self._normalize_response(response_json)
|
||||||
|
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
|
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
|
||||||
|
|
||||||
if mapping and mapping.model:
|
if mapping and mapping.model:
|
||||||
mapped_name = mapping.model.provider_model_name
|
mapped_name = str(mapping.model.provider_model_name)
|
||||||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||||||
return mapped_name
|
return mapped_name
|
||||||
|
|
||||||
@@ -351,7 +351,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
provider: Provider,
|
provider: Provider,
|
||||||
endpoint: ProviderEndpoint,
|
endpoint: ProviderEndpoint,
|
||||||
key: ProviderAPIKey,
|
key: ProviderAPIKey,
|
||||||
):
|
) -> AsyncGenerator[bytes, None]:
|
||||||
return await self._execute_stream_request(
|
return await self._execute_stream_request(
|
||||||
ctx,
|
ctx,
|
||||||
provider,
|
provider,
|
||||||
@@ -451,19 +451,19 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
ctx.response_metadata = {} # 重置 Provider 响应元数据
|
ctx.response_metadata = {} # 重置 Provider 响应元数据
|
||||||
|
|
||||||
# 记录 Provider 信息
|
# 记录 Provider 信息
|
||||||
ctx.provider_name = provider.name
|
ctx.provider_name = str(provider.name)
|
||||||
ctx.provider_id = provider.id
|
ctx.provider_id = str(provider.id)
|
||||||
ctx.endpoint_id = endpoint.id
|
ctx.endpoint_id = str(endpoint.id)
|
||||||
ctx.key_id = key.id
|
ctx.key_id = str(key.id)
|
||||||
|
|
||||||
# 记录格式转换信息
|
# 记录格式转换信息
|
||||||
ctx.provider_api_format = endpoint.api_format if endpoint.api_format else ""
|
ctx.provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
|
||||||
ctx.client_api_format = ctx.api_format # 已在 process_stream 中设置
|
ctx.client_api_format = ctx.api_format # 已在 process_stream 中设置
|
||||||
|
|
||||||
# 获取模型映射(别名/映射 → 实际模型名)
|
# 获取模型映射(别名/映射 → 实际模型名)
|
||||||
mapped_model = await self._get_mapped_model(
|
mapped_model = await self._get_mapped_model(
|
||||||
source_model=ctx.model,
|
source_model=ctx.model,
|
||||||
provider_id=provider.id,
|
provider_id=str(provider.id),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
|
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
|
||||||
@@ -575,11 +575,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
try:
|
try:
|
||||||
sse_parser = SSEEventParser()
|
sse_parser = SSEEventParser()
|
||||||
last_data_time = time.time()
|
last_data_time = time.time()
|
||||||
|
streaming_status_updated = False
|
||||||
|
|
||||||
# 检查是否需要格式转换
|
# 检查是否需要格式转换
|
||||||
needs_conversion = self._needs_format_conversion(ctx)
|
needs_conversion = self._needs_format_conversion(ctx)
|
||||||
|
|
||||||
async for line in stream_response.aiter_lines():
|
async for line in stream_response.aiter_lines():
|
||||||
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
|
if not streaming_status_updated:
|
||||||
|
self._update_usage_to_streaming(ctx.request_id)
|
||||||
|
streaming_status_updated = True
|
||||||
|
|
||||||
normalized_line = line.rstrip("\r")
|
normalized_line = line.rstrip("\r")
|
||||||
events = sse_parser.feed_line(normalized_line)
|
events = sse_parser.feed_line(normalized_line)
|
||||||
|
|
||||||
@@ -588,7 +594,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
self._handle_sse_event(
|
self._handle_sse_event(
|
||||||
ctx,
|
ctx,
|
||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data", ""),
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
yield b"\n"
|
yield b"\n"
|
||||||
continue
|
continue
|
||||||
@@ -622,7 +628,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
self._handle_sse_event(
|
self._handle_sse_event(
|
||||||
ctx,
|
ctx,
|
||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data", ""),
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
if ctx.data_count > 0:
|
if ctx.data_count > 0:
|
||||||
@@ -633,7 +639,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
self._handle_sse_event(
|
self._handle_sse_event(
|
||||||
ctx,
|
ctx,
|
||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data", ""),
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否收到数据
|
# 检查是否收到数据
|
||||||
@@ -781,7 +787,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
f"error_type={parsed.error_type}, "
|
f"error_type={parsed.error_type}, "
|
||||||
f"message={parsed.error_message}")
|
f"message={parsed.error_message}")
|
||||||
raise EmbeddedErrorException(
|
raise EmbeddedErrorException(
|
||||||
provider_name=provider.name,
|
provider_name=str(provider.name),
|
||||||
error_code=(
|
error_code=(
|
||||||
int(parsed.error_type)
|
int(parsed.error_type)
|
||||||
if parsed.error_type and parsed.error_type.isdigit()
|
if parsed.error_type and parsed.error_type.isdigit()
|
||||||
@@ -819,6 +825,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 检查是否需要格式转换
|
# 检查是否需要格式转换
|
||||||
needs_conversion = self._needs_format_conversion(ctx)
|
needs_conversion = self._needs_format_conversion(ctx)
|
||||||
|
|
||||||
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
|
if prefetched_lines:
|
||||||
|
self._update_usage_to_streaming(ctx.request_id)
|
||||||
|
|
||||||
# 先处理预读的数据
|
# 先处理预读的数据
|
||||||
for line in prefetched_lines:
|
for line in prefetched_lines:
|
||||||
normalized_line = line.rstrip("\r")
|
normalized_line = line.rstrip("\r")
|
||||||
@@ -829,7 +839,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
self._handle_sse_event(
|
self._handle_sse_event(
|
||||||
ctx,
|
ctx,
|
||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data", ""),
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
yield b"\n"
|
yield b"\n"
|
||||||
continue
|
continue
|
||||||
@@ -848,7 +858,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
self._handle_sse_event(
|
self._handle_sse_event(
|
||||||
ctx,
|
ctx,
|
||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data", ""),
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
if ctx.data_count > 0:
|
if ctx.data_count > 0:
|
||||||
@@ -864,7 +874,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
self._handle_sse_event(
|
self._handle_sse_event(
|
||||||
ctx,
|
ctx,
|
||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data", ""),
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
yield b"\n"
|
yield b"\n"
|
||||||
continue
|
continue
|
||||||
@@ -898,7 +908,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
self._handle_sse_event(
|
self._handle_sse_event(
|
||||||
ctx,
|
ctx,
|
||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data", ""),
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
if ctx.data_count > 0:
|
if ctx.data_count > 0:
|
||||||
@@ -910,7 +920,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
self._handle_sse_event(
|
self._handle_sse_event(
|
||||||
ctx,
|
ctx,
|
||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data", ""),
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否收到数据
|
# 检查是否收到数据
|
||||||
@@ -1270,6 +1280,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
target_model=ctx.mapped_model,
|
target_model=ctx.mapped_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
|
||||||
|
|
||||||
async def process_sync(
|
async def process_sync(
|
||||||
self,
|
self,
|
||||||
original_request_body: Dict[str, Any],
|
original_request_body: Dict[str, Any],
|
||||||
@@ -1309,15 +1321,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
provider: Provider,
|
provider: Provider,
|
||||||
endpoint: ProviderEndpoint,
|
endpoint: ProviderEndpoint,
|
||||||
key: ProviderAPIKey,
|
key: ProviderAPIKey,
|
||||||
):
|
) -> Dict[str, Any]:
|
||||||
nonlocal provider_name, response_json, status_code, response_headers, provider_api_format, provider_request_headers, provider_request_body, mapped_model_result, response_metadata_result
|
nonlocal provider_name, response_json, status_code, response_headers, provider_api_format, provider_request_headers, provider_request_body, mapped_model_result, response_metadata_result
|
||||||
provider_name = provider.name
|
provider_name = str(provider.name)
|
||||||
provider_api_format = endpoint.api_format if endpoint.api_format else ""
|
provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
|
||||||
|
|
||||||
# 获取模型映射(别名/映射 → 实际模型名)
|
# 获取模型映射(别名/映射 → 实际模型名)
|
||||||
mapped_model = await self._get_mapped_model(
|
mapped_model = await self._get_mapped_model(
|
||||||
source_model=model,
|
source_model=model,
|
||||||
provider_id=provider.id,
|
provider_id=str(provider.id),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
|
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
|
||||||
@@ -1373,7 +1385,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
elif resp.status_code == 429:
|
elif resp.status_code == 429:
|
||||||
raise ProviderRateLimitException(
|
raise ProviderRateLimitException(
|
||||||
f"提供商速率限制: {provider.name}",
|
f"提供商速率限制: {provider.name}",
|
||||||
provider_name=provider.name,
|
provider_name=str(provider.name),
|
||||||
response_headers=response_headers,
|
response_headers=response_headers,
|
||||||
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
||||||
)
|
)
|
||||||
@@ -1409,7 +1421,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 提取 Provider 响应元数据(子类可覆盖)
|
# 提取 Provider 响应元数据(子类可覆盖)
|
||||||
response_metadata_result = self._extract_response_metadata(response_json)
|
response_metadata_result = self._extract_response_metadata(response_json)
|
||||||
|
|
||||||
return response_json
|
return response_json if isinstance(response_json, dict) else {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析能力需求
|
# 解析能力需求
|
||||||
@@ -1437,6 +1449,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
provider_name = actual_provider_name
|
provider_name = actual_provider_name
|
||||||
response_time_ms = int((time.time() - sync_start_time) * 1000)
|
response_time_ms = int((time.time() - sync_start_time) * 1000)
|
||||||
|
|
||||||
|
# 确保 response_json 不为 None
|
||||||
|
if response_json is None:
|
||||||
|
response_json = {}
|
||||||
|
|
||||||
# 检查是否需要格式转换
|
# 检查是否需要格式转换
|
||||||
if (
|
if (
|
||||||
provider_api_format
|
provider_api_format
|
||||||
|
|||||||
Reference in New Issue
Block a user