refactor: 重构活跃请求查询逻辑到 UsageService

- 在 UsageService 新增 get_active_requests 方法,统一处理活跃请求查询
- 支持自动清理超时的 pending 请求(默认 5 分钟)
- admin 和 user 接口均复用该方法,减少重复代码
- 支持按 ID 列表查询或查询所有活跃请求
This commit is contained in:
fawney19
2025-12-11 10:04:15 +08:00
parent 6016f08d1c
commit 913a87d7f3
3 changed files with 88 additions and 77 deletions

View File

@@ -651,42 +651,17 @@ class AdminActiveRequestsAdapter(AdminApiAdapter):
self.ids = ids
async def handle(self, context): # type: ignore[override]
db = context.db
from src.services.usage import UsageService
db = context.db
id_list = None
if self.ids:
# 查询指定 ID 的请求状态
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
if not id_list:
return {"requests": []}
records = (
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
.filter(Usage.id.in_(id_list))
.all()
)
else:
# 查询所有活跃请求pending 或 streaming
records = (
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
.filter(Usage.status.in_(["pending", "streaming"]))
.order_by(Usage.created_at.desc())
.limit(50)
.all()
)
return {
"requests": [
{
"id": r.id,
"status": r.status,
"input_tokens": r.input_tokens,
"output_tokens": r.output_tokens,
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
"response_time_ms": r.response_time_ms,
}
for r in records
]
}
requests = UsageService.get_active_requests(db=db, ids=id_list)
return {"requests": requests}
@dataclass

View File

@@ -662,60 +662,18 @@ class GetActiveRequestsAdapter(AuthenticatedApiAdapter):
ids: Optional[str] = None
async def handle(self, context): # type: ignore[override]
from src.services.usage import UsageService
db = context.db
user = context.user
id_list = None
if self.ids:
# 查询指定 ID 的请求状态(只能查询自己的)
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
if not id_list:
return {"requests": []}
records = (
db.query(
Usage.id,
Usage.status,
Usage.input_tokens,
Usage.output_tokens,
Usage.total_cost_usd,
Usage.response_time_ms,
)
.filter(Usage.id.in_(id_list), Usage.user_id == user.id)
.all()
)
else:
# 查询所有活跃请求pending 或 streaming
records = (
db.query(
Usage.id,
Usage.status,
Usage.input_tokens,
Usage.output_tokens,
Usage.total_cost_usd,
Usage.response_time_ms,
)
.filter(
Usage.user_id == user.id,
Usage.status.in_(["pending", "streaming"]),
)
.order_by(Usage.created_at.desc())
.limit(50)
.all()
)
return {
"requests": [
{
"id": r.id,
"status": r.status,
"input_tokens": r.input_tokens,
"output_tokens": r.output_tokens,
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
"response_time_ms": r.response_time_ms,
}
for r in records
]
}
requests = UsageService.get_active_requests(db=db, ids=id_list, user_id=user.id)
return {"requests": requests}
class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):

View File

@@ -1304,3 +1304,81 @@ class UsageService:
)
.count()
)
@classmethod
def get_active_requests(
cls,
db: Session,
ids: Optional[List[str]] = None,
user_id: Optional[str] = None,
timeout_seconds: int = 300,
) -> List[Dict[str, Any]]:
"""
获取活跃请求状态,并自动清理超时的 pending 请求
Args:
db: 数据库会话
ids: 指定要查询的请求 ID 列表(可选)
user_id: 限制只查询该用户的请求(可选,用于普通用户接口)
timeout_seconds: pending 状态超时时间(秒),默认 5 分钟
Returns:
请求状态列表
"""
now = datetime.now(timezone.utc)
# 构建基础查询
query = db.query(
Usage.id,
Usage.status,
Usage.input_tokens,
Usage.output_tokens,
Usage.total_cost_usd,
Usage.response_time_ms,
Usage.created_at,
)
if ids:
query = query.filter(Usage.id.in_(ids))
if user_id:
query = query.filter(Usage.user_id == user_id)
else:
# 查询所有活跃请求
query = query.filter(Usage.status.in_(["pending", "streaming"]))
if user_id:
query = query.filter(Usage.user_id == user_id)
query = query.order_by(Usage.created_at.desc()).limit(50)
records = query.all()
# 检查超时的 pending 请求
timeout_ids = []
for r in records:
if r.status == "pending" and r.created_at:
# 处理时区:如果 created_at 没有时区信息,假定为 UTC
created_at = r.created_at
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
elapsed = (now - created_at).total_seconds()
if elapsed > timeout_seconds:
timeout_ids.append(r.id)
# 批量更新超时的请求
if timeout_ids:
db.query(Usage).filter(Usage.id.in_(timeout_ids)).update(
{"status": "failed", "error_message": "请求超时(服务器可能已重启)"},
synchronize_session=False,
)
db.commit()
return [
{
"id": r.id,
"status": "failed" if r.id in timeout_ids else r.status,
"input_tokens": r.input_tokens,
"output_tokens": r.output_tokens,
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
"response_time_ms": r.response_time_ms,
}
for r in records
]