diff --git a/src/api/admin/usage/routes.py b/src/api/admin/usage/routes.py index c095554..893c119 100644 --- a/src/api/admin/usage/routes.py +++ b/src/api/admin/usage/routes.py @@ -651,42 +651,17 @@ class AdminActiveRequestsAdapter(AdminApiAdapter): self.ids = ids async def handle(self, context): # type: ignore[override] - db = context.db + from src.services.usage import UsageService + db = context.db + id_list = None if self.ids: - # 查询指定 ID 的请求状态 id_list = [id.strip() for id in self.ids.split(",") if id.strip()] if not id_list: return {"requests": []} - records = ( - db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms) - .filter(Usage.id.in_(id_list)) - .all() - ) - else: - # 查询所有活跃请求(pending 或 streaming) - records = ( - db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms) - .filter(Usage.status.in_(["pending", "streaming"])) - .order_by(Usage.created_at.desc()) - .limit(50) - .all() - ) - - return { - "requests": [ - { - "id": r.id, - "status": r.status, - "input_tokens": r.input_tokens, - "output_tokens": r.output_tokens, - "cost": float(r.total_cost_usd) if r.total_cost_usd else 0, - "response_time_ms": r.response_time_ms, - } - for r in records - ] - } + requests = UsageService.get_active_requests(db=db, ids=id_list) + return {"requests": requests} @dataclass diff --git a/src/api/user_me/routes.py b/src/api/user_me/routes.py index a52560b..9062135 100644 --- a/src/api/user_me/routes.py +++ b/src/api/user_me/routes.py @@ -662,60 +662,18 @@ class GetActiveRequestsAdapter(AuthenticatedApiAdapter): ids: Optional[str] = None async def handle(self, context): # type: ignore[override] + from src.services.usage import UsageService + db = context.db user = context.user - + id_list = None if self.ids: - # 查询指定 ID 的请求状态(只能查询自己的) id_list = [id.strip() for id in self.ids.split(",") if id.strip()] if not id_list: return {"requests": []} - records = ( - db.query( - Usage.id, - Usage.status, - Usage.input_tokens, - Usage.output_tokens, - Usage.total_cost_usd, - Usage.response_time_ms, - ) - .filter(Usage.id.in_(id_list), Usage.user_id == user.id) - .all() - ) - else: - # 查询所有活跃请求(pending 或 streaming) - records = ( - db.query( - Usage.id, - Usage.status, - Usage.input_tokens, - Usage.output_tokens, - Usage.total_cost_usd, - Usage.response_time_ms, - ) - .filter( - Usage.user_id == user.id, - Usage.status.in_(["pending", "streaming"]), - ) - .order_by(Usage.created_at.desc()) - .limit(50) - .all() - ) - - return { - "requests": [ - { - "id": r.id, - "status": r.status, - "input_tokens": r.input_tokens, - "output_tokens": r.output_tokens, - "cost": float(r.total_cost_usd) if r.total_cost_usd else 0, - "response_time_ms": r.response_time_ms, - } - for r in records - ] - } + requests = UsageService.get_active_requests(db=db, ids=id_list, user_id=user.id) + return {"requests": requests} class ListAvailableProvidersAdapter(AuthenticatedApiAdapter): diff --git a/src/services/usage/service.py b/src/services/usage/service.py index 706eb8d..d8ccb27 100644 --- a/src/services/usage/service.py +++ b/src/services/usage/service.py @@ -1304,3 +1304,81 @@ class UsageService: ) .count() ) + + @classmethod + def get_active_requests( + cls, + db: Session, + ids: Optional[List[str]] = None, + user_id: Optional[str] = None, + timeout_seconds: int = 300, + ) -> List[Dict[str, Any]]: + """ + 获取活跃请求状态,并自动清理超时的 pending 请求 + + Args: + db: 数据库会话 + ids: 指定要查询的请求 ID 列表(可选) + user_id: 限制只查询该用户的请求(可选,用于普通用户接口) + timeout_seconds: pending 状态超时时间(秒),默认 5 分钟 + + Returns: + 请求状态列表 + """ + now = datetime.now(timezone.utc) + + # 构建基础查询 + query = db.query( + Usage.id, + Usage.status, + Usage.input_tokens, + Usage.output_tokens, + Usage.total_cost_usd, + Usage.response_time_ms, + Usage.created_at, + ) + + if ids: + query = query.filter(Usage.id.in_(ids)) + if user_id: + query = query.filter(Usage.user_id == user_id) + else: + # 查询所有活跃请求 + query = query.filter(Usage.status.in_(["pending", "streaming"])) + if user_id: + query = query.filter(Usage.user_id == user_id) + query = query.order_by(Usage.created_at.desc()).limit(50) + + records = query.all() + + # 检查超时的 pending 请求 + timeout_ids = [] + for r in records: + if r.status == "pending" and r.created_at: + # 处理时区:如果 created_at 没有时区信息,假定为 UTC + created_at = r.created_at + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + elapsed = (now - created_at).total_seconds() + if elapsed > timeout_seconds: + timeout_ids.append(r.id) + + # 批量更新超时的请求 + if timeout_ids: + db.query(Usage).filter(Usage.id.in_(timeout_ids)).update( + {"status": "failed", "error_message": "请求超时(服务器可能已重启)"}, + synchronize_session=False, + ) + db.commit() + + return [ + { + "id": r.id, + "status": "failed" if r.id in timeout_ids else r.status, + "input_tokens": r.input_tokens, + "output_tokens": r.output_tokens, + "cost": float(r.total_cost_usd) if r.total_cost_usd else 0, + "response_time_ms": r.response_time_ms, + } + for r in records + ]