"""管理员使用情况统计路由。""" from dataclasses import dataclass from datetime import datetime from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query, Request from sqlalchemy import func from sqlalchemy.orm import Session from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.pipeline import ApiRequestPipeline from src.database import get_db from src.models.database import ( ApiKey, Provider, ProviderAPIKey, ProviderEndpoint, RequestCandidate, Usage, User, ) from src.services.usage.service import UsageService router = APIRouter(prefix="/api/admin/usage", tags=["Admin - Usage"]) pipeline = ApiRequestPipeline() # ==================== RESTful Routes ==================== @router.get("/aggregation/stats") async def get_usage_aggregation( request: Request, group_by: str = Query(..., description="Aggregation dimension: model, user, provider, or api_format"), start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, limit: int = Query(20, ge=1, le=100), db: Session = Depends(get_db), ): """ Get usage aggregation by specified dimension. - group_by=model: Aggregate by model - group_by=user: Aggregate by user - group_by=provider: Aggregate by provider - group_by=api_format: Aggregate by API format """ if group_by == "model": adapter = AdminUsageByModelAdapter(start_date=start_date, end_date=end_date, limit=limit) elif group_by == "user": adapter = AdminUsageByUserAdapter(start_date=start_date, end_date=end_date, limit=limit) elif group_by == "provider": adapter = AdminUsageByProviderAdapter(start_date=start_date, end_date=end_date, limit=limit) elif group_by == "api_format": adapter = AdminUsageByApiFormatAdapter(start_date=start_date, end_date=end_date, limit=limit) else: raise HTTPException( status_code=400, detail=f"Invalid group_by value: {group_by}. Must be one of: model, user, provider, api_format" ) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/stats") async def get_usage_stats( request: Request, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, db: Session = Depends(get_db), ): adapter = AdminUsageStatsAdapter(start_date=start_date, end_date=end_date) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/records") async def get_usage_records( request: Request, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, user_id: Optional[str] = None, username: Optional[str] = None, model: Optional[str] = None, provider: Optional[str] = None, status: Optional[str] = None, # stream, standard, error limit: int = Query(100, ge=1, le=500), offset: int = Query(0, ge=0), db: Session = Depends(get_db), ): adapter = AdminUsageRecordsAdapter( start_date=start_date, end_date=end_date, user_id=user_id, username=username, model=model, provider=provider, status=status, limit=limit, offset=offset, ) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/active") async def get_active_requests( request: Request, ids: Optional[str] = Query(None, description="逗号分隔的请求 ID 列表,用于查询特定请求的状态"), db: Session = Depends(get_db), ): """ 获取活跃请求的状态(轻量级接口,用于前端轮询) - 如果提供 ids 参数,只返回这些 ID 对应请求的最新状态 - 如果不提供 ids,返回所有 pending/streaming 状态的请求 """ adapter = AdminActiveRequestsAdapter(ids=ids) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) # NOTE: This route must be defined AFTER all other routes to avoid matching # routes like /stats, /records, /active, etc. @router.get("/{usage_id}") async def get_usage_detail( usage_id: str, request: Request, db: Session = Depends(get_db), ): """ Get detailed information of a specific usage record. Includes request/response headers and body. """ adapter = AdminUsageDetailAdapter(usage_id=usage_id) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) class AdminUsageStatsAdapter(AdminApiAdapter): def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime]): self.start_date = start_date self.end_date = end_date async def handle(self, context): # type: ignore[override] db = context.db query = db.query(Usage) if self.start_date: query = query.filter(Usage.created_at >= self.start_date) if self.end_date: query = query.filter(Usage.created_at <= self.end_date) total_stats = query.with_entities( func.count(Usage.id).label("total_requests"), func.sum(Usage.total_tokens).label("total_tokens"), func.sum(Usage.total_cost_usd).label("total_cost"), func.sum(Usage.actual_total_cost_usd).label("total_actual_cost"), func.avg(Usage.response_time_ms).label("avg_response_time_ms"), ).first() # 缓存统计 cache_stats = query.with_entities( func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"), func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"), func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"), func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"), ).first() # 错误统计 error_count = query.filter( (Usage.status_code >= 400) | (Usage.error_message.isnot(None)) ).count() activity_heatmap = UsageService.get_daily_activity( db=db, window_days=365, include_actual_cost=True, ) context.add_audit_metadata( action="usage_stats", start_date=self.start_date.isoformat() if self.start_date else None, end_date=self.end_date.isoformat() if self.end_date else None, ) total_requests = total_stats.total_requests if total_stats else 0 avg_response_time_ms = float(total_stats.avg_response_time_ms or 0) if total_stats else 0 avg_response_time = avg_response_time_ms / 1000.0 return { "total_requests": total_requests, "total_tokens": int(total_stats.total_tokens or 0), "total_cost": float(total_stats.total_cost or 0), "total_actual_cost": float(total_stats.total_actual_cost or 0), "avg_response_time": round(avg_response_time, 2), "error_count": error_count, "error_rate": ( round((error_count / total_requests) * 100, 2) if total_requests > 0 else 0 ), "cache_stats": { "cache_creation_tokens": ( int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0 ), "cache_read_tokens": int(cache_stats.cache_read_tokens or 0) if cache_stats else 0, "cache_creation_cost": ( float(cache_stats.cache_creation_cost or 0) if cache_stats else 0 ), "cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0, }, "activity_heatmap": activity_heatmap, } class AdminUsageByModelAdapter(AdminApiAdapter): def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int): self.start_date = start_date self.end_date = end_date self.limit = limit async def handle(self, context): # type: ignore[override] db = context.db query = db.query( Usage.model, func.count(Usage.id).label("request_count"), func.sum(Usage.total_tokens).label("total_tokens"), func.sum(Usage.total_cost_usd).label("total_cost"), func.sum(Usage.actual_total_cost_usd).label("actual_cost"), ) # 过滤掉 pending/streaming 状态的请求(尚未完成的请求不应计入统计) query = query.filter(Usage.status.notin_(["pending", "streaming"])) # 过滤掉 unknown/pending provider(请求未到达任何提供商) query = query.filter(Usage.provider.notin_(["unknown", "pending"])) if self.start_date: query = query.filter(Usage.created_at >= self.start_date) if self.end_date: query = query.filter(Usage.created_at <= self.end_date) query = query.group_by(Usage.model).order_by(func.count(Usage.id).desc()).limit(self.limit) stats = query.all() context.add_audit_metadata( action="usage_by_model", start_date=self.start_date.isoformat() if self.start_date else None, end_date=self.end_date.isoformat() if self.end_date else None, limit=self.limit, result_count=len(stats), ) return [ { "model": model, "request_count": count, "total_tokens": int(tokens or 0), "total_cost": float(cost or 0), "actual_cost": float(actual_cost or 0), } for model, count, tokens, cost, actual_cost in stats ] class AdminUsageByUserAdapter(AdminApiAdapter): def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int): self.start_date = start_date self.end_date = end_date self.limit = limit async def handle(self, context): # type: ignore[override] db = context.db query = ( db.query( User.id, User.email, User.username, func.count(Usage.id).label("request_count"), func.sum(Usage.total_tokens).label("total_tokens"), func.sum(Usage.total_cost_usd).label("total_cost"), ) .join(Usage, Usage.user_id == User.id) .group_by(User.id, User.email, User.username) ) if self.start_date: query = query.filter(Usage.created_at >= self.start_date) if self.end_date: query = query.filter(Usage.created_at <= self.end_date) query = query.order_by(func.count(Usage.id).desc()).limit(self.limit) stats = query.all() context.add_audit_metadata( action="usage_by_user", start_date=self.start_date.isoformat() if self.start_date else None, end_date=self.end_date.isoformat() if self.end_date else None, limit=self.limit, result_count=len(stats), ) return [ { "user_id": user_id, "email": email, "username": username, "request_count": count, "total_tokens": int(tokens or 0), "total_cost": float(cost or 0), } for user_id, email, username, count, tokens, cost in stats ] class AdminUsageByProviderAdapter(AdminApiAdapter): def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int): self.start_date = start_date self.end_date = end_date self.limit = limit async def handle(self, context): # type: ignore[override] db = context.db # 从 request_candidates 表统计每个 Provider 的尝试次数和成功率 # 这样可以正确统计 Fallback 场景(一个请求可能尝试多个 Provider) from sqlalchemy import case, Integer attempt_query = db.query( RequestCandidate.provider_id, func.count(RequestCandidate.id).label("attempt_count"), func.sum( case((RequestCandidate.status == "success", 1), else_=0) ).label("success_count"), func.sum( case((RequestCandidate.status == "failed", 1), else_=0) ).label("failed_count"), func.avg(RequestCandidate.latency_ms).label("avg_latency_ms"), ).filter( RequestCandidate.provider_id.isnot(None), # 只统计实际执行的尝试(排除 available/skipped 状态) RequestCandidate.status.in_(["success", "failed"]), ) if self.start_date: attempt_query = attempt_query.filter(RequestCandidate.created_at >= self.start_date) if self.end_date: attempt_query = attempt_query.filter(RequestCandidate.created_at <= self.end_date) attempt_stats = ( attempt_query.group_by(RequestCandidate.provider_id) .order_by(func.count(RequestCandidate.id).desc()) .limit(self.limit) .all() ) # 从 Usage 表获取 token 和费用统计(基于成功的请求) usage_query = db.query( Usage.provider_id, func.count(Usage.id).label("request_count"), func.sum(Usage.total_tokens).label("total_tokens"), func.sum(Usage.total_cost_usd).label("total_cost"), func.sum(Usage.actual_total_cost_usd).label("actual_cost"), func.avg(Usage.response_time_ms).label("avg_response_time_ms"), ).filter( Usage.provider_id.isnot(None), # 过滤掉 pending/streaming 状态的请求 Usage.status.notin_(["pending", "streaming"]), ) if self.start_date: usage_query = usage_query.filter(Usage.created_at >= self.start_date) if self.end_date: usage_query = usage_query.filter(Usage.created_at <= self.end_date) usage_stats = usage_query.group_by(Usage.provider_id).all() usage_map = {str(u.provider_id): u for u in usage_stats} # 获取所有相关的 Provider ID provider_ids = set() for stat in attempt_stats: if stat.provider_id: provider_ids.add(stat.provider_id) for stat in usage_stats: if stat.provider_id: provider_ids.add(stat.provider_id) # 获取 Provider 名称映射 provider_map = {} if provider_ids: providers_data = ( db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all() ) provider_map = {str(p.id): p.name for p in providers_data} context.add_audit_metadata( action="usage_by_provider", start_date=self.start_date.isoformat() if self.start_date else None, end_date=self.end_date.isoformat() if self.end_date else None, limit=self.limit, result_count=len(attempt_stats), ) result = [] for stat in attempt_stats: provider_id_str = str(stat.provider_id) if stat.provider_id else None attempt_count = stat.attempt_count or 0 success_count = int(stat.success_count or 0) failed_count = int(stat.failed_count or 0) success_rate = (success_count / attempt_count * 100) if attempt_count > 0 else 0 # 从 usage_map 获取 token 和费用信息 usage_stat = usage_map.get(provider_id_str) result.append({ "provider_id": provider_id_str, "provider": provider_map.get(provider_id_str, "Unknown"), "request_count": attempt_count, # 尝试次数 "total_tokens": int(usage_stat.total_tokens or 0) if usage_stat else 0, "total_cost": float(usage_stat.total_cost or 0) if usage_stat else 0, "actual_cost": float(usage_stat.actual_cost or 0) if usage_stat else 0, "avg_response_time_ms": float(stat.avg_latency_ms or 0), "success_rate": round(success_rate, 2), "error_count": failed_count, }) return result class AdminUsageByApiFormatAdapter(AdminApiAdapter): def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int): self.start_date = start_date self.end_date = end_date self.limit = limit async def handle(self, context): # type: ignore[override] db = context.db query = db.query( Usage.api_format, func.count(Usage.id).label("request_count"), func.sum(Usage.total_tokens).label("total_tokens"), func.sum(Usage.total_cost_usd).label("total_cost"), func.sum(Usage.actual_total_cost_usd).label("actual_cost"), func.avg(Usage.response_time_ms).label("avg_response_time_ms"), ) # 过滤掉 pending/streaming 状态的请求 query = query.filter(Usage.status.notin_(["pending", "streaming"])) # 过滤掉 unknown/pending provider query = query.filter(Usage.provider.notin_(["unknown", "pending"])) # 只统计有 api_format 的记录 query = query.filter(Usage.api_format.isnot(None)) if self.start_date: query = query.filter(Usage.created_at >= self.start_date) if self.end_date: query = query.filter(Usage.created_at <= self.end_date) query = ( query.group_by(Usage.api_format) .order_by(func.count(Usage.id).desc()) .limit(self.limit) ) stats = query.all() context.add_audit_metadata( action="usage_by_api_format", start_date=self.start_date.isoformat() if self.start_date else None, end_date=self.end_date.isoformat() if self.end_date else None, limit=self.limit, result_count=len(stats), ) return [ { "api_format": api_format or "unknown", "request_count": count, "total_tokens": int(tokens or 0), "total_cost": float(cost or 0), "actual_cost": float(actual_cost or 0), "avg_response_time_ms": float(avg_response_time or 0), } for api_format, count, tokens, cost, actual_cost, avg_response_time in stats ] class AdminUsageRecordsAdapter(AdminApiAdapter): def __init__( self, start_date: Optional[datetime], end_date: Optional[datetime], user_id: Optional[str], username: Optional[str], model: Optional[str], provider: Optional[str], status: Optional[str], limit: int, offset: int, ): self.start_date = start_date self.end_date = end_date self.user_id = user_id self.username = username self.model = model self.provider = provider self.status = status self.limit = limit self.offset = offset async def handle(self, context): # type: ignore[override] db = context.db query = ( db.query(Usage, User, ProviderEndpoint, ProviderAPIKey) .outerjoin(User, Usage.user_id == User.id) .outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id) .outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id) ) if self.user_id: query = query.filter(Usage.user_id == self.user_id) if self.username: # 支持用户名模糊搜索 query = query.filter(User.username.ilike(f"%{self.username}%")) if self.model: # 支持模型名模糊搜索 query = query.filter(Usage.model.ilike(f"%{self.model}%")) if self.provider: # 支持提供商名称搜索(通过 Provider 表) query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True) query = query.filter(Provider.name.ilike(f"%{self.provider}%")) if self.status: # 状态筛选 # 旧的筛选值(基于 is_stream 和 status_code):stream, standard, error # 新的筛选值(基于 status 字段):pending, streaming, completed, failed, active if self.status == "stream": query = query.filter(Usage.is_stream == True) # noqa: E712 elif self.status == "standard": query = query.filter(Usage.is_stream == False) # noqa: E712 elif self.status == "error": query = query.filter( (Usage.status_code >= 400) | (Usage.error_message.isnot(None)) ) elif self.status in ("pending", "streaming", "completed"): # 新的状态筛选:直接按 status 字段过滤 query = query.filter(Usage.status == self.status) elif self.status == "failed": # 失败请求需要同时考虑新旧两种判断方式: # 1. 新方式:status = "failed" # 2. 旧方式:status_code >= 400 或 error_message 不为空 query = query.filter( (Usage.status == "failed") | (Usage.status_code >= 400) | (Usage.error_message.isnot(None)) ) elif self.status == "active": # 活跃请求:pending 或 streaming 状态 query = query.filter(Usage.status.in_(["pending", "streaming"])) if self.start_date: query = query.filter(Usage.created_at >= self.start_date) if self.end_date: query = query.filter(Usage.created_at <= self.end_date) total = query.count() records = ( query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all() ) request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id] fallback_map = {} if request_ids: # 只统计实际执行的候选(success 或 failed),不包括 skipped/pending/available executed_counts = ( db.query(RequestCandidate.request_id, func.count(RequestCandidate.id)) .filter( RequestCandidate.request_id.in_(request_ids), RequestCandidate.status.in_(["success", "failed"]), ) .group_by(RequestCandidate.request_id) .all() ) # 如果实际执行的候选数 > 1,说明发生了 Provider 切换 fallback_map = {req_id: count > 1 for req_id, count in executed_counts} context.add_audit_metadata( action="usage_records", start_date=self.start_date.isoformat() if self.start_date else None, end_date=self.end_date.isoformat() if self.end_date else None, user_id=self.user_id, username=self.username, model=self.model, provider=self.provider, status=self.status, limit=self.limit, offset=self.offset, total=total, ) # 构建 provider_id -> Provider 名称的映射,避免 N+1 查询 provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id] provider_map = {} if provider_ids: providers_data = ( db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all() ) provider_map = {str(p.id): p.name for p in providers_data} data = [] for usage, user, endpoint, api_key in records: actual_cost = ( float(usage.actual_total_cost_usd) if usage.actual_total_cost_usd is not None else 0.0 ) rate_multiplier = ( float(usage.rate_multiplier) if usage.rate_multiplier is not None else 1.0 ) # 提供商名称优先级:关联的 Provider 表 > usage.provider 字段 provider_name = usage.provider if usage.provider_id and str(usage.provider_id) in provider_map: provider_name = provider_map[str(usage.provider_id)] data.append( { "id": usage.id, "user_id": user.id if user else None, "user_email": user.email if user else "已删除用户", "username": user.username if user else "已删除用户", "provider": provider_name, "model": usage.model, "target_model": usage.target_model, # 映射后的目标模型名 "input_tokens": usage.input_tokens, "output_tokens": usage.output_tokens, "cache_creation_input_tokens": usage.cache_creation_input_tokens, "cache_read_input_tokens": usage.cache_read_input_tokens, "total_tokens": usage.total_tokens, "cost": float(usage.total_cost_usd), "actual_cost": actual_cost, "rate_multiplier": rate_multiplier, "response_time_ms": usage.response_time_ms, "created_at": usage.created_at.isoformat(), "is_stream": usage.is_stream, "input_price_per_1m": usage.input_price_per_1m, "output_price_per_1m": usage.output_price_per_1m, "cache_creation_price_per_1m": usage.cache_creation_price_per_1m, "cache_read_price_per_1m": usage.cache_read_price_per_1m, "status_code": usage.status_code, "error_message": usage.error_message, "status": usage.status, # 请求状态: pending, streaming, completed, failed "has_fallback": fallback_map.get(usage.request_id, False), "api_format": usage.api_format or (endpoint.api_format if endpoint and endpoint.api_format else None), "api_key_name": api_key.name if api_key else None, "request_metadata": usage.request_metadata, # Provider 响应元数据 } ) return { "records": data, "total": total, "limit": self.limit, "offset": self.offset, } class AdminActiveRequestsAdapter(AdminApiAdapter): """轻量级活跃请求状态查询适配器""" def __init__(self, ids: Optional[str]): self.ids = ids async def handle(self, context): # type: ignore[override] from src.services.usage import UsageService db = context.db id_list = None if self.ids: id_list = [id.strip() for id in self.ids.split(",") if id.strip()] if not id_list: return {"requests": []} requests = UsageService.get_active_requests_status(db=db, ids=id_list) return {"requests": requests} @dataclass class AdminUsageDetailAdapter(AdminApiAdapter): """Get detailed usage record with request/response body""" usage_id: str async def handle(self, context): # type: ignore[override] db = context.db usage_record = db.query(Usage).filter(Usage.id == self.usage_id).first() if not usage_record: raise HTTPException(status_code=404, detail="Usage record not found") user = db.query(User).filter(User.id == usage_record.user_id).first() api_key = db.query(ApiKey).filter(ApiKey.id == usage_record.api_key_id).first() # 获取阶梯计费信息 tiered_pricing_info = await self._get_tiered_pricing_info(db, usage_record) context.add_audit_metadata( action="usage_detail", usage_id=self.usage_id, ) return { "id": usage_record.id, "request_id": usage_record.request_id, "user": { "id": user.id if user else None, "username": user.username if user else "Unknown", "email": user.email if user else None, }, "api_key": { "id": api_key.id if api_key else None, "name": api_key.name if api_key else None, "display": api_key.get_display_key() if api_key else None, }, "provider": usage_record.provider, "api_format": usage_record.api_format, "model": usage_record.model, "target_model": usage_record.target_model, "tokens": { "input": usage_record.input_tokens, "output": usage_record.output_tokens, "total": usage_record.total_tokens, }, "cost": { "input": usage_record.input_cost_usd, "output": usage_record.output_cost_usd, "total": usage_record.total_cost_usd, }, "cache_creation_input_tokens": usage_record.cache_creation_input_tokens, "cache_read_input_tokens": usage_record.cache_read_input_tokens, "cache_creation_cost": getattr(usage_record, "cache_creation_cost_usd", 0.0), "cache_read_cost": getattr(usage_record, "cache_read_cost_usd", 0.0), "request_cost": getattr(usage_record, "request_cost_usd", 0.0), "input_price_per_1m": usage_record.input_price_per_1m, "output_price_per_1m": usage_record.output_price_per_1m, "cache_creation_price_per_1m": usage_record.cache_creation_price_per_1m, "cache_read_price_per_1m": usage_record.cache_read_price_per_1m, "price_per_request": usage_record.price_per_request, "request_type": usage_record.request_type, "is_stream": usage_record.is_stream, "status_code": usage_record.status_code, "error_message": usage_record.error_message, "response_time_ms": usage_record.response_time_ms, "created_at": usage_record.created_at.isoformat() if usage_record.created_at else None, "request_headers": usage_record.request_headers, "request_body": usage_record.get_request_body(), "provider_request_headers": usage_record.provider_request_headers, "response_headers": usage_record.response_headers, "response_body": usage_record.get_response_body(), "metadata": usage_record.request_metadata, "tiered_pricing": tiered_pricing_info, } async def _get_tiered_pricing_info(self, db, usage_record) -> dict | None: """获取阶梯计费信息""" from src.services.model.cost import ModelCostService # 计算总输入上下文(用于阶梯判定):输入 + 缓存创建 + 缓存读取 input_tokens = usage_record.input_tokens or 0 cache_creation_tokens = usage_record.cache_creation_input_tokens or 0 cache_read_tokens = usage_record.cache_read_input_tokens or 0 total_input_context = input_tokens + cache_creation_tokens + cache_read_tokens # 尝试获取模型的阶梯配置(带来源信息) cost_service = ModelCostService(db) pricing_result = await cost_service.get_tiered_pricing_with_source_async( usage_record.provider, usage_record.model ) if not pricing_result: return None tiered_pricing = pricing_result.get("pricing") pricing_source = pricing_result.get("source") # 'provider' 或 'global' if not tiered_pricing or not tiered_pricing.get("tiers"): return None tiers = tiered_pricing.get("tiers", []) if not tiers: return None # 找到命中的阶梯 tier_index = None matched_tier = None for i, tier in enumerate(tiers): up_to = tier.get("up_to") if up_to is None or total_input_context <= up_to: tier_index = i matched_tier = tier break # 如果都没匹配,使用最后一个阶梯 if tier_index is None and tiers: tier_index = len(tiers) - 1 matched_tier = tiers[-1] return { "total_input_context": total_input_context, "tier_index": tier_index, "tier_count": len(tiers), "current_tier": matched_tier, "tiers": tiers, "source": pricing_source, # 定价来源: 'provider' 或 'global' } # ==================== 缓存亲和性分析 ==================== @router.get("/cache-affinity/ttl-analysis") async def analyze_cache_affinity_ttl( request: Request, user_id: Optional[str] = Query(None, description="指定用户 ID"), api_key_id: Optional[str] = Query(None, description="指定 API Key ID"), hours: int = Query(168, ge=1, le=720, description="分析最近多少小时的数据"), db: Session = Depends(get_db), ): """ 分析用户请求间隔分布,推荐合适的缓存亲和性 TTL。 通过分析同一用户连续请求之间的时间间隔,判断用户的使用模式: - 高频用户(间隔短):5 分钟 TTL 足够 - 中频用户:15-30 分钟 TTL - 低频用户(间隔长):需要 60 分钟 TTL """ adapter = CacheAffinityTTLAnalysisAdapter( user_id=user_id, api_key_id=api_key_id, hours=hours, ) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/cache-affinity/hit-analysis") async def analyze_cache_hit( request: Request, user_id: Optional[str] = Query(None, description="指定用户 ID"), api_key_id: Optional[str] = Query(None, description="指定 API Key ID"), hours: int = Query(168, ge=1, le=720, description="分析最近多少小时的数据"), db: Session = Depends(get_db), ): """ 分析缓存命中情况。 返回缓存命中率、节省的费用等统计信息。 """ adapter = CacheHitAnalysisAdapter( user_id=user_id, api_key_id=api_key_id, hours=hours, ) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) class CacheAffinityTTLAnalysisAdapter(AdminApiAdapter): """缓存亲和性 TTL 分析适配器""" def __init__( self, user_id: Optional[str], api_key_id: Optional[str], hours: int, ): self.user_id = user_id self.api_key_id = api_key_id self.hours = hours async def handle(self, context): # type: ignore[override] db = context.db result = UsageService.analyze_cache_affinity_ttl( db=db, user_id=self.user_id, api_key_id=self.api_key_id, hours=self.hours, ) context.add_audit_metadata( action="cache_affinity_ttl_analysis", user_id=self.user_id, api_key_id=self.api_key_id, hours=self.hours, total_users_analyzed=result.get("total_users_analyzed", 0), ) return result class CacheHitAnalysisAdapter(AdminApiAdapter): """缓存命中分析适配器""" def __init__( self, user_id: Optional[str], api_key_id: Optional[str], hours: int, ): self.user_id = user_id self.api_key_id = api_key_id self.hours = hours async def handle(self, context): # type: ignore[override] db = context.db result = UsageService.get_cache_hit_analysis( db=db, user_id=self.user_id, api_key_id=self.api_key_id, hours=self.hours, ) context.add_audit_metadata( action="cache_hit_analysis", user_id=self.user_id, api_key_id=self.api_key_id, hours=self.hours, ) return result @router.get("/cache-affinity/interval-timeline") async def get_interval_timeline( request: Request, hours: int = Query(168, ge=1, le=720, description="分析最近多少小时的数据"), limit: int = Query(1000, ge=100, le=5000, description="最大返回数据点数量"), user_id: Optional[str] = Query(None, description="指定用户 ID"), include_user_info: bool = Query(False, description="是否包含用户信息(用于管理员多用户视图)"), db: Session = Depends(get_db), ): """ 获取请求间隔时间线数据,用于散点图展示。 返回每个请求的时间点和与上一个请求的间隔(分钟), 可用于可视化用户请求模式。 当 include_user_info=true 且未指定 user_id 时,返回数据会包含: - points 中每个点包含 user_id 字段 - users 字段包含 user_id -> username 的映射 """ adapter = IntervalTimelineAdapter( hours=hours, limit=limit, user_id=user_id, include_user_info=include_user_info, ) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) class IntervalTimelineAdapter(AdminApiAdapter): """请求间隔时间线适配器""" def __init__( self, hours: int, limit: int, user_id: Optional[str] = None, include_user_info: bool = False, ): self.hours = hours self.limit = limit self.user_id = user_id self.include_user_info = include_user_info async def handle(self, context): # type: ignore[override] db = context.db result = UsageService.get_interval_timeline( db=db, hours=self.hours, limit=self.limit, user_id=self.user_id, include_user_info=self.include_user_info, ) context.add_audit_metadata( action="interval_timeline", hours=self.hours, limit=self.limit, user_id=self.user_id, include_user_info=self.include_user_info, total_points=result.get("total_points", 0), ) return result