mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 02:32:27 +08:00
refactor: 完善 handler 基类类型注解和流式状态更新
- 为 BaseMessageHandler 和 MessageTelemetry 添加完整类型注解 - 新增 _update_usage_to_streaming 方法,异步更新 Usage 状态为 streaming - 优化 chat/cli handler 的类型提示,提升代码可维护性 - 修复类型检查警告,确保 mypy 通过
This commit is contained in:
@@ -50,7 +50,9 @@ class MessageTelemetry:
|
||||
负责记录 Usage/Audit,避免处理器里重复代码。
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, user, api_key, request_id: str, client_ip: str):
|
||||
def __init__(
|
||||
self, db: Session, user: Any, api_key: Any, request_id: str, client_ip: str
|
||||
) -> None:
|
||||
self.db = db
|
||||
self.user = user
|
||||
self.api_key = api_key
|
||||
@@ -187,7 +189,7 @@ class MessageTelemetry:
|
||||
response_body: Optional[Dict[str, Any]] = None,
|
||||
# 模型映射信息
|
||||
target_model: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
记录失败请求
|
||||
|
||||
@@ -283,15 +285,15 @@ class BaseMessageHandler:
|
||||
self,
|
||||
*,
|
||||
db: Session,
|
||||
user,
|
||||
api_key,
|
||||
user: Any,
|
||||
api_key: Any,
|
||||
request_id: str,
|
||||
client_ip: str,
|
||||
user_agent: str,
|
||||
start_time: float,
|
||||
allowed_api_formats: Optional[list[str]] = None,
|
||||
adapter_detector: Optional[AdapterDetectorType] = None,
|
||||
):
|
||||
) -> None:
|
||||
self.db = db
|
||||
self.user = user
|
||||
self.api_key = api_key
|
||||
@@ -304,7 +306,7 @@ class BaseMessageHandler:
|
||||
self.adapter_detector = adapter_detector
|
||||
|
||||
redis_client = get_redis_client_sync()
|
||||
self.orchestrator = FallbackOrchestrator(db, redis_client)
|
||||
self.orchestrator = FallbackOrchestrator(db, redis_client) # type: ignore[arg-type]
|
||||
self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip)
|
||||
|
||||
def elapsed_ms(self) -> int:
|
||||
@@ -347,7 +349,8 @@ class BaseMessageHandler:
|
||||
def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat:
|
||||
"""根据 provider_type 解析 API 格式,未知类型默认 OPENAI"""
|
||||
if provider_type:
|
||||
return resolve_api_format(provider_type, default=APIFormat.OPENAI)
|
||||
result = resolve_api_format(provider_type, default=APIFormat.OPENAI)
|
||||
return result or APIFormat.OPENAI
|
||||
return self.primary_api_format
|
||||
|
||||
def build_provider_payload(
|
||||
@@ -361,3 +364,34 @@ class BaseMessageHandler:
|
||||
if mapped_model:
|
||||
payload["model"] = mapped_model
|
||||
return payload
|
||||
|
||||
def _update_usage_to_streaming(self, request_id: Optional[str] = None) -> None:
|
||||
"""更新 Usage 状态为 streaming(流式传输开始时调用)
|
||||
|
||||
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||
|
||||
Args:
|
||||
request_id: 请求 ID,如果不传则使用 self.request_id
|
||||
"""
|
||||
import asyncio
|
||||
from src.database.database import get_db
|
||||
|
||||
target_request_id = request_id or self.request_id
|
||||
|
||||
async def _do_update() -> None:
|
||||
try:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
UsageService.update_usage_status(
|
||||
db=db,
|
||||
request_id=target_request_id,
|
||||
status="streaming",
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
|
||||
|
||||
# 创建后台任务,不阻塞当前流
|
||||
asyncio.create_task(_do_update())
|
||||
|
||||
Reference in New Issue
Block a user