Files
Aether/src/api/admin/usage/routes.py

794 lines
32 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""管理员使用情况统计路由。"""
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.database import get_db
from src.models.database import (
ApiKey,
Provider,
ProviderAPIKey,
ProviderEndpoint,
RequestCandidate,
Usage,
User,
)
from src.services.usage.service import UsageService
router = APIRouter(prefix="/api/admin/usage", tags=["Admin - Usage"])
pipeline = ApiRequestPipeline()
# ==================== RESTful Routes ====================
@router.get("/aggregation/stats")
async def get_usage_aggregation(
request: Request,
group_by: str = Query(..., description="Aggregation dimension: model, user, provider, or api_format"),
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
):
"""
Get usage aggregation by specified dimension.
- group_by=model: Aggregate by model
- group_by=user: Aggregate by user
- group_by=provider: Aggregate by provider
- group_by=api_format: Aggregate by API format
"""
if group_by == "model":
adapter = AdminUsageByModelAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "user":
adapter = AdminUsageByUserAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "provider":
adapter = AdminUsageByProviderAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "api_format":
adapter = AdminUsageByApiFormatAdapter(start_date=start_date, end_date=end_date, limit=limit)
else:
raise HTTPException(
status_code=400,
detail=f"Invalid group_by value: {group_by}. Must be one of: model, user, provider, api_format"
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/stats")
async def get_usage_stats(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
db: Session = Depends(get_db),
):
adapter = AdminUsageStatsAdapter(start_date=start_date, end_date=end_date)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/records")
async def get_usage_records(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
user_id: Optional[str] = None,
username: Optional[str] = None,
model: Optional[str] = None,
provider: Optional[str] = None,
status: Optional[str] = None, # stream, standard, error
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
):
adapter = AdminUsageRecordsAdapter(
start_date=start_date,
end_date=end_date,
user_id=user_id,
username=username,
model=model,
provider=provider,
status=status,
limit=limit,
offset=offset,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/active")
async def get_active_requests(
request: Request,
ids: Optional[str] = Query(None, description="逗号分隔的请求 ID 列表,用于查询特定请求的状态"),
db: Session = Depends(get_db),
):
"""
获取活跃请求的状态轻量级接口用于前端轮询
- 如果提供 ids 参数只返回这些 ID 对应请求的最新状态
- 如果不提供 ids返回所有 pending/streaming 状态的请求
"""
adapter = AdminActiveRequestsAdapter(ids=ids)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# NOTE: This route must be defined AFTER all other routes to avoid matching
# routes like /stats, /records, /active, etc.
@router.get("/{usage_id}")
async def get_usage_detail(
usage_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Get detailed information of a specific usage record.
Includes request/response headers and body.
"""
adapter = AdminUsageDetailAdapter(usage_id=usage_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminUsageStatsAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime]):
self.start_date = start_date
self.end_date = end_date
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(Usage)
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
total_stats = query.with_entities(
func.count(Usage.id).label("total_requests"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("total_actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
).first()
# 缓存统计
cache_stats = query.with_entities(
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
).first()
# 错误统计
error_count = query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
activity_heatmap = UsageService.get_daily_activity(
db=db,
window_days=365,
include_actual_cost=True,
)
context.add_audit_metadata(
action="usage_stats",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
)
total_requests = total_stats.total_requests if total_stats else 0
avg_response_time_ms = float(total_stats.avg_response_time_ms or 0) if total_stats else 0
avg_response_time = avg_response_time_ms / 1000.0
return {
"total_requests": total_requests,
"total_tokens": int(total_stats.total_tokens or 0),
"total_cost": float(total_stats.total_cost or 0),
"total_actual_cost": float(total_stats.total_actual_cost or 0),
"avg_response_time": round(avg_response_time, 2),
"error_count": error_count,
"error_rate": (
round((error_count / total_requests) * 100, 2) if total_requests > 0 else 0
),
"cache_stats": {
"cache_creation_tokens": (
int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
),
"cache_read_tokens": int(cache_stats.cache_read_tokens or 0) if cache_stats else 0,
"cache_creation_cost": (
float(cache_stats.cache_creation_cost or 0) if cache_stats else 0
),
"cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0,
},
"activity_heatmap": activity_heatmap,
}
class AdminUsageByModelAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(
Usage.model,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
)
# 过滤掉 pending/streaming 状态的请求(尚未完成的请求不应计入统计)
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
# 过滤掉 unknown/pending provider请求未到达任何提供商
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = query.group_by(Usage.model).order_by(func.count(Usage.id).desc()).limit(self.limit)
stats = query.all()
context.add_audit_metadata(
action="usage_by_model",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"model": model,
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
"actual_cost": float(actual_cost or 0),
}
for model, count, tokens, cost, actual_cost in stats
]
class AdminUsageByUserAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = (
db.query(
User.id,
User.email,
User.username,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
)
.join(Usage, Usage.user_id == User.id)
.group_by(User.id, User.email, User.username)
)
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = query.order_by(func.count(Usage.id).desc()).limit(self.limit)
stats = query.all()
context.add_audit_metadata(
action="usage_by_user",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"user_id": user_id,
"email": email,
"username": username,
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
}
for user_id, email, username, count, tokens, cost in stats
]
class AdminUsageByProviderAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
# 从 request_candidates 表统计每个 Provider 的尝试次数和成功率
# 这样可以正确统计 Fallback 场景(一个请求可能尝试多个 Provider
from sqlalchemy import case, Integer
attempt_query = db.query(
RequestCandidate.provider_id,
func.count(RequestCandidate.id).label("attempt_count"),
func.sum(
case((RequestCandidate.status == "success", 1), else_=0)
).label("success_count"),
func.sum(
case((RequestCandidate.status == "failed", 1), else_=0)
).label("failed_count"),
func.avg(RequestCandidate.latency_ms).label("avg_latency_ms"),
).filter(
RequestCandidate.provider_id.isnot(None),
# 只统计实际执行的尝试(排除 available/skipped 状态)
RequestCandidate.status.in_(["success", "failed"]),
)
if self.start_date:
attempt_query = attempt_query.filter(RequestCandidate.created_at >= self.start_date)
if self.end_date:
attempt_query = attempt_query.filter(RequestCandidate.created_at <= self.end_date)
attempt_stats = (
attempt_query.group_by(RequestCandidate.provider_id)
.order_by(func.count(RequestCandidate.id).desc())
.limit(self.limit)
.all()
)
# 从 Usage 表获取 token 和费用统计(基于成功的请求)
usage_query = db.query(
Usage.provider_id,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
).filter(
Usage.provider_id.isnot(None),
# 过滤掉 pending/streaming 状态的请求
Usage.status.notin_(["pending", "streaming"]),
)
if self.start_date:
usage_query = usage_query.filter(Usage.created_at >= self.start_date)
if self.end_date:
usage_query = usage_query.filter(Usage.created_at <= self.end_date)
usage_stats = usage_query.group_by(Usage.provider_id).all()
usage_map = {str(u.provider_id): u for u in usage_stats}
# 获取所有相关的 Provider ID
provider_ids = set()
for stat in attempt_stats:
if stat.provider_id:
provider_ids.add(stat.provider_id)
for stat in usage_stats:
if stat.provider_id:
provider_ids.add(stat.provider_id)
# 获取 Provider 名称映射
provider_map = {}
if provider_ids:
providers_data = (
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
)
provider_map = {str(p.id): p.name for p in providers_data}
context.add_audit_metadata(
action="usage_by_provider",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(attempt_stats),
)
result = []
for stat in attempt_stats:
provider_id_str = str(stat.provider_id) if stat.provider_id else None
attempt_count = stat.attempt_count or 0
success_count = int(stat.success_count or 0)
failed_count = int(stat.failed_count or 0)
success_rate = (success_count / attempt_count * 100) if attempt_count > 0 else 0
# 从 usage_map 获取 token 和费用信息
usage_stat = usage_map.get(provider_id_str)
result.append({
"provider_id": provider_id_str,
"provider": provider_map.get(provider_id_str, "Unknown"),
"request_count": attempt_count, # 尝试次数
"total_tokens": int(usage_stat.total_tokens or 0) if usage_stat else 0,
"total_cost": float(usage_stat.total_cost or 0) if usage_stat else 0,
"actual_cost": float(usage_stat.actual_cost or 0) if usage_stat else 0,
"avg_response_time_ms": float(stat.avg_latency_ms or 0),
"success_rate": round(success_rate, 2),
"error_count": failed_count,
})
return result
class AdminUsageByApiFormatAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(
Usage.api_format,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
)
# 过滤掉 pending/streaming 状态的请求
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
# 过滤掉 unknown/pending provider
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
# 只统计有 api_format 的记录
query = query.filter(Usage.api_format.isnot(None))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = (
query.group_by(Usage.api_format)
.order_by(func.count(Usage.id).desc())
.limit(self.limit)
)
stats = query.all()
context.add_audit_metadata(
action="usage_by_api_format",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"api_format": api_format or "unknown",
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
"actual_cost": float(actual_cost or 0),
"avg_response_time_ms": float(avg_response_time or 0),
}
for api_format, count, tokens, cost, actual_cost, avg_response_time in stats
]
class AdminUsageRecordsAdapter(AdminApiAdapter):
def __init__(
self,
start_date: Optional[datetime],
end_date: Optional[datetime],
user_id: Optional[str],
username: Optional[str],
model: Optional[str],
provider: Optional[str],
status: Optional[str],
limit: int,
offset: int,
):
self.start_date = start_date
self.end_date = end_date
self.user_id = user_id
self.username = username
self.model = model
self.provider = provider
self.status = status
self.limit = limit
self.offset = offset
async def handle(self, context): # type: ignore[override]
db = context.db
query = (
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey)
.outerjoin(User, Usage.user_id == User.id)
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
)
if self.user_id:
query = query.filter(Usage.user_id == self.user_id)
if self.username:
# 支持用户名模糊搜索
query = query.filter(User.username.ilike(f"%{self.username}%"))
if self.model:
# 支持模型名模糊搜索
query = query.filter(Usage.model.ilike(f"%{self.model}%"))
if self.provider:
# 支持提供商名称搜索(通过 Provider 表)
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
query = query.filter(Provider.name.ilike(f"%{self.provider}%"))
if self.status:
# 状态筛选
# 旧的筛选值(基于 is_stream 和 status_codestream, standard, error
# 新的筛选值(基于 status 字段pending, streaming, completed, failed, active
if self.status == "stream":
query = query.filter(Usage.is_stream == True) # noqa: E712
elif self.status == "standard":
query = query.filter(Usage.is_stream == False) # noqa: E712
elif self.status == "error":
query = query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
)
elif self.status in ("pending", "streaming", "completed", "failed"):
# 新的状态筛选:直接按 status 字段过滤
query = query.filter(Usage.status == self.status)
elif self.status == "active":
# 活跃请求pending 或 streaming 状态
query = query.filter(Usage.status.in_(["pending", "streaming"]))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
total = query.count()
records = (
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
)
request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id]
fallback_map = {}
if request_ids:
# 只统计实际执行的候选success 或 failed不包括 skipped/pending/available
executed_counts = (
db.query(RequestCandidate.request_id, func.count(RequestCandidate.id))
.filter(
RequestCandidate.request_id.in_(request_ids),
RequestCandidate.status.in_(["success", "failed"]),
)
.group_by(RequestCandidate.request_id)
.all()
)
# 如果实际执行的候选数 > 1说明发生了 Provider 切换
fallback_map = {req_id: count > 1 for req_id, count in executed_counts}
context.add_audit_metadata(
action="usage_records",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
user_id=self.user_id,
username=self.username,
model=self.model,
provider=self.provider,
status=self.status,
limit=self.limit,
offset=self.offset,
total=total,
)
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id]
provider_map = {}
if provider_ids:
providers_data = (
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
)
provider_map = {str(p.id): p.name for p in providers_data}
data = []
for usage, user, endpoint, api_key in records:
actual_cost = (
float(usage.actual_total_cost_usd)
if usage.actual_total_cost_usd is not None
else 0.0
)
rate_multiplier = (
float(usage.rate_multiplier) if usage.rate_multiplier is not None else 1.0
)
# 提供商名称优先级:关联的 Provider 表 > usage.provider 字段
provider_name = usage.provider
if usage.provider_id and str(usage.provider_id) in provider_map:
provider_name = provider_map[str(usage.provider_id)]
data.append(
{
"id": usage.id,
"user_id": user.id if user else None,
"user_email": user.email if user else "已删除用户",
"username": user.username if user else "已删除用户",
"provider": provider_name,
"model": usage.model,
"target_model": usage.target_model, # 映射后的目标模型名
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"total_tokens": usage.total_tokens,
"cost": float(usage.total_cost_usd),
"actual_cost": actual_cost,
"rate_multiplier": rate_multiplier,
"response_time_ms": usage.response_time_ms,
"created_at": usage.created_at.isoformat(),
"is_stream": usage.is_stream,
"input_price_per_1m": usage.input_price_per_1m,
"output_price_per_1m": usage.output_price_per_1m,
"cache_creation_price_per_1m": usage.cache_creation_price_per_1m,
"cache_read_price_per_1m": usage.cache_read_price_per_1m,
"status_code": usage.status_code,
"error_message": usage.error_message,
"status": usage.status, # 请求状态: pending, streaming, completed, failed
"has_fallback": fallback_map.get(usage.request_id, False),
"api_format": usage.api_format
or (endpoint.api_format if endpoint and endpoint.api_format else None),
"api_key_name": api_key.name if api_key else None,
"request_metadata": usage.request_metadata, # Provider 响应元数据
}
)
return {
"records": data,
"total": total,
"limit": self.limit,
"offset": self.offset,
}
class AdminActiveRequestsAdapter(AdminApiAdapter):
"""轻量级活跃请求状态查询适配器"""
def __init__(self, ids: Optional[str]):
self.ids = ids
async def handle(self, context): # type: ignore[override]
from src.services.usage import UsageService
2025-12-10 20:52:44 +08:00
db = context.db
id_list = None
2025-12-10 20:52:44 +08:00
if self.ids:
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
if not id_list:
return {"requests": []}
requests = UsageService.get_active_requests_status(db=db, ids=id_list)
return {"requests": requests}
2025-12-10 20:52:44 +08:00
@dataclass
class AdminUsageDetailAdapter(AdminApiAdapter):
"""Get detailed usage record with request/response body"""
usage_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
usage_record = db.query(Usage).filter(Usage.id == self.usage_id).first()
if not usage_record:
raise HTTPException(status_code=404, detail="Usage record not found")
user = db.query(User).filter(User.id == usage_record.user_id).first()
api_key = db.query(ApiKey).filter(ApiKey.id == usage_record.api_key_id).first()
# 获取阶梯计费信息
tiered_pricing_info = await self._get_tiered_pricing_info(db, usage_record)
context.add_audit_metadata(
action="usage_detail",
usage_id=self.usage_id,
)
return {
"id": usage_record.id,
"request_id": usage_record.request_id,
"user": {
"id": user.id if user else None,
"username": user.username if user else "Unknown",
"email": user.email if user else None,
},
"api_key": {
"id": api_key.id if api_key else None,
"name": api_key.name if api_key else None,
"display": api_key.get_display_key() if api_key else None,
},
"provider": usage_record.provider,
"api_format": usage_record.api_format,
"model": usage_record.model,
"target_model": usage_record.target_model,
"tokens": {
"input": usage_record.input_tokens,
"output": usage_record.output_tokens,
"total": usage_record.total_tokens,
},
"cost": {
"input": usage_record.input_cost_usd,
"output": usage_record.output_cost_usd,
"total": usage_record.total_cost_usd,
},
"cache_creation_input_tokens": usage_record.cache_creation_input_tokens,
"cache_read_input_tokens": usage_record.cache_read_input_tokens,
"cache_creation_cost": getattr(usage_record, "cache_creation_cost_usd", 0.0),
"cache_read_cost": getattr(usage_record, "cache_read_cost_usd", 0.0),
"request_cost": getattr(usage_record, "request_cost_usd", 0.0),
"input_price_per_1m": usage_record.input_price_per_1m,
"output_price_per_1m": usage_record.output_price_per_1m,
"cache_creation_price_per_1m": usage_record.cache_creation_price_per_1m,
"cache_read_price_per_1m": usage_record.cache_read_price_per_1m,
"price_per_request": usage_record.price_per_request,
"request_type": usage_record.request_type,
"is_stream": usage_record.is_stream,
"status_code": usage_record.status_code,
"error_message": usage_record.error_message,
"response_time_ms": usage_record.response_time_ms,
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
"request_headers": usage_record.request_headers,
"request_body": usage_record.get_request_body(),
"provider_request_headers": usage_record.provider_request_headers,
"response_headers": usage_record.response_headers,
"response_body": usage_record.get_response_body(),
"metadata": usage_record.request_metadata,
"tiered_pricing": tiered_pricing_info,
}
async def _get_tiered_pricing_info(self, db, usage_record) -> dict | None:
"""获取阶梯计费信息"""
from src.services.model.cost import ModelCostService
# 计算总输入上下文(用于阶梯判定):输入 + 缓存创建 + 缓存读取
input_tokens = usage_record.input_tokens or 0
cache_creation_tokens = usage_record.cache_creation_input_tokens or 0
cache_read_tokens = usage_record.cache_read_input_tokens or 0
total_input_context = input_tokens + cache_creation_tokens + cache_read_tokens
# 尝试获取模型的阶梯配置(带来源信息)
cost_service = ModelCostService(db)
pricing_result = await cost_service.get_tiered_pricing_with_source_async(
usage_record.provider, usage_record.model
)
if not pricing_result:
return None
tiered_pricing = pricing_result.get("pricing")
pricing_source = pricing_result.get("source") # 'provider' 或 'global'
if not tiered_pricing or not tiered_pricing.get("tiers"):
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
# 找到命中的阶梯
tier_index = None
matched_tier = None
for i, tier in enumerate(tiers):
up_to = tier.get("up_to")
if up_to is None or total_input_context <= up_to:
tier_index = i
matched_tier = tier
break
# 如果都没匹配,使用最后一个阶梯
if tier_index is None and tiers:
tier_index = len(tiers) - 1
matched_tier = tiers[-1]
return {
"total_input_context": total_input_context,
"tier_index": tier_index,
"tier_count": len(tiers),
"current_tier": matched_tier,
"tiers": tiers,
"source": pricing_source, # 定价来源: 'provider' 或 'global'
}