Files
Aether/src/api/admin/usage/routes.py

984 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""管理员使用情况统计路由。"""
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.database import get_db
from src.models.database import (
ApiKey,
Provider,
ProviderAPIKey,
ProviderEndpoint,
RequestCandidate,
Usage,
User,
)
from src.services.usage.service import UsageService
router = APIRouter(prefix="/api/admin/usage", tags=["Admin - Usage"])
pipeline = ApiRequestPipeline()
# ==================== RESTful Routes ====================
@router.get("/aggregation/stats")
async def get_usage_aggregation(
request: Request,
group_by: str = Query(..., description="Aggregation dimension: model, user, provider, or api_format"),
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
):
"""
Get usage aggregation by specified dimension.
- group_by=model: Aggregate by model
- group_by=user: Aggregate by user
- group_by=provider: Aggregate by provider
- group_by=api_format: Aggregate by API format
"""
if group_by == "model":
adapter = AdminUsageByModelAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "user":
adapter = AdminUsageByUserAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "provider":
adapter = AdminUsageByProviderAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "api_format":
adapter = AdminUsageByApiFormatAdapter(start_date=start_date, end_date=end_date, limit=limit)
else:
raise HTTPException(
status_code=400,
detail=f"Invalid group_by value: {group_by}. Must be one of: model, user, provider, api_format"
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/stats")
async def get_usage_stats(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
db: Session = Depends(get_db),
):
adapter = AdminUsageStatsAdapter(start_date=start_date, end_date=end_date)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/records")
async def get_usage_records(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
user_id: Optional[str] = None,
username: Optional[str] = None,
model: Optional[str] = None,
provider: Optional[str] = None,
status: Optional[str] = None, # stream, standard, error
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
):
adapter = AdminUsageRecordsAdapter(
start_date=start_date,
end_date=end_date,
user_id=user_id,
username=username,
model=model,
provider=provider,
status=status,
limit=limit,
offset=offset,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/active")
async def get_active_requests(
request: Request,
ids: Optional[str] = Query(None, description="逗号分隔的请求 ID 列表,用于查询特定请求的状态"),
db: Session = Depends(get_db),
):
"""
获取活跃请求的状态(轻量级接口,用于前端轮询)
- 如果提供 ids 参数,只返回这些 ID 对应请求的最新状态
- 如果不提供 ids返回所有 pending/streaming 状态的请求
"""
adapter = AdminActiveRequestsAdapter(ids=ids)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# NOTE: This route must be defined AFTER all other routes to avoid matching
# routes like /stats, /records, /active, etc.
@router.get("/{usage_id}")
async def get_usage_detail(
usage_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Get detailed information of a specific usage record.
Includes request/response headers and body.
"""
adapter = AdminUsageDetailAdapter(usage_id=usage_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminUsageStatsAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime]):
self.start_date = start_date
self.end_date = end_date
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(Usage)
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
total_stats = query.with_entities(
func.count(Usage.id).label("total_requests"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("total_actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
).first()
# 缓存统计
cache_stats = query.with_entities(
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
).first()
# 错误统计
error_count = query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
activity_heatmap = UsageService.get_daily_activity(
db=db,
window_days=365,
include_actual_cost=True,
)
context.add_audit_metadata(
action="usage_stats",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
)
total_requests = total_stats.total_requests if total_stats else 0
avg_response_time_ms = float(total_stats.avg_response_time_ms or 0) if total_stats else 0
avg_response_time = avg_response_time_ms / 1000.0
return {
"total_requests": total_requests,
"total_tokens": int(total_stats.total_tokens or 0),
"total_cost": float(total_stats.total_cost or 0),
"total_actual_cost": float(total_stats.total_actual_cost or 0),
"avg_response_time": round(avg_response_time, 2),
"error_count": error_count,
"error_rate": (
round((error_count / total_requests) * 100, 2) if total_requests > 0 else 0
),
"cache_stats": {
"cache_creation_tokens": (
int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
),
"cache_read_tokens": int(cache_stats.cache_read_tokens or 0) if cache_stats else 0,
"cache_creation_cost": (
float(cache_stats.cache_creation_cost or 0) if cache_stats else 0
),
"cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0,
},
"activity_heatmap": activity_heatmap,
}
class AdminUsageByModelAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(
Usage.model,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
)
# 过滤掉 pending/streaming 状态的请求(尚未完成的请求不应计入统计)
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
# 过滤掉 unknown/pending provider请求未到达任何提供商
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = query.group_by(Usage.model).order_by(func.count(Usage.id).desc()).limit(self.limit)
stats = query.all()
context.add_audit_metadata(
action="usage_by_model",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"model": model,
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
"actual_cost": float(actual_cost or 0),
}
for model, count, tokens, cost, actual_cost in stats
]
class AdminUsageByUserAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = (
db.query(
User.id,
User.email,
User.username,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
)
.join(Usage, Usage.user_id == User.id)
.group_by(User.id, User.email, User.username)
)
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = query.order_by(func.count(Usage.id).desc()).limit(self.limit)
stats = query.all()
context.add_audit_metadata(
action="usage_by_user",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"user_id": user_id,
"email": email,
"username": username,
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
}
for user_id, email, username, count, tokens, cost in stats
]
class AdminUsageByProviderAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
# 从 request_candidates 表统计每个 Provider 的尝试次数和成功率
# 这样可以正确统计 Fallback 场景(一个请求可能尝试多个 Provider
from sqlalchemy import case, Integer
attempt_query = db.query(
RequestCandidate.provider_id,
func.count(RequestCandidate.id).label("attempt_count"),
func.sum(
case((RequestCandidate.status == "success", 1), else_=0)
).label("success_count"),
func.sum(
case((RequestCandidate.status == "failed", 1), else_=0)
).label("failed_count"),
func.avg(RequestCandidate.latency_ms).label("avg_latency_ms"),
).filter(
RequestCandidate.provider_id.isnot(None),
# 只统计实际执行的尝试(排除 available/skipped 状态)
RequestCandidate.status.in_(["success", "failed"]),
)
if self.start_date:
attempt_query = attempt_query.filter(RequestCandidate.created_at >= self.start_date)
if self.end_date:
attempt_query = attempt_query.filter(RequestCandidate.created_at <= self.end_date)
attempt_stats = (
attempt_query.group_by(RequestCandidate.provider_id)
.order_by(func.count(RequestCandidate.id).desc())
.limit(self.limit)
.all()
)
# 从 Usage 表获取 token 和费用统计(基于成功的请求)
usage_query = db.query(
Usage.provider_id,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
).filter(
Usage.provider_id.isnot(None),
# 过滤掉 pending/streaming 状态的请求
Usage.status.notin_(["pending", "streaming"]),
)
if self.start_date:
usage_query = usage_query.filter(Usage.created_at >= self.start_date)
if self.end_date:
usage_query = usage_query.filter(Usage.created_at <= self.end_date)
usage_stats = usage_query.group_by(Usage.provider_id).all()
usage_map = {str(u.provider_id): u for u in usage_stats}
# 获取所有相关的 Provider ID
provider_ids = set()
for stat in attempt_stats:
if stat.provider_id:
provider_ids.add(stat.provider_id)
for stat in usage_stats:
if stat.provider_id:
provider_ids.add(stat.provider_id)
# 获取 Provider 名称映射
provider_map = {}
if provider_ids:
providers_data = (
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
)
provider_map = {str(p.id): p.name for p in providers_data}
context.add_audit_metadata(
action="usage_by_provider",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(attempt_stats),
)
result = []
for stat in attempt_stats:
provider_id_str = str(stat.provider_id) if stat.provider_id else None
attempt_count = stat.attempt_count or 0
success_count = int(stat.success_count or 0)
failed_count = int(stat.failed_count or 0)
success_rate = (success_count / attempt_count * 100) if attempt_count > 0 else 0
# 从 usage_map 获取 token 和费用信息
usage_stat = usage_map.get(provider_id_str)
result.append({
"provider_id": provider_id_str,
"provider": provider_map.get(provider_id_str, "Unknown"),
"request_count": attempt_count, # 尝试次数
"total_tokens": int(usage_stat.total_tokens or 0) if usage_stat else 0,
"total_cost": float(usage_stat.total_cost or 0) if usage_stat else 0,
"actual_cost": float(usage_stat.actual_cost or 0) if usage_stat else 0,
"avg_response_time_ms": float(stat.avg_latency_ms or 0),
"success_rate": round(success_rate, 2),
"error_count": failed_count,
})
return result
class AdminUsageByApiFormatAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(
Usage.api_format,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
)
# 过滤掉 pending/streaming 状态的请求
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
# 过滤掉 unknown/pending provider
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
# 只统计有 api_format 的记录
query = query.filter(Usage.api_format.isnot(None))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = (
query.group_by(Usage.api_format)
.order_by(func.count(Usage.id).desc())
.limit(self.limit)
)
stats = query.all()
context.add_audit_metadata(
action="usage_by_api_format",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"api_format": api_format or "unknown",
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
"actual_cost": float(actual_cost or 0),
"avg_response_time_ms": float(avg_response_time or 0),
}
for api_format, count, tokens, cost, actual_cost, avg_response_time in stats
]
class AdminUsageRecordsAdapter(AdminApiAdapter):
def __init__(
self,
start_date: Optional[datetime],
end_date: Optional[datetime],
user_id: Optional[str],
username: Optional[str],
model: Optional[str],
provider: Optional[str],
status: Optional[str],
limit: int,
offset: int,
):
self.start_date = start_date
self.end_date = end_date
self.user_id = user_id
self.username = username
self.model = model
self.provider = provider
self.status = status
self.limit = limit
self.offset = offset
async def handle(self, context): # type: ignore[override]
db = context.db
query = (
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey)
.outerjoin(User, Usage.user_id == User.id)
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
)
if self.user_id:
query = query.filter(Usage.user_id == self.user_id)
if self.username:
# 支持用户名模糊搜索
query = query.filter(User.username.ilike(f"%{self.username}%"))
if self.model:
# 支持模型名模糊搜索
query = query.filter(Usage.model.ilike(f"%{self.model}%"))
if self.provider:
# 支持提供商名称搜索(通过 Provider 表)
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
query = query.filter(Provider.name.ilike(f"%{self.provider}%"))
if self.status:
# 状态筛选
# 旧的筛选值(基于 is_stream 和 status_codestream, standard, error
# 新的筛选值(基于 status 字段pending, streaming, completed, failed, active
if self.status == "stream":
query = query.filter(Usage.is_stream == True) # noqa: E712
elif self.status == "standard":
query = query.filter(Usage.is_stream == False) # noqa: E712
elif self.status == "error":
query = query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
)
elif self.status in ("pending", "streaming", "completed"):
# 新的状态筛选:直接按 status 字段过滤
query = query.filter(Usage.status == self.status)
elif self.status == "failed":
# 失败请求需要同时考虑新旧两种判断方式:
# 1. 新方式status = "failed"
# 2. 旧方式status_code >= 400 或 error_message 不为空
query = query.filter(
(Usage.status == "failed") |
(Usage.status_code >= 400) |
(Usage.error_message.isnot(None))
)
elif self.status == "active":
# 活跃请求pending 或 streaming 状态
query = query.filter(Usage.status.in_(["pending", "streaming"]))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
total = query.count()
records = (
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
)
request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id]
fallback_map = {}
if request_ids:
# 只统计实际执行的候选success 或 failed不包括 skipped/pending/available
executed_counts = (
db.query(RequestCandidate.request_id, func.count(RequestCandidate.id))
.filter(
RequestCandidate.request_id.in_(request_ids),
RequestCandidate.status.in_(["success", "failed"]),
)
.group_by(RequestCandidate.request_id)
.all()
)
# 如果实际执行的候选数 > 1说明发生了 Provider 切换
fallback_map = {req_id: count > 1 for req_id, count in executed_counts}
context.add_audit_metadata(
action="usage_records",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
user_id=self.user_id,
username=self.username,
model=self.model,
provider=self.provider,
status=self.status,
limit=self.limit,
offset=self.offset,
total=total,
)
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id]
provider_map = {}
if provider_ids:
providers_data = (
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
)
provider_map = {str(p.id): p.name for p in providers_data}
data = []
for usage, user, endpoint, api_key in records:
actual_cost = (
float(usage.actual_total_cost_usd)
if usage.actual_total_cost_usd is not None
else 0.0
)
rate_multiplier = (
float(usage.rate_multiplier) if usage.rate_multiplier is not None else 1.0
)
# 提供商名称优先级:关联的 Provider 表 > usage.provider 字段
provider_name = usage.provider
if usage.provider_id and str(usage.provider_id) in provider_map:
provider_name = provider_map[str(usage.provider_id)]
data.append(
{
"id": usage.id,
"user_id": user.id if user else None,
"user_email": user.email if user else "已删除用户",
"username": user.username if user else "已删除用户",
"provider": provider_name,
"model": usage.model,
"target_model": usage.target_model, # 映射后的目标模型名
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"total_tokens": usage.total_tokens,
"cost": float(usage.total_cost_usd),
"actual_cost": actual_cost,
"rate_multiplier": rate_multiplier,
"response_time_ms": usage.response_time_ms,
"created_at": usage.created_at.isoformat(),
"is_stream": usage.is_stream,
"input_price_per_1m": usage.input_price_per_1m,
"output_price_per_1m": usage.output_price_per_1m,
"cache_creation_price_per_1m": usage.cache_creation_price_per_1m,
"cache_read_price_per_1m": usage.cache_read_price_per_1m,
"status_code": usage.status_code,
"error_message": usage.error_message,
"status": usage.status, # 请求状态: pending, streaming, completed, failed
"has_fallback": fallback_map.get(usage.request_id, False),
"api_format": usage.api_format
or (endpoint.api_format if endpoint and endpoint.api_format else None),
"api_key_name": api_key.name if api_key else None,
"request_metadata": usage.request_metadata, # Provider 响应元数据
}
)
return {
"records": data,
"total": total,
"limit": self.limit,
"offset": self.offset,
}
class AdminActiveRequestsAdapter(AdminApiAdapter):
"""轻量级活跃请求状态查询适配器"""
def __init__(self, ids: Optional[str]):
self.ids = ids
async def handle(self, context): # type: ignore[override]
from src.services.usage import UsageService
db = context.db
id_list = None
if self.ids:
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
if not id_list:
return {"requests": []}
requests = UsageService.get_active_requests_status(db=db, ids=id_list)
return {"requests": requests}
@dataclass
class AdminUsageDetailAdapter(AdminApiAdapter):
"""Get detailed usage record with request/response body"""
usage_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
usage_record = db.query(Usage).filter(Usage.id == self.usage_id).first()
if not usage_record:
raise HTTPException(status_code=404, detail="Usage record not found")
user = db.query(User).filter(User.id == usage_record.user_id).first()
api_key = db.query(ApiKey).filter(ApiKey.id == usage_record.api_key_id).first()
# 获取阶梯计费信息
tiered_pricing_info = await self._get_tiered_pricing_info(db, usage_record)
context.add_audit_metadata(
action="usage_detail",
usage_id=self.usage_id,
)
return {
"id": usage_record.id,
"request_id": usage_record.request_id,
"user": {
"id": user.id if user else None,
"username": user.username if user else "Unknown",
"email": user.email if user else None,
},
"api_key": {
"id": api_key.id if api_key else None,
"name": api_key.name if api_key else None,
"display": api_key.get_display_key() if api_key else None,
},
"provider": usage_record.provider,
"api_format": usage_record.api_format,
"model": usage_record.model,
"target_model": usage_record.target_model,
"tokens": {
"input": usage_record.input_tokens,
"output": usage_record.output_tokens,
"total": usage_record.total_tokens,
},
"cost": {
"input": usage_record.input_cost_usd,
"output": usage_record.output_cost_usd,
"total": usage_record.total_cost_usd,
},
"cache_creation_input_tokens": usage_record.cache_creation_input_tokens,
"cache_read_input_tokens": usage_record.cache_read_input_tokens,
"cache_creation_cost": getattr(usage_record, "cache_creation_cost_usd", 0.0),
"cache_read_cost": getattr(usage_record, "cache_read_cost_usd", 0.0),
"request_cost": getattr(usage_record, "request_cost_usd", 0.0),
"input_price_per_1m": usage_record.input_price_per_1m,
"output_price_per_1m": usage_record.output_price_per_1m,
"cache_creation_price_per_1m": usage_record.cache_creation_price_per_1m,
"cache_read_price_per_1m": usage_record.cache_read_price_per_1m,
"price_per_request": usage_record.price_per_request,
"request_type": usage_record.request_type,
"is_stream": usage_record.is_stream,
"status_code": usage_record.status_code,
"error_message": usage_record.error_message,
"response_time_ms": usage_record.response_time_ms,
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
"request_headers": usage_record.request_headers,
"request_body": usage_record.get_request_body(),
"provider_request_headers": usage_record.provider_request_headers,
"response_headers": usage_record.response_headers,
"response_body": usage_record.get_response_body(),
"metadata": usage_record.request_metadata,
"tiered_pricing": tiered_pricing_info,
}
async def _get_tiered_pricing_info(self, db, usage_record) -> dict | None:
"""获取阶梯计费信息"""
from src.services.model.cost import ModelCostService
# 计算总输入上下文(用于阶梯判定):输入 + 缓存创建 + 缓存读取
input_tokens = usage_record.input_tokens or 0
cache_creation_tokens = usage_record.cache_creation_input_tokens or 0
cache_read_tokens = usage_record.cache_read_input_tokens or 0
total_input_context = input_tokens + cache_creation_tokens + cache_read_tokens
# 尝试获取模型的阶梯配置(带来源信息)
cost_service = ModelCostService(db)
pricing_result = await cost_service.get_tiered_pricing_with_source_async(
usage_record.provider, usage_record.model
)
if not pricing_result:
return None
tiered_pricing = pricing_result.get("pricing")
pricing_source = pricing_result.get("source") # 'provider' 或 'global'
if not tiered_pricing or not tiered_pricing.get("tiers"):
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
# 找到命中的阶梯
tier_index = None
matched_tier = None
for i, tier in enumerate(tiers):
up_to = tier.get("up_to")
if up_to is None or total_input_context <= up_to:
tier_index = i
matched_tier = tier
break
# 如果都没匹配,使用最后一个阶梯
if tier_index is None and tiers:
tier_index = len(tiers) - 1
matched_tier = tiers[-1]
return {
"total_input_context": total_input_context,
"tier_index": tier_index,
"tier_count": len(tiers),
"current_tier": matched_tier,
"tiers": tiers,
"source": pricing_source, # 定价来源: 'provider' 或 'global'
}
# ==================== 缓存亲和性分析 ====================
@router.get("/cache-affinity/ttl-analysis")
async def analyze_cache_affinity_ttl(
request: Request,
user_id: Optional[str] = Query(None, description="指定用户 ID"),
api_key_id: Optional[str] = Query(None, description="指定 API Key ID"),
hours: int = Query(168, ge=1, le=720, description="分析最近多少小时的数据"),
db: Session = Depends(get_db),
):
"""
分析用户请求间隔分布,推荐合适的缓存亲和性 TTL。
通过分析同一用户连续请求之间的时间间隔,判断用户的使用模式:
- 高频用户间隔短5 分钟 TTL 足够
- 中频用户15-30 分钟 TTL
- 低频用户(间隔长):需要 60 分钟 TTL
"""
adapter = CacheAffinityTTLAnalysisAdapter(
user_id=user_id,
api_key_id=api_key_id,
hours=hours,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/cache-affinity/hit-analysis")
async def analyze_cache_hit(
request: Request,
user_id: Optional[str] = Query(None, description="指定用户 ID"),
api_key_id: Optional[str] = Query(None, description="指定 API Key ID"),
hours: int = Query(168, ge=1, le=720, description="分析最近多少小时的数据"),
db: Session = Depends(get_db),
):
"""
分析缓存命中情况。
返回缓存命中率、节省的费用等统计信息。
"""
adapter = CacheHitAnalysisAdapter(
user_id=user_id,
api_key_id=api_key_id,
hours=hours,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class CacheAffinityTTLAnalysisAdapter(AdminApiAdapter):
"""缓存亲和性 TTL 分析适配器"""
def __init__(
self,
user_id: Optional[str],
api_key_id: Optional[str],
hours: int,
):
self.user_id = user_id
self.api_key_id = api_key_id
self.hours = hours
async def handle(self, context): # type: ignore[override]
db = context.db
result = UsageService.analyze_cache_affinity_ttl(
db=db,
user_id=self.user_id,
api_key_id=self.api_key_id,
hours=self.hours,
)
context.add_audit_metadata(
action="cache_affinity_ttl_analysis",
user_id=self.user_id,
api_key_id=self.api_key_id,
hours=self.hours,
total_users_analyzed=result.get("total_users_analyzed", 0),
)
return result
class CacheHitAnalysisAdapter(AdminApiAdapter):
"""缓存命中分析适配器"""
def __init__(
self,
user_id: Optional[str],
api_key_id: Optional[str],
hours: int,
):
self.user_id = user_id
self.api_key_id = api_key_id
self.hours = hours
async def handle(self, context): # type: ignore[override]
db = context.db
result = UsageService.get_cache_hit_analysis(
db=db,
user_id=self.user_id,
api_key_id=self.api_key_id,
hours=self.hours,
)
context.add_audit_metadata(
action="cache_hit_analysis",
user_id=self.user_id,
api_key_id=self.api_key_id,
hours=self.hours,
)
return result
@router.get("/cache-affinity/interval-timeline")
async def get_interval_timeline(
request: Request,
hours: int = Query(168, ge=1, le=720, description="分析最近多少小时的数据"),
limit: int = Query(1000, ge=100, le=5000, description="最大返回数据点数量"),
user_id: Optional[str] = Query(None, description="指定用户 ID"),
include_user_info: bool = Query(False, description="是否包含用户信息(用于管理员多用户视图)"),
db: Session = Depends(get_db),
):
"""
获取请求间隔时间线数据,用于散点图展示。
返回每个请求的时间点和与上一个请求的间隔(分钟),
可用于可视化用户请求模式。
当 include_user_info=true 且未指定 user_id 时,返回数据会包含:
- points 中每个点包含 user_id 字段
- users 字段包含 user_id -> username 的映射
"""
adapter = IntervalTimelineAdapter(
hours=hours,
limit=limit,
user_id=user_id,
include_user_info=include_user_info,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class IntervalTimelineAdapter(AdminApiAdapter):
"""请求间隔时间线适配器"""
def __init__(
self,
hours: int,
limit: int,
user_id: Optional[str] = None,
include_user_info: bool = False,
):
self.hours = hours
self.limit = limit
self.user_id = user_id
self.include_user_info = include_user_info
async def handle(self, context): # type: ignore[override]
db = context.db
result = UsageService.get_interval_timeline(
db=db,
hours=self.hours,
limit=self.limit,
user_id=self.user_id,
include_user_info=self.include_user_info,
)
context.add_audit_metadata(
action="interval_timeline",
hours=self.hours,
limit=self.limit,
user_id=self.user_id,
include_user_info=self.include_user_info,
total_points=result.get("total_points", 0),
)
return result