From abc41c7d3cf47f342c7531b8d66aaf7032e493f6 Mon Sep 17 00:00:00 2001 From: fawney19 Date: Thu, 11 Dec 2025 17:47:59 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E7=9B=91=E6=8E=A7=E5=92=8C=E4=BD=BF=E7=94=A8=E9=87=8F=E7=BB=9F?= =?UTF-8?q?=E8=AE=A1=20API=20=E7=AB=AF=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/admin/usage/routes.py | 181 +++++++ src/api/user_me/routes.py | 33 ++ src/services/cache/aware_scheduler.py | 13 +- .../orchestration/error_classifier.py | 29 +- src/services/usage/service.py | 458 ++++++++++++++++++ 5 files changed, 707 insertions(+), 7 deletions(-) diff --git a/src/api/admin/usage/routes.py b/src/api/admin/usage/routes.py index 7244a9d..80489a6 100644 --- a/src/api/admin/usage/routes.py +++ b/src/api/admin/usage/routes.py @@ -800,3 +800,184 @@ class AdminUsageDetailAdapter(AdminApiAdapter): "tiers": tiers, "source": pricing_source, # 定价来源: 'provider' 或 'global' } + + +# ==================== 缓存亲和性分析 ==================== + + +@router.get("/cache-affinity/ttl-analysis") +async def analyze_cache_affinity_ttl( + request: Request, + user_id: Optional[str] = Query(None, description="指定用户 ID"), + api_key_id: Optional[str] = Query(None, description="指定 API Key ID"), + hours: int = Query(168, ge=1, le=720, description="分析最近多少小时的数据"), + db: Session = Depends(get_db), +): + """ + 分析用户请求间隔分布,推荐合适的缓存亲和性 TTL。 + + 通过分析同一用户连续请求之间的时间间隔,判断用户的使用模式: + - 高频用户(间隔短):5 分钟 TTL 足够 + - 中频用户:15-30 分钟 TTL + - 低频用户(间隔长):需要 60 分钟 TTL + """ + adapter = CacheAffinityTTLAnalysisAdapter( + user_id=user_id, + api_key_id=api_key_id, + hours=hours, + ) + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +@router.get("/cache-affinity/hit-analysis") +async def analyze_cache_hit( + request: Request, + user_id: Optional[str] = Query(None, description="指定用户 ID"), + api_key_id: Optional[str] = Query(None, description="指定 API Key ID"), + hours: int = Query(168, ge=1, le=720, description="分析最近多少小时的数据"), + db: Session = Depends(get_db), +): + """ + 分析缓存命中情况。 + + 返回缓存命中率、节省的费用等统计信息。 + """ + adapter = CacheHitAnalysisAdapter( + user_id=user_id, + api_key_id=api_key_id, + hours=hours, + ) + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +class CacheAffinityTTLAnalysisAdapter(AdminApiAdapter): + """缓存亲和性 TTL 分析适配器""" + + def __init__( + self, + user_id: Optional[str], + api_key_id: Optional[str], + hours: int, + ): + self.user_id = user_id + self.api_key_id = api_key_id + self.hours = hours + + async def handle(self, context): # type: ignore[override] + db = context.db + + result = UsageService.analyze_cache_affinity_ttl( + db=db, + user_id=self.user_id, + api_key_id=self.api_key_id, + hours=self.hours, + ) + + context.add_audit_metadata( + action="cache_affinity_ttl_analysis", + user_id=self.user_id, + api_key_id=self.api_key_id, + hours=self.hours, + total_users_analyzed=result.get("total_users_analyzed", 0), + ) + + return result + + +class CacheHitAnalysisAdapter(AdminApiAdapter): + """缓存命中分析适配器""" + + def __init__( + self, + user_id: Optional[str], + api_key_id: Optional[str], + hours: int, + ): + self.user_id = user_id + self.api_key_id = api_key_id + self.hours = hours + + async def handle(self, context): # type: ignore[override] + db = context.db + + result = UsageService.get_cache_hit_analysis( + db=db, + user_id=self.user_id, + api_key_id=self.api_key_id, + hours=self.hours, + ) + + context.add_audit_metadata( + action="cache_hit_analysis", + user_id=self.user_id, + api_key_id=self.api_key_id, + hours=self.hours, + ) + + return result + + +@router.get("/cache-affinity/interval-timeline") +async def get_interval_timeline( + request: Request, + hours: int = Query(168, ge=1, le=720, description="分析最近多少小时的数据"), + limit: int = Query(1000, ge=100, le=5000, description="最大返回数据点数量"), + user_id: Optional[str] = Query(None, description="指定用户 ID"), + include_user_info: bool = Query(False, description="是否包含用户信息(用于管理员多用户视图)"), + db: Session = Depends(get_db), +): + """ + 获取请求间隔时间线数据,用于散点图展示。 + + 返回每个请求的时间点和与上一个请求的间隔(分钟), + 可用于可视化用户请求模式。 + + 当 include_user_info=true 且未指定 user_id 时,返回数据会包含: + - points 中每个点包含 user_id 字段 + - users 字段包含 user_id -> username 的映射 + """ + adapter = IntervalTimelineAdapter( + hours=hours, + limit=limit, + user_id=user_id, + include_user_info=include_user_info, + ) + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +class IntervalTimelineAdapter(AdminApiAdapter): + """请求间隔时间线适配器""" + + def __init__( + self, + hours: int, + limit: int, + user_id: Optional[str] = None, + include_user_info: bool = False, + ): + self.hours = hours + self.limit = limit + self.user_id = user_id + self.include_user_info = include_user_info + + async def handle(self, context): # type: ignore[override] + db = context.db + + result = UsageService.get_interval_timeline( + db=db, + hours=self.hours, + limit=self.limit, + user_id=self.user_id, + include_user_info=self.include_user_info, + ) + + context.add_audit_metadata( + action="interval_timeline", + hours=self.hours, + limit=self.limit, + user_id=self.user_id, + include_user_info=self.include_user_info, + total_points=result.get("total_points", 0), + ) + + return result diff --git a/src/api/user_me/routes.py b/src/api/user_me/routes.py index 3ed3bc5..c03aae4 100644 --- a/src/api/user_me/routes.py +++ b/src/api/user_me/routes.py @@ -121,6 +121,18 @@ async def get_my_active_requests( return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) +@router.get("/usage/interval-timeline") +async def get_my_interval_timeline( + request: Request, + hours: int = Query(168, ge=1, le=720, description="分析最近多少小时的数据"), + limit: int = Query(1000, ge=100, le=5000, description="最大返回数据点数量"), + db: Session = Depends(get_db), +): + """获取当前用户的请求间隔时间线数据,用于散点图展示""" + adapter = GetMyIntervalTimelineAdapter(hours=hours, limit=limit) + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + @router.get("/providers") async def list_available_providers(request: Request, db: Session = Depends(get_db)): adapter = ListAvailableProvidersAdapter() @@ -676,6 +688,27 @@ class GetActiveRequestsAdapter(AuthenticatedApiAdapter): return {"requests": requests} +@dataclass +class GetMyIntervalTimelineAdapter(AuthenticatedApiAdapter): + """获取当前用户的请求间隔时间线适配器""" + + hours: int + limit: int + + async def handle(self, context): # type: ignore[override] + db = context.db + user = context.user + + result = UsageService.get_interval_timeline( + db=db, + hours=self.hours, + limit=self.limit, + user_id=str(user.id), + ) + + return result + + class ListAvailableProvidersAdapter(AuthenticatedApiAdapter): async def handle(self, context): # type: ignore[override] from sqlalchemy.orm import selectinload diff --git a/src/services/cache/aware_scheduler.py b/src/services/cache/aware_scheduler.py index 873262a..553f363 100644 --- a/src/services/cache/aware_scheduler.py +++ b/src/services/cache/aware_scheduler.py @@ -862,13 +862,14 @@ class CacheAwareScheduler: # Key 级别的能力匹配检查 # 注意:模型级别的能力检查已在 _check_model_support 中完成 - if capability_requirements: - from src.core.key_capabilities import check_capability_match + # 始终执行检查,即使 capability_requirements 为空 + # 因为 check_capability_match 会检查 Key 的 EXCLUSIVE 能力是否被浪费 + from src.core.key_capabilities import check_capability_match - key_caps: Dict[str, bool] = dict(key.capabilities or {}) - is_match, skip_reason = check_capability_match(key_caps, capability_requirements) - if not is_match: - return False, skip_reason + key_caps: Dict[str, bool] = dict(key.capabilities or {}) + is_match, skip_reason = check_capability_match(key_caps, capability_requirements) + if not is_match: + return False, skip_reason return True, None diff --git a/src/services/orchestration/error_classifier.py b/src/services/orchestration/error_classifier.py index b5cd168..30b71e9 100644 --- a/src/services/orchestration/error_classifier.py +++ b/src/services/orchestration/error_classifier.py @@ -67,12 +67,13 @@ class ErrorClassifier: # 表示客户端请求错误的关键词(不区分大小写) # 这些错误是由用户请求本身导致的,换 Provider 也无济于事 + # 注意:标准 API 返回的 error.type 已在 CLIENT_ERROR_TYPES 中处理 + # 这里主要用于匹配非标准格式或第三方代理的错误消息 CLIENT_ERROR_PATTERNS: Tuple[str, ...] = ( "could not process image", # 图片处理失败 "image too large", # 图片过大 "invalid image", # 无效图片 "unsupported image", # 不支持的图片格式 - "invalid_request_error", # OpenAI/Claude 通用客户端错误类型 "content_policy_violation", # 内容违规 "invalid_api_key", # 无效的 API Key(不同于认证失败) "context_length_exceeded", # 上下文长度超限 @@ -85,6 +86,7 @@ class ErrorClassifier: "image exceeds", # 图片超出限制 "pdf too large", # PDF 过大 "file too large", # 文件过大 + "tool_use_id", # tool_result 引用了不存在的 tool_use(兼容非标准代理) ) def __init__( @@ -105,10 +107,22 @@ class ErrorClassifier: self.adaptive_manager = adaptive_manager or get_adaptive_manager() self.cache_scheduler = cache_scheduler + # 表示客户端错误的 error type(不区分大小写) + # 这些 type 表明是请求本身的问题,不应重试 + CLIENT_ERROR_TYPES: Tuple[str, ...] = ( + "invalid_request_error", # Claude/OpenAI 标准客户端错误类型 + "invalid_argument", # Gemini 参数错误 + "failed_precondition", # Gemini 前置条件错误 + ) + def _is_client_error(self, error_text: Optional[str]) -> bool: """ 检测错误响应是否为客户端错误(不应重试) + 判断逻辑: + 1. 检查 error.type 是否为已知的客户端错误类型 + 2. 检查错误文本是否包含已知的客户端错误模式 + Args: error_text: 错误响应文本 @@ -118,6 +132,19 @@ class ErrorClassifier: if not error_text: return False + # 尝试解析 JSON 并检查 error type + try: + data = json.loads(error_text) + if isinstance(data.get("error"), dict): + error_type = data["error"].get("type", "") + if error_type and any( + t.lower() in error_type.lower() for t in self.CLIENT_ERROR_TYPES + ): + return True + except (json.JSONDecodeError, TypeError, KeyError): + pass + + # 回退到关键词匹配 error_lower = error_text.lower() return any(pattern.lower() in error_lower for pattern in self.CLIENT_ERROR_PATTERNS) diff --git a/src/services/usage/service.py b/src/services/usage/service.py index cc55fee..60e8a25 100644 --- a/src/services/usage/service.py +++ b/src/services/usage/service.py @@ -1394,3 +1394,461 @@ class UsageService: } for r in records ] + + # ========== 缓存亲和性分析方法 ========== + + @staticmethod + def analyze_cache_affinity_ttl( + db: Session, + user_id: Optional[str] = None, + api_key_id: Optional[str] = None, + hours: int = 168, + ) -> Dict[str, Any]: + """ + 分析用户请求间隔分布,推荐合适的缓存亲和性 TTL + + 通过分析同一用户连续请求之间的时间间隔,判断用户的使用模式: + - 高频用户(间隔短):5 分钟 TTL 足够 + - 中频用户:15-30 分钟 TTL + - 低频用户(间隔长):需要 60 分钟 TTL + + Args: + db: 数据库会话 + user_id: 指定用户 ID(可选,为空则分析所有用户) + api_key_id: 指定 API Key ID(可选) + hours: 分析最近多少小时的数据 + + Returns: + 包含分析结果的字典 + """ + from sqlalchemy import text + + # 计算时间范围 + start_date = datetime.now(timezone.utc) - timedelta(hours=hours) + + # 构建 SQL 查询 - 使用窗口函数计算请求间隔 + # 按 user_id 或 api_key_id 分组,计算同一组内连续请求的时间差 + group_by_field = "api_key_id" if api_key_id else "user_id" + + # 构建过滤条件 + filter_clause = "" + if user_id or api_key_id: + filter_clause = f"AND {group_by_field} = :filter_id" + + sql = text(f""" + WITH user_requests AS ( + SELECT + {group_by_field} as group_id, + created_at, + LAG(created_at) OVER ( + PARTITION BY {group_by_field} + ORDER BY created_at + ) as prev_request_at + FROM usage + WHERE status = 'completed' + AND created_at > :start_date + AND {group_by_field} IS NOT NULL + {filter_clause} + ), + intervals AS ( + SELECT + group_id, + EXTRACT(EPOCH FROM (created_at - prev_request_at)) / 60.0 as interval_minutes + FROM user_requests + WHERE prev_request_at IS NOT NULL + ), + user_stats AS ( + SELECT + group_id, + COUNT(*) as request_count, + COUNT(*) FILTER (WHERE interval_minutes <= 5) as within_5min, + COUNT(*) FILTER (WHERE interval_minutes > 5 AND interval_minutes <= 15) as within_15min, + COUNT(*) FILTER (WHERE interval_minutes > 15 AND interval_minutes <= 30) as within_30min, + COUNT(*) FILTER (WHERE interval_minutes > 30 AND interval_minutes <= 60) as within_60min, + COUNT(*) FILTER (WHERE interval_minutes > 60) as over_60min, + PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY interval_minutes) as median_interval, + PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY interval_minutes) as p75_interval, + PERCENTILE_CONT(0.90) WITHIN GROUP (ORDER BY interval_minutes) as p90_interval, + AVG(interval_minutes) as avg_interval, + MIN(interval_minutes) as min_interval, + MAX(interval_minutes) as max_interval + FROM intervals + GROUP BY group_id + HAVING COUNT(*) >= 2 + ) + SELECT * FROM user_stats + ORDER BY request_count DESC + """) + + params: Dict[str, Any] = { + "start_date": start_date, + } + if user_id: + params["filter_id"] = user_id + elif api_key_id: + params["filter_id"] = api_key_id + + result = db.execute(sql, params) + rows = result.fetchall() + + # 收集所有 user_id 以便批量查询用户信息 + group_ids = [row[0] for row in rows] + + # 如果是按 user_id 分组,查询用户信息 + user_info_map: Dict[str, Dict[str, str]] = {} + if group_by_field == "user_id" and group_ids: + users = db.query(User).filter(User.id.in_(group_ids)).all() + for user in users: + user_info_map[str(user.id)] = { + "username": user.username, + "email": user.email or "", + } + + # 处理结果 + users_analysis = [] + for row in rows: + # row 是一个 tuple,按查询顺序访问 + ( + group_id, + request_count, + within_5min, + within_15min, + within_30min, + within_60min, + over_60min, + median_interval, + p75_interval, + p90_interval, + avg_interval, + min_interval, + max_interval, + ) = row + + # 计算推荐 TTL + recommended_ttl = UsageService._calculate_recommended_ttl( + p75_interval, p90_interval + ) + + # 获取用户信息 + user_info = user_info_map.get(str(group_id), {}) + + # 计算各区间占比 + total_intervals = request_count + users_analysis.append({ + "group_id": group_id, + "username": user_info.get("username"), + "email": user_info.get("email"), + "request_count": request_count, + "interval_distribution": { + "within_5min": within_5min, + "within_15min": within_15min, + "within_30min": within_30min, + "within_60min": within_60min, + "over_60min": over_60min, + }, + "interval_percentages": { + "within_5min": round(within_5min / total_intervals * 100, 1), + "within_15min": round(within_15min / total_intervals * 100, 1), + "within_30min": round(within_30min / total_intervals * 100, 1), + "within_60min": round(within_60min / total_intervals * 100, 1), + "over_60min": round(over_60min / total_intervals * 100, 1), + }, + "percentiles": { + "p50": round(float(median_interval), 2) if median_interval else None, + "p75": round(float(p75_interval), 2) if p75_interval else None, + "p90": round(float(p90_interval), 2) if p90_interval else None, + }, + "avg_interval_minutes": round(float(avg_interval), 2) if avg_interval else None, + "min_interval_minutes": round(float(min_interval), 2) if min_interval else None, + "max_interval_minutes": round(float(max_interval), 2) if max_interval else None, + "recommended_ttl_minutes": recommended_ttl, + "recommendation_reason": UsageService._get_ttl_recommendation_reason( + recommended_ttl, p75_interval, p90_interval + ), + }) + + # 汇总统计 + ttl_distribution = {"5min": 0, "15min": 0, "30min": 0, "60min": 0} + for analysis in users_analysis: + ttl = analysis["recommended_ttl_minutes"] + if ttl <= 5: + ttl_distribution["5min"] += 1 + elif ttl <= 15: + ttl_distribution["15min"] += 1 + elif ttl <= 30: + ttl_distribution["30min"] += 1 + else: + ttl_distribution["60min"] += 1 + + return { + "analysis_period_hours": hours, + "total_users_analyzed": len(users_analysis), + "ttl_distribution": ttl_distribution, + "users": users_analysis, + } + + @staticmethod + def _calculate_recommended_ttl( + p75_interval: Optional[float], + p90_interval: Optional[float], + ) -> int: + """ + 根据请求间隔分布计算推荐的缓存 TTL + + 策略: + - 如果 90% 的请求间隔都在 5 分钟内 → 5 分钟 TTL + - 如果 75% 的请求间隔在 15 分钟内 → 15 分钟 TTL + - 如果 75% 的请求间隔在 30 分钟内 → 30 分钟 TTL + - 否则 → 60 分钟 TTL + """ + if p90_interval is None or p75_interval is None: + return 5 # 默认值 + + # 如果 90% 的间隔都在 5 分钟内 + if p90_interval <= 5: + return 5 + + # 如果 75% 的间隔在 15 分钟内 + if p75_interval <= 15: + return 15 + + # 如果 75% 的间隔在 30 分钟内 + if p75_interval <= 30: + return 30 + + # 低频用户,需要更长的 TTL + return 60 + + @staticmethod + def _get_ttl_recommendation_reason( + ttl: int, + p75_interval: Optional[float], + p90_interval: Optional[float], + ) -> str: + """生成 TTL 推荐理由""" + if p75_interval is None or p90_interval is None: + return "数据不足,使用默认值" + + if ttl == 5: + return f"高频用户:90% 的请求间隔在 {p90_interval:.1f} 分钟内" + elif ttl == 15: + return f"中高频用户:75% 的请求间隔在 {p75_interval:.1f} 分钟内" + elif ttl == 30: + return f"中频用户:75% 的请求间隔在 {p75_interval:.1f} 分钟内" + else: + return f"低频用户:75% 的请求间隔为 {p75_interval:.1f} 分钟,建议使用长 TTL" + + @staticmethod + def get_cache_hit_analysis( + db: Session, + user_id: Optional[str] = None, + api_key_id: Optional[str] = None, + hours: int = 168, + ) -> Dict[str, Any]: + """ + 分析缓存命中情况 + + Args: + db: 数据库会话 + user_id: 指定用户 ID(可选) + api_key_id: 指定 API Key ID(可选) + hours: 分析最近多少小时的数据 + + Returns: + 缓存命中分析结果 + """ + start_date = datetime.now(timezone.utc) - timedelta(hours=hours) + + # 基础查询 + query = db.query( + func.count(Usage.id).label("total_requests"), + func.sum(Usage.input_tokens).label("total_input_tokens"), + func.sum(Usage.cache_read_input_tokens).label("total_cache_read_tokens"), + func.sum(Usage.cache_creation_input_tokens).label("total_cache_creation_tokens"), + func.sum(Usage.cache_read_cost_usd).label("total_cache_read_cost"), + func.sum(Usage.cache_creation_cost_usd).label("total_cache_creation_cost"), + ).filter( + Usage.status == "completed", + Usage.created_at >= start_date, + ) + + if user_id: + query = query.filter(Usage.user_id == user_id) + if api_key_id: + query = query.filter(Usage.api_key_id == api_key_id) + + result = query.first() + + total_requests = result.total_requests or 0 + total_input_tokens = result.total_input_tokens or 0 + total_cache_read_tokens = result.total_cache_read_tokens or 0 + total_cache_creation_tokens = result.total_cache_creation_tokens or 0 + total_cache_read_cost = float(result.total_cache_read_cost or 0) + total_cache_creation_cost = float(result.total_cache_creation_cost or 0) + + # 计算缓存命中率(按 token 数) + # 总输入上下文 = input_tokens + cache_read_tokens(因为 input_tokens 不含 cache_read) + # 或者如果 input_tokens 已经包含 cache_read,则直接用 input_tokens + # 这里假设 cache_read_tokens 是额外的,命中率 = cache_read / (input + cache_read) + total_context_tokens = total_input_tokens + total_cache_read_tokens + cache_hit_rate = 0.0 + if total_context_tokens > 0: + cache_hit_rate = total_cache_read_tokens / total_context_tokens * 100 + + # 计算节省的费用 + # 缓存读取价格是正常输入价格的 10%,所以节省了 90% + # 节省 = cache_read_tokens * (正常价格 - 缓存价格) = cache_read_cost * 9 + # 因为 cache_read_cost 是按 10% 价格算的,如果按 100% 算就是 10 倍 + estimated_savings = total_cache_read_cost * 9 # 节省了 90% + + # 统计有缓存命中的请求数 + requests_with_cache_hit = db.query(func.count(Usage.id)).filter( + Usage.status == "completed", + Usage.created_at >= start_date, + Usage.cache_read_input_tokens > 0, + ) + if user_id: + requests_with_cache_hit = requests_with_cache_hit.filter(Usage.user_id == user_id) + if api_key_id: + requests_with_cache_hit = requests_with_cache_hit.filter(Usage.api_key_id == api_key_id) + requests_with_cache_hit = requests_with_cache_hit.scalar() or 0 + + return { + "analysis_period_hours": hours, + "total_requests": total_requests, + "requests_with_cache_hit": requests_with_cache_hit, + "request_cache_hit_rate": round(requests_with_cache_hit / total_requests * 100, 2) if total_requests > 0 else 0, + "total_input_tokens": total_input_tokens, + "total_cache_read_tokens": total_cache_read_tokens, + "total_cache_creation_tokens": total_cache_creation_tokens, + "token_cache_hit_rate": round(cache_hit_rate, 2), + "total_cache_read_cost_usd": round(total_cache_read_cost, 4), + "total_cache_creation_cost_usd": round(total_cache_creation_cost, 4), + "estimated_savings_usd": round(estimated_savings, 4), + } + + @staticmethod + def get_interval_timeline( + db: Session, + hours: int = 168, + limit: int = 1000, + user_id: Optional[str] = None, + include_user_info: bool = False, + ) -> Dict[str, Any]: + """ + 获取请求间隔时间线数据,用于散点图展示 + + Args: + db: 数据库会话 + hours: 分析最近多少小时的数据 + limit: 最大返回数据点数量 + user_id: 指定用户 ID(可选,为空则返回所有用户) + include_user_info: 是否包含用户信息(用于管理员多用户视图) + + Returns: + 包含时间线数据点的字典 + """ + from sqlalchemy import text + + start_date = datetime.now(timezone.utc) - timedelta(hours=hours) + + # 构建用户过滤条件 + user_filter = "AND u.user_id = :user_id" if user_id else "" + + # 根据是否需要用户信息选择不同的查询 + if include_user_info and not user_id: + # 管理员视图:返回带用户信息的数据点 + sql = text(f""" + WITH request_intervals AS ( + SELECT + u.created_at, + u.user_id, + usr.username, + LAG(u.created_at) OVER ( + PARTITION BY u.user_id + ORDER BY u.created_at + ) as prev_request_at + FROM usage u + LEFT JOIN users usr ON u.user_id = usr.id + WHERE u.status = 'completed' + AND u.created_at > :start_date + AND u.user_id IS NOT NULL + {user_filter} + ) + SELECT + created_at, + user_id, + username, + EXTRACT(EPOCH FROM (created_at - prev_request_at)) / 60.0 as interval_minutes + FROM request_intervals + WHERE prev_request_at IS NOT NULL + AND EXTRACT(EPOCH FROM (created_at - prev_request_at)) / 60.0 <= 120 + ORDER BY created_at + LIMIT :limit + """) + else: + # 普通视图:只返回时间和间隔 + sql = text(f""" + WITH request_intervals AS ( + SELECT + u.created_at, + u.user_id, + LAG(u.created_at) OVER ( + PARTITION BY u.user_id + ORDER BY u.created_at + ) as prev_request_at + FROM usage u + WHERE u.status = 'completed' + AND u.created_at > :start_date + AND u.user_id IS NOT NULL + {user_filter} + ) + SELECT + created_at, + EXTRACT(EPOCH FROM (created_at - prev_request_at)) / 60.0 as interval_minutes + FROM request_intervals + WHERE prev_request_at IS NOT NULL + AND EXTRACT(EPOCH FROM (created_at - prev_request_at)) / 60.0 <= 120 + ORDER BY created_at + LIMIT :limit + """) + + params: Dict[str, Any] = {"start_date": start_date, "limit": limit} + if user_id: + params["user_id"] = user_id + + result = db.execute(sql, params) + rows = result.fetchall() + + # 转换为时间线数据点 + points = [] + users_map: Dict[str, str] = {} # user_id -> username + + if include_user_info and not user_id: + for row in rows: + created_at, row_user_id, username, interval_minutes = row + points.append({ + "x": created_at.isoformat(), + "y": round(float(interval_minutes), 2), + "user_id": str(row_user_id), + }) + if row_user_id and username: + users_map[str(row_user_id)] = username + else: + for row in rows: + created_at, interval_minutes = row + points.append({ + "x": created_at.isoformat(), + "y": round(float(interval_minutes), 2) + }) + + response: Dict[str, Any] = { + "analysis_period_hours": hours, + "total_points": len(points), + "points": points, + } + + if include_user_info and not user_id: + response["users"] = users_map + + return response