refactor: 完善 handler 基类类型注解和流式状态更新

- 为 BaseMessageHandler 和 MessageTelemetry 添加完整类型注解
- 新增 _update_usage_to_streaming 方法,异步更新 Usage 状态为 streaming
- 优化 chat/cli handler 的类型提示,提升代码可维护性
- 修复类型检查警告,确保 mypy 通过
This commit is contained in:
fawney19
2025-12-11 10:05:06 +08:00
parent 913a87d7f3
commit 0474f63403
3 changed files with 140 additions and 74 deletions

View File

@@ -50,7 +50,9 @@ class MessageTelemetry:
负责记录 Usage/Audit避免处理器里重复代码。
"""
def __init__(self, db: Session, user, api_key, request_id: str, client_ip: str):
def __init__(
self, db: Session, user: Any, api_key: Any, request_id: str, client_ip: str
) -> None:
self.db = db
self.user = user
self.api_key = api_key
@@ -187,7 +189,7 @@ class MessageTelemetry:
response_body: Optional[Dict[str, Any]] = None,
# 模型映射信息
target_model: Optional[str] = None,
):
) -> None:
"""
记录失败请求
@@ -283,15 +285,15 @@ class BaseMessageHandler:
self,
*,
db: Session,
user,
api_key,
user: Any,
api_key: Any,
request_id: str,
client_ip: str,
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list[str]] = None,
adapter_detector: Optional[AdapterDetectorType] = None,
):
) -> None:
self.db = db
self.user = user
self.api_key = api_key
@@ -304,7 +306,7 @@ class BaseMessageHandler:
self.adapter_detector = adapter_detector
redis_client = get_redis_client_sync()
self.orchestrator = FallbackOrchestrator(db, redis_client)
self.orchestrator = FallbackOrchestrator(db, redis_client) # type: ignore[arg-type]
self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip)
def elapsed_ms(self) -> int:
@@ -347,7 +349,8 @@ class BaseMessageHandler:
def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat:
"""根据 provider_type 解析 API 格式,未知类型默认 OPENAI"""
if provider_type:
return resolve_api_format(provider_type, default=APIFormat.OPENAI)
result = resolve_api_format(provider_type, default=APIFormat.OPENAI)
return result or APIFormat.OPENAI
return self.primary_api_format
def build_provider_payload(
@@ -361,3 +364,34 @@ class BaseMessageHandler:
if mapped_model:
payload["model"] = mapped_model
return payload
def _update_usage_to_streaming(self, request_id: Optional[str] = None) -> None:
"""更新 Usage 状态为 streaming流式传输开始时调用
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
Args:
request_id: 请求 ID如果不传则使用 self.request_id
"""
import asyncio
from src.database.database import get_db
target_request_id = request_id or self.request_id
async def _do_update() -> None:
try:
db_gen = get_db()
db = next(db_gen)
try:
UsageService.update_usage_status(
db=db,
request_id=target_request_id,
status="streaming",
)
finally:
db.close()
except Exception as e:
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
# 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update())