mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor: 重构活跃请求查询逻辑到 UsageService
- 在 UsageService 新增 get_active_requests 方法,统一处理活跃请求查询 - 支持自动清理超时的 pending 请求(默认 5 分钟) - admin 和 user 接口均复用该方法,减少重复代码 - 支持按 ID 列表查询或查询所有活跃请求
This commit is contained in:
@@ -651,42 +651,17 @@ class AdminActiveRequestsAdapter(AdminApiAdapter):
|
|||||||
self.ids = ids
|
self.ids = ids
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
db = context.db
|
from src.services.usage import UsageService
|
||||||
|
|
||||||
|
db = context.db
|
||||||
|
id_list = None
|
||||||
if self.ids:
|
if self.ids:
|
||||||
# 查询指定 ID 的请求状态
|
|
||||||
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
|
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
|
||||||
if not id_list:
|
if not id_list:
|
||||||
return {"requests": []}
|
return {"requests": []}
|
||||||
|
|
||||||
records = (
|
requests = UsageService.get_active_requests(db=db, ids=id_list)
|
||||||
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
|
return {"requests": requests}
|
||||||
.filter(Usage.id.in_(id_list))
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 查询所有活跃请求(pending 或 streaming)
|
|
||||||
records = (
|
|
||||||
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
|
|
||||||
.filter(Usage.status.in_(["pending", "streaming"]))
|
|
||||||
.order_by(Usage.created_at.desc())
|
|
||||||
.limit(50)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"requests": [
|
|
||||||
{
|
|
||||||
"id": r.id,
|
|
||||||
"status": r.status,
|
|
||||||
"input_tokens": r.input_tokens,
|
|
||||||
"output_tokens": r.output_tokens,
|
|
||||||
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
|
|
||||||
"response_time_ms": r.response_time_ms,
|
|
||||||
}
|
|
||||||
for r in records
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -662,60 +662,18 @@ class GetActiveRequestsAdapter(AuthenticatedApiAdapter):
|
|||||||
ids: Optional[str] = None
|
ids: Optional[str] = None
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
from src.services.usage import UsageService
|
||||||
|
|
||||||
db = context.db
|
db = context.db
|
||||||
user = context.user
|
user = context.user
|
||||||
|
id_list = None
|
||||||
if self.ids:
|
if self.ids:
|
||||||
# 查询指定 ID 的请求状态(只能查询自己的)
|
|
||||||
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
|
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
|
||||||
if not id_list:
|
if not id_list:
|
||||||
return {"requests": []}
|
return {"requests": []}
|
||||||
|
|
||||||
records = (
|
requests = UsageService.get_active_requests(db=db, ids=id_list, user_id=user.id)
|
||||||
db.query(
|
return {"requests": requests}
|
||||||
Usage.id,
|
|
||||||
Usage.status,
|
|
||||||
Usage.input_tokens,
|
|
||||||
Usage.output_tokens,
|
|
||||||
Usage.total_cost_usd,
|
|
||||||
Usage.response_time_ms,
|
|
||||||
)
|
|
||||||
.filter(Usage.id.in_(id_list), Usage.user_id == user.id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 查询所有活跃请求(pending 或 streaming)
|
|
||||||
records = (
|
|
||||||
db.query(
|
|
||||||
Usage.id,
|
|
||||||
Usage.status,
|
|
||||||
Usage.input_tokens,
|
|
||||||
Usage.output_tokens,
|
|
||||||
Usage.total_cost_usd,
|
|
||||||
Usage.response_time_ms,
|
|
||||||
)
|
|
||||||
.filter(
|
|
||||||
Usage.user_id == user.id,
|
|
||||||
Usage.status.in_(["pending", "streaming"]),
|
|
||||||
)
|
|
||||||
.order_by(Usage.created_at.desc())
|
|
||||||
.limit(50)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"requests": [
|
|
||||||
{
|
|
||||||
"id": r.id,
|
|
||||||
"status": r.status,
|
|
||||||
"input_tokens": r.input_tokens,
|
|
||||||
"output_tokens": r.output_tokens,
|
|
||||||
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
|
|
||||||
"response_time_ms": r.response_time_ms,
|
|
||||||
}
|
|
||||||
for r in records
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||||
|
|||||||
@@ -1304,3 +1304,81 @@ class UsageService:
|
|||||||
)
|
)
|
||||||
.count()
|
.count()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_active_requests(
|
||||||
|
cls,
|
||||||
|
db: Session,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
timeout_seconds: int = 300,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取活跃请求状态,并自动清理超时的 pending 请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
ids: 指定要查询的请求 ID 列表(可选)
|
||||||
|
user_id: 限制只查询该用户的请求(可选,用于普通用户接口)
|
||||||
|
timeout_seconds: pending 状态超时时间(秒),默认 5 分钟
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
请求状态列表
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# 构建基础查询
|
||||||
|
query = db.query(
|
||||||
|
Usage.id,
|
||||||
|
Usage.status,
|
||||||
|
Usage.input_tokens,
|
||||||
|
Usage.output_tokens,
|
||||||
|
Usage.total_cost_usd,
|
||||||
|
Usage.response_time_ms,
|
||||||
|
Usage.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ids:
|
||||||
|
query = query.filter(Usage.id.in_(ids))
|
||||||
|
if user_id:
|
||||||
|
query = query.filter(Usage.user_id == user_id)
|
||||||
|
else:
|
||||||
|
# 查询所有活跃请求
|
||||||
|
query = query.filter(Usage.status.in_(["pending", "streaming"]))
|
||||||
|
if user_id:
|
||||||
|
query = query.filter(Usage.user_id == user_id)
|
||||||
|
query = query.order_by(Usage.created_at.desc()).limit(50)
|
||||||
|
|
||||||
|
records = query.all()
|
||||||
|
|
||||||
|
# 检查超时的 pending 请求
|
||||||
|
timeout_ids = []
|
||||||
|
for r in records:
|
||||||
|
if r.status == "pending" and r.created_at:
|
||||||
|
# 处理时区:如果 created_at 没有时区信息,假定为 UTC
|
||||||
|
created_at = r.created_at
|
||||||
|
if created_at.tzinfo is None:
|
||||||
|
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||||
|
elapsed = (now - created_at).total_seconds()
|
||||||
|
if elapsed > timeout_seconds:
|
||||||
|
timeout_ids.append(r.id)
|
||||||
|
|
||||||
|
# 批量更新超时的请求
|
||||||
|
if timeout_ids:
|
||||||
|
db.query(Usage).filter(Usage.id.in_(timeout_ids)).update(
|
||||||
|
{"status": "failed", "error_message": "请求超时(服务器可能已重启)"},
|
||||||
|
synchronize_session=False,
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": r.id,
|
||||||
|
"status": "failed" if r.id in timeout_ids else r.status,
|
||||||
|
"input_tokens": r.input_tokens,
|
||||||
|
"output_tokens": r.output_tokens,
|
||||||
|
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
|
||||||
|
"response_time_ms": r.response_time_ms,
|
||||||
|
}
|
||||||
|
for r in records
|
||||||
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user