mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor: 重构活跃请求查询逻辑到 UsageService
- 在 UsageService 新增 get_active_requests 方法,统一处理活跃请求查询 - 支持自动清理超时的 pending 请求(默认 5 分钟) - admin 和 user 接口均复用该方法,减少重复代码 - 支持按 ID 列表查询或查询所有活跃请求
This commit is contained in:
@@ -651,42 +651,17 @@ class AdminActiveRequestsAdapter(AdminApiAdapter):
|
||||
self.ids = ids
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
from src.services.usage import UsageService
|
||||
|
||||
db = context.db
|
||||
id_list = None
|
||||
if self.ids:
|
||||
# 查询指定 ID 的请求状态
|
||||
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
|
||||
if not id_list:
|
||||
return {"requests": []}
|
||||
|
||||
records = (
|
||||
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
|
||||
.filter(Usage.id.in_(id_list))
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
# 查询所有活跃请求(pending 或 streaming)
|
||||
records = (
|
||||
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
|
||||
.filter(Usage.status.in_(["pending", "streaming"]))
|
||||
.order_by(Usage.created_at.desc())
|
||||
.limit(50)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"requests": [
|
||||
{
|
||||
"id": r.id,
|
||||
"status": r.status,
|
||||
"input_tokens": r.input_tokens,
|
||||
"output_tokens": r.output_tokens,
|
||||
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
|
||||
"response_time_ms": r.response_time_ms,
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
}
|
||||
requests = UsageService.get_active_requests(db=db, ids=id_list)
|
||||
return {"requests": requests}
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -662,60 +662,18 @@ class GetActiveRequestsAdapter(AuthenticatedApiAdapter):
|
||||
ids: Optional[str] = None
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.services.usage import UsageService
|
||||
|
||||
db = context.db
|
||||
user = context.user
|
||||
|
||||
id_list = None
|
||||
if self.ids:
|
||||
# 查询指定 ID 的请求状态(只能查询自己的)
|
||||
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
|
||||
if not id_list:
|
||||
return {"requests": []}
|
||||
|
||||
records = (
|
||||
db.query(
|
||||
Usage.id,
|
||||
Usage.status,
|
||||
Usage.input_tokens,
|
||||
Usage.output_tokens,
|
||||
Usage.total_cost_usd,
|
||||
Usage.response_time_ms,
|
||||
)
|
||||
.filter(Usage.id.in_(id_list), Usage.user_id == user.id)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
# 查询所有活跃请求(pending 或 streaming)
|
||||
records = (
|
||||
db.query(
|
||||
Usage.id,
|
||||
Usage.status,
|
||||
Usage.input_tokens,
|
||||
Usage.output_tokens,
|
||||
Usage.total_cost_usd,
|
||||
Usage.response_time_ms,
|
||||
)
|
||||
.filter(
|
||||
Usage.user_id == user.id,
|
||||
Usage.status.in_(["pending", "streaming"]),
|
||||
)
|
||||
.order_by(Usage.created_at.desc())
|
||||
.limit(50)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"requests": [
|
||||
{
|
||||
"id": r.id,
|
||||
"status": r.status,
|
||||
"input_tokens": r.input_tokens,
|
||||
"output_tokens": r.output_tokens,
|
||||
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
|
||||
"response_time_ms": r.response_time_ms,
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
}
|
||||
requests = UsageService.get_active_requests(db=db, ids=id_list, user_id=user.id)
|
||||
return {"requests": requests}
|
||||
|
||||
|
||||
class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
@@ -1304,3 +1304,81 @@ class UsageService:
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_active_requests(
|
||||
cls,
|
||||
db: Session,
|
||||
ids: Optional[List[str]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
timeout_seconds: int = 300,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取活跃请求状态,并自动清理超时的 pending 请求
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
ids: 指定要查询的请求 ID 列表(可选)
|
||||
user_id: 限制只查询该用户的请求(可选,用于普通用户接口)
|
||||
timeout_seconds: pending 状态超时时间(秒),默认 5 分钟
|
||||
|
||||
Returns:
|
||||
请求状态列表
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# 构建基础查询
|
||||
query = db.query(
|
||||
Usage.id,
|
||||
Usage.status,
|
||||
Usage.input_tokens,
|
||||
Usage.output_tokens,
|
||||
Usage.total_cost_usd,
|
||||
Usage.response_time_ms,
|
||||
Usage.created_at,
|
||||
)
|
||||
|
||||
if ids:
|
||||
query = query.filter(Usage.id.in_(ids))
|
||||
if user_id:
|
||||
query = query.filter(Usage.user_id == user_id)
|
||||
else:
|
||||
# 查询所有活跃请求
|
||||
query = query.filter(Usage.status.in_(["pending", "streaming"]))
|
||||
if user_id:
|
||||
query = query.filter(Usage.user_id == user_id)
|
||||
query = query.order_by(Usage.created_at.desc()).limit(50)
|
||||
|
||||
records = query.all()
|
||||
|
||||
# 检查超时的 pending 请求
|
||||
timeout_ids = []
|
||||
for r in records:
|
||||
if r.status == "pending" and r.created_at:
|
||||
# 处理时区:如果 created_at 没有时区信息,假定为 UTC
|
||||
created_at = r.created_at
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
elapsed = (now - created_at).total_seconds()
|
||||
if elapsed > timeout_seconds:
|
||||
timeout_ids.append(r.id)
|
||||
|
||||
# 批量更新超时的请求
|
||||
if timeout_ids:
|
||||
db.query(Usage).filter(Usage.id.in_(timeout_ids)).update(
|
||||
{"status": "failed", "error_message": "请求超时(服务器可能已重启)"},
|
||||
synchronize_session=False,
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"status": "failed" if r.id in timeout_ids else r.status,
|
||||
"input_tokens": r.input_tokens,
|
||||
"output_tokens": r.output_tokens,
|
||||
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
|
||||
"response_time_ms": r.response_time_ms,
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user