refactor: 优化活跃请求状态查询逻辑

- 重命名 get_active_requests 为 get_active_requests_status
- 支持从端点配置读取超时时间
- 新增 content_length_limit 错误类型
This commit is contained in:
fawney19
2025-12-11 10:45:06 +08:00
parent 875f3d5f54
commit 323a514f77
4 changed files with 21 additions and 8 deletions

View File

@@ -660,7 +660,7 @@ class AdminActiveRequestsAdapter(AdminApiAdapter):
if not id_list:
return {"requests": []}
requests = UsageService.get_active_requests(db=db, ids=id_list)
requests = UsageService.get_active_requests_status(db=db, ids=id_list)
return {"requests": requests}

View File

@@ -672,7 +672,7 @@ class GetActiveRequestsAdapter(AuthenticatedApiAdapter):
if not id_list:
return {"requests": []}
requests = UsageService.get_active_requests(db=db, ids=id_list, user_id=user.id)
requests = UsageService.get_active_requests_status(db=db, ids=id_list, user_id=user.id)
return {"requests": requests}

View File

@@ -76,6 +76,7 @@ class ErrorClassifier:
"content_policy_violation", # 内容违规
"invalid_api_key", # 无效的 API Key不同于认证失败
"context_length_exceeded", # 上下文长度超限
"content_length_limit", # 请求内容长度超限 (Claude API)
"max_tokens", # token 数超限
"invalid_prompt", # 无效的提示词
"content too long", # 内容过长

View File

@@ -1306,28 +1306,35 @@ class UsageService:
)
@classmethod
def get_active_requests(
def get_active_requests_status(
cls,
db: Session,
ids: Optional[List[str]] = None,
user_id: Optional[str] = None,
timeout_seconds: int = 300,
default_timeout_seconds: int = 300,
) -> List[Dict[str, Any]]:
"""
获取活跃请求状态,并自动清理超时的 pending 请求
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending 请求
与 get_active_requests 不同,此方法:
1. 返回轻量级的状态字典而非完整 Usage 对象
2. 自动检测并清理超时的 pending 请求
3. 支持按 ID 列表查询特定请求
Args:
db: 数据库会话
ids: 指定要查询的请求 ID 列表(可选)
user_id: 限制只查询该用户的请求(可选,用于普通用户接口)
timeout_seconds: pending 状态超时时间(秒),默认 5 分钟
default_timeout_seconds: 默认超时时间(秒),当端点未配置时使用
Returns:
请求状态列表
"""
from src.models.database import ProviderEndpoint
now = datetime.now(timezone.utc)
# 构建基础查询
# 构建基础查询,包含端点的 timeout 配置
query = db.query(
Usage.id,
Usage.status,
@@ -1336,7 +1343,9 @@ class UsageService:
Usage.total_cost_usd,
Usage.response_time_ms,
Usage.created_at,
)
Usage.provider_endpoint_id,
ProviderEndpoint.timeout.label("endpoint_timeout"),
).outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
if ids:
query = query.filter(Usage.id.in_(ids))
@@ -1355,6 +1364,9 @@ class UsageService:
timeout_ids = []
for r in records:
if r.status == "pending" and r.created_at:
# 使用端点配置的超时时间,若无则使用默认值
timeout_seconds = r.endpoint_timeout or default_timeout_seconds
# 处理时区:如果 created_at 没有时区信息,假定为 UTC
created_at = r.created_at
if created_at.tzinfo is None: