Files
Aether/src/services/usage/service.py
fawney19 7b932d7afb refactor: optimize middleware with pure ASGI implementation and enhance security measures
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling
- Add trusted proxy count configuration for client IP extraction in reverse proxy environments
- Implement audit log cleanup scheduler with configurable retention period
- Replace plaintext token logging with SHA256 hash fingerprints for security
- Fix database session lifecycle management in middleware
- Improve request tracing and error logging throughout the system
- Add comprehensive tests for pipeline architecture
2025-12-18 19:07:20 +08:00

1954 lines
74 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
用量统计和配额管理服务
"""
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.core.enums import ProviderBillingType
from src.core.logger import logger
from src.models.database import ApiKey, Provider, ProviderAPIKey, Usage, User, UserRole
from src.services.model.cost import ModelCostService
from src.services.system.config import SystemConfigService
class UsageService:
"""用量统计服务"""
# ==================== 内部数据类 ====================
@staticmethod
def _build_usage_params(
*,
db: Session,
user: Optional[User],
api_key: Optional[ApiKey],
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
request_type: str,
api_format: Optional[str],
is_stream: bool,
response_time_ms: Optional[int],
first_byte_time_ms: Optional[int],
status_code: int,
error_message: Optional[str],
metadata: Optional[Dict[str, Any]],
request_headers: Optional[Dict[str, Any]],
request_body: Optional[Any],
provider_request_headers: Optional[Dict[str, Any]],
response_headers: Optional[Dict[str, Any]],
response_body: Optional[Any],
request_id: str,
provider_id: Optional[str],
provider_endpoint_id: Optional[str],
provider_api_key_id: Optional[str],
status: str,
target_model: Optional[str],
# 成本计算结果
input_cost: float,
output_cost: float,
cache_creation_cost: float,
cache_read_cost: float,
cache_cost: float,
request_cost: float,
total_cost: float,
# 价格信息
input_price: float,
output_price: float,
cache_creation_price: Optional[float],
cache_read_price: Optional[float],
request_price: Optional[float],
# 倍率
actual_rate_multiplier: float,
is_free_tier: bool,
) -> Dict[str, Any]:
"""构建 Usage 记录的参数字典(内部方法,避免代码重复)"""
# 根据配置决定是否记录请求详情
should_log_headers = SystemConfigService.should_log_headers(db)
should_log_body = SystemConfigService.should_log_body(db)
# 处理请求头(可能需要脱敏)
processed_request_headers = None
if should_log_headers and request_headers:
processed_request_headers = SystemConfigService.mask_sensitive_headers(
db, request_headers
)
# 处理提供商请求头(可能需要脱敏)
processed_provider_request_headers = None
if should_log_headers and provider_request_headers:
processed_provider_request_headers = SystemConfigService.mask_sensitive_headers(
db, provider_request_headers
)
# 处理请求体和响应体(可能需要截断)
processed_request_body = None
processed_response_body = None
if should_log_body:
if request_body:
processed_request_body = SystemConfigService.truncate_body(
db, request_body, is_request=True
)
if response_body:
processed_response_body = SystemConfigService.truncate_body(
db, response_body, is_request=False
)
# 处理响应头
processed_response_headers = None
if should_log_headers and response_headers:
processed_response_headers = SystemConfigService.mask_sensitive_headers(
db, response_headers
)
# 计算真实成本(表面成本 * 倍率),免费套餐实际费用为 0
if is_free_tier:
actual_input_cost = 0.0
actual_output_cost = 0.0
actual_cache_creation_cost = 0.0
actual_cache_read_cost = 0.0
actual_request_cost = 0.0
actual_total_cost = 0.0
else:
actual_input_cost = input_cost * actual_rate_multiplier
actual_output_cost = output_cost * actual_rate_multiplier
actual_cache_creation_cost = cache_creation_cost * actual_rate_multiplier
actual_cache_read_cost = cache_read_cost * actual_rate_multiplier
actual_request_cost = request_cost * actual_rate_multiplier
actual_total_cost = total_cost * actual_rate_multiplier
return {
"user_id": user.id if user else None,
"api_key_id": api_key.id if api_key else None,
"request_id": request_id,
"provider": provider,
"model": model,
"target_model": target_model,
"provider_id": provider_id,
"provider_endpoint_id": provider_endpoint_id,
"provider_api_key_id": provider_api_key_id,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"cache_creation_input_tokens": cache_creation_input_tokens,
"cache_read_input_tokens": cache_read_input_tokens,
"input_cost_usd": input_cost,
"output_cost_usd": output_cost,
"cache_cost_usd": cache_cost,
"cache_creation_cost_usd": cache_creation_cost,
"cache_read_cost_usd": cache_read_cost,
"request_cost_usd": request_cost,
"total_cost_usd": total_cost,
"actual_input_cost_usd": actual_input_cost,
"actual_output_cost_usd": actual_output_cost,
"actual_cache_creation_cost_usd": actual_cache_creation_cost,
"actual_cache_read_cost_usd": actual_cache_read_cost,
"actual_request_cost_usd": actual_request_cost,
"actual_total_cost_usd": actual_total_cost,
"rate_multiplier": actual_rate_multiplier,
"input_price_per_1m": input_price,
"output_price_per_1m": output_price,
"cache_creation_price_per_1m": cache_creation_price,
"cache_read_price_per_1m": cache_read_price,
"price_per_request": request_price,
"request_type": request_type,
"api_format": api_format,
"is_stream": is_stream,
"status_code": status_code,
"error_message": error_message,
"response_time_ms": response_time_ms,
"first_byte_time_ms": first_byte_time_ms,
"status": status,
"request_metadata": metadata,
"request_headers": processed_request_headers,
"request_body": processed_request_body,
"provider_request_headers": processed_provider_request_headers,
"response_headers": processed_response_headers,
"response_body": processed_response_body,
}
@classmethod
async def _get_rate_multiplier_and_free_tier(
cls,
db: Session,
provider_api_key_id: Optional[str],
provider_id: Optional[str],
) -> Tuple[float, bool]:
"""获取费率倍数和是否免费套餐"""
actual_rate_multiplier = 1.0
if provider_api_key_id:
provider_key = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.id == provider_api_key_id).first()
)
if provider_key and provider_key.rate_multiplier:
actual_rate_multiplier = provider_key.rate_multiplier
is_free_tier = False
if provider_id:
provider_obj = db.query(Provider).filter(Provider.id == provider_id).first()
if provider_obj and provider_obj.billing_type == ProviderBillingType.FREE_TIER:
is_free_tier = True
return actual_rate_multiplier, is_free_tier
@classmethod
async def _calculate_costs(
cls,
db: Session,
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
api_format: Optional[str],
cache_ttl_minutes: Optional[int],
use_tiered_pricing: bool,
is_failed_request: bool,
) -> Tuple[float, float, float, float, float, float, float, float, float,
Optional[float], Optional[float], Optional[float], Optional[int]]:
"""计算所有成本相关数据
Returns:
(input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, tier_index)
"""
# 获取模型价格信息
input_price, output_price = await cls.get_model_price_async(db, provider, model)
cache_creation_price, cache_read_price = await cls.get_cache_prices_async(
db, provider, model, input_price
)
request_price = await cls.get_request_price_async(db, provider, model)
effective_request_price = None if is_failed_request else request_price
# 初始化成本变量
input_cost = 0.0
output_cost = 0.0
cache_creation_cost = 0.0
cache_read_cost = 0.0
cache_cost = 0.0
request_cost = 0.0
total_cost = 0.0
tier_index = None
if use_tiered_pricing:
(
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
tier_index,
) = await cls.calculate_cost_with_strategy_async(
db=db,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format,
cache_ttl_minutes=cache_ttl_minutes,
)
if is_failed_request:
total_cost = total_cost - request_cost
request_cost = 0.0
else:
(
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
) = cls.calculate_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_price_per_1m=input_price,
output_price_per_1m=output_price,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_creation_price_per_1m=cache_creation_price,
cache_read_price_per_1m=cache_read_price,
price_per_request=effective_request_price,
)
return (
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, tier_index
)
@staticmethod
def _update_existing_usage(
existing_usage: Usage,
usage_params: Dict[str, Any],
target_model: Optional[str],
) -> None:
"""更新已存在的 Usage 记录(内部方法)"""
# 更新关键字段
existing_usage.provider = usage_params["provider"]
existing_usage.status = usage_params["status"]
existing_usage.status_code = usage_params["status_code"]
existing_usage.error_message = usage_params["error_message"]
existing_usage.response_time_ms = usage_params["response_time_ms"]
existing_usage.first_byte_time_ms = usage_params["first_byte_time_ms"]
# 更新请求头和请求体(如果有新值)
if usage_params["request_headers"] is not None:
existing_usage.request_headers = usage_params["request_headers"]
if usage_params["request_body"] is not None:
existing_usage.request_body = usage_params["request_body"]
if usage_params["provider_request_headers"] is not None:
existing_usage.provider_request_headers = usage_params["provider_request_headers"]
existing_usage.response_body = usage_params["response_body"]
existing_usage.response_headers = usage_params["response_headers"]
# 更新 token 和费用信息
existing_usage.input_tokens = usage_params["input_tokens"]
existing_usage.output_tokens = usage_params["output_tokens"]
existing_usage.total_tokens = usage_params["total_tokens"]
existing_usage.cache_creation_input_tokens = usage_params["cache_creation_input_tokens"]
existing_usage.cache_read_input_tokens = usage_params["cache_read_input_tokens"]
existing_usage.input_cost_usd = usage_params["input_cost_usd"]
existing_usage.output_cost_usd = usage_params["output_cost_usd"]
existing_usage.cache_cost_usd = usage_params["cache_cost_usd"]
existing_usage.cache_creation_cost_usd = usage_params["cache_creation_cost_usd"]
existing_usage.cache_read_cost_usd = usage_params["cache_read_cost_usd"]
existing_usage.request_cost_usd = usage_params["request_cost_usd"]
existing_usage.total_cost_usd = usage_params["total_cost_usd"]
existing_usage.actual_input_cost_usd = usage_params["actual_input_cost_usd"]
existing_usage.actual_output_cost_usd = usage_params["actual_output_cost_usd"]
existing_usage.actual_cache_creation_cost_usd = usage_params["actual_cache_creation_cost_usd"]
existing_usage.actual_cache_read_cost_usd = usage_params["actual_cache_read_cost_usd"]
existing_usage.actual_request_cost_usd = usage_params["actual_request_cost_usd"]
existing_usage.actual_total_cost_usd = usage_params["actual_total_cost_usd"]
existing_usage.rate_multiplier = usage_params["rate_multiplier"]
# 更新 Provider 侧追踪信息
existing_usage.provider_id = usage_params["provider_id"]
existing_usage.provider_endpoint_id = usage_params["provider_endpoint_id"]
existing_usage.provider_api_key_id = usage_params["provider_api_key_id"]
# 更新模型映射信息
if target_model is not None:
existing_usage.target_model = target_model
# ==================== 公开 API ====================
@classmethod
async def get_model_price_async(
cls, db: Session, provider: str, model: str
) -> tuple[float, float]:
"""异步获取模型价格输入价格输出价格每1M tokens
查找逻辑:
1. 直接通过 GlobalModel.name 匹配
2. 查找该 Provider 的 Model 实现并获取价格
3. 如果找不到则使用系统默认价格
"""
service = ModelCostService(db)
return await service.get_model_price_async(provider, model)
@classmethod
def get_model_price(cls, db: Session, provider: str, model: str) -> tuple[float, float]:
"""获取模型价格输入价格输出价格每1M tokens
查找逻辑:
1. 直接通过 GlobalModel.name 匹配
2. 查找该 Provider 的 Model 实现并获取价格
3. 如果找不到则使用系统默认价格
"""
service = ModelCostService(db)
return service.get_model_price(provider, model)
@classmethod
async def get_cache_prices_async(
cls, db: Session, provider: str, model: str, input_price: float
) -> Tuple[Optional[float], Optional[float]]:
"""异步获取模型缓存价格缓存创建价格缓存读取价格每1M tokens"""
service = ModelCostService(db)
return await service.get_cache_prices_async(provider, model, input_price)
@classmethod
def get_cache_prices(
cls, db: Session, provider: str, model: str, input_price: float
) -> Tuple[Optional[float], Optional[float]]:
"""获取模型缓存价格缓存创建价格缓存读取价格每1M tokens"""
service = ModelCostService(db)
return service.get_cache_prices(provider, model, input_price)
@classmethod
async def get_request_price_async(
cls, db: Session, provider: str, model: str
) -> Optional[float]:
"""异步获取模型按次计费价格"""
service = ModelCostService(db)
return await service.get_request_price_async(provider, model)
@classmethod
def get_request_price(cls, db: Session, provider: str, model: str) -> Optional[float]:
"""获取模型按次计费价格"""
service = ModelCostService(db)
return service.get_request_price(provider, model)
@staticmethod
def calculate_cost(
input_tokens: int,
output_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
cache_creation_price_per_1m: Optional[float] = None,
cache_read_price_per_1m: Optional[float] = None,
price_per_request: Optional[float] = None,
) -> tuple[float, float, float, float, float, float, float]:
"""计算成本价格是每百万tokens- 固定价格模式
Returns:
Tuple of (input_cost, output_cost, cache_creation_cost,
cache_read_cost, cache_cost, request_cost, total_cost)
"""
return ModelCostService.compute_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_price_per_1m=input_price_per_1m,
output_price_per_1m=output_price_per_1m,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
cache_creation_price_per_1m=cache_creation_price_per_1m,
cache_read_price_per_1m=cache_read_price_per_1m,
price_per_request=price_per_request,
)
@classmethod
async def calculate_cost_with_strategy_async(
cls,
db: Session,
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
api_format: Optional[str] = None,
cache_ttl_minutes: Optional[int] = None,
) -> tuple[float, float, float, float, float, float, float, Optional[int]]:
"""使用策略模式计算成本(支持阶梯计费)
根据 api_format 选择对应的计费策略,支持阶梯计费和 TTL 差异化。
Returns:
Tuple of (input_cost, output_cost, cache_creation_cost,
cache_read_cost, cache_cost, request_cost, total_cost, tier_index)
"""
service = ModelCostService(db)
return await service.compute_cost_with_strategy_async(
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format,
cache_ttl_minutes=cache_ttl_minutes,
)
@classmethod
async def record_usage_async(
cls,
db: Session,
user: Optional[User],
api_key: Optional[ApiKey],
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
request_type: str = "chat",
api_format: Optional[str] = None,
is_stream: bool = False,
response_time_ms: Optional[int] = None,
first_byte_time_ms: Optional[int] = None,
status_code: int = 200,
error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
request_headers: Optional[Dict[str, Any]] = None,
request_body: Optional[Any] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
response_headers: Optional[Dict[str, Any]] = None,
response_body: Optional[Any] = None,
request_id: Optional[str] = None,
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
status: str = "completed",
cache_ttl_minutes: Optional[int] = None,
use_tiered_pricing: bool = True,
target_model: Optional[str] = None,
) -> Usage:
"""异步记录使用量(简化版,仅插入新记录)
此方法用于快速记录使用量,不更新用户/API Key 统计,不支持更新已存在的记录。
适用于不需要更新统计信息的场景。
如需完整功能(更新用户统计、支持更新已存在记录),请使用 record_usage()。
"""
# 生成 request_id
if request_id is None:
request_id = str(uuid.uuid4())[:8]
# 获取费率倍数和是否免费套餐
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
db, provider_api_key_id, provider_id
)
# 计算成本
is_failed_request = status_code >= 400 or error_message is not None
(
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, tier_index
) = await cls._calculate_costs(
db=db,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format,
cache_ttl_minutes=cache_ttl_minutes,
use_tiered_pricing=use_tiered_pricing,
is_failed_request=is_failed_request,
)
# 构建 Usage 参数
usage_params = cls._build_usage_params(
db=db,
user=user,
api_key=api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
request_type=request_type,
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms,
status_code=status_code,
error_message=error_message,
metadata=metadata,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers,
response_headers=response_headers,
response_body=response_body,
request_id=request_id,
provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id,
status=status,
target_model=target_model,
input_cost=input_cost,
output_cost=output_cost,
cache_creation_cost=cache_creation_cost,
cache_read_cost=cache_read_cost,
cache_cost=cache_cost,
request_cost=request_cost,
total_cost=total_cost,
input_price=input_price,
output_price=output_price,
cache_creation_price=cache_creation_price,
cache_read_price=cache_read_price,
request_price=request_price,
actual_rate_multiplier=actual_rate_multiplier,
is_free_tier=is_free_tier,
)
# 创建 Usage 记录
usage = Usage(**usage_params)
db.add(usage)
# 更新 GlobalModel 使用计数(原子操作)
from sqlalchemy import update
from src.models.database import GlobalModel
db.execute(
update(GlobalModel)
.where(GlobalModel.name == model)
.values(usage_count=GlobalModel.usage_count + 1)
)
# 更新 Provider 月度使用量(原子操作)
if provider_id:
actual_total_cost = usage_params["actual_total_cost_usd"]
db.execute(
update(Provider)
.where(Provider.id == provider_id)
.values(monthly_used_usd=Provider.monthly_used_usd + actual_total_cost)
)
db.commit() # 立即提交事务,释放数据库锁
return usage
@classmethod
async def record_usage(
cls,
db: Session,
user: Optional[User],
api_key: Optional[ApiKey],
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
request_type: str = "chat",
api_format: Optional[str] = None,
is_stream: bool = False,
response_time_ms: Optional[int] = None,
first_byte_time_ms: Optional[int] = None,
status_code: int = 200,
error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
request_headers: Optional[Dict[str, Any]] = None,
request_body: Optional[Any] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
response_headers: Optional[Dict[str, Any]] = None,
response_body: Optional[Any] = None,
request_id: Optional[str] = None,
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
status: str = "completed",
cache_ttl_minutes: Optional[int] = None,
use_tiered_pricing: bool = True,
target_model: Optional[str] = None,
) -> Usage:
"""记录使用量(完整版,支持更新已存在记录和用户统计)
此方法支持:
- 检查是否已存在相同 request_id 的记录(更新 vs 插入)
- 更新用户/API Key 使用统计
- 阶梯计费
如只需简单插入新记录,可使用 record_usage_async()。
"""
# 生成 request_id
if request_id is None:
request_id = str(uuid.uuid4())[:8]
# 获取费率倍数和是否免费套餐
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
db, provider_api_key_id, provider_id
)
# 计算成本
is_failed_request = status_code >= 400 or error_message is not None
(
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, _tier_index
) = await cls._calculate_costs(
db=db,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format,
cache_ttl_minutes=cache_ttl_minutes,
use_tiered_pricing=use_tiered_pricing,
is_failed_request=is_failed_request,
)
# 构建 Usage 参数
usage_params = cls._build_usage_params(
db=db,
user=user,
api_key=api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
request_type=request_type,
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms,
status_code=status_code,
error_message=error_message,
metadata=metadata,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers,
response_headers=response_headers,
response_body=response_body,
request_id=request_id,
provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id,
status=status,
target_model=target_model,
input_cost=input_cost,
output_cost=output_cost,
cache_creation_cost=cache_creation_cost,
cache_read_cost=cache_read_cost,
cache_cost=cache_cost,
request_cost=request_cost,
total_cost=total_cost,
input_price=input_price,
output_price=output_price,
cache_creation_price=cache_creation_price,
cache_read_price=cache_read_price,
request_price=request_price,
actual_rate_multiplier=actual_rate_multiplier,
is_free_tier=is_free_tier,
)
# 检查是否已存在相同 request_id 的记录
existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first()
if existing_usage:
logger.debug(
f"request_id {request_id} 已存在,更新现有记录 "
f"(status: {existing_usage.status} -> {status})"
)
cls._update_existing_usage(existing_usage, usage_params, target_model)
usage = existing_usage
else:
usage = Usage(**usage_params)
db.add(usage)
# 确保 user 和 api_key 在会话中
if user and not db.object_session(user):
user = db.merge(user)
if api_key and not db.object_session(api_key):
api_key = db.merge(api_key)
# 使用原子更新避免并发竞态条件
from sqlalchemy import func, update
from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel
# 更新用户使用量(独立 Key 不计入创建者的使用记录)
if user and not (api_key and api_key.is_standalone):
db.execute(
update(UserModel)
.where(UserModel.id == user.id)
.values(
used_usd=UserModel.used_usd + total_cost,
total_usd=UserModel.total_usd + total_cost,
updated_at=func.now(),
)
)
# 更新 API 密钥使用量
if api_key:
if api_key.is_standalone:
db.execute(
update(ApiKeyModel)
.where(ApiKeyModel.id == api_key.id)
.values(
total_requests=ApiKeyModel.total_requests + 1,
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
balance_used_usd=ApiKeyModel.balance_used_usd + total_cost,
last_used_at=func.now(),
updated_at=func.now(),
)
)
else:
db.execute(
update(ApiKeyModel)
.where(ApiKeyModel.id == api_key.id)
.values(
total_requests=ApiKeyModel.total_requests + 1,
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
last_used_at=func.now(),
updated_at=func.now(),
)
)
# 更新 GlobalModel 使用计数
db.execute(
update(GlobalModel)
.where(GlobalModel.name == model)
.values(usage_count=GlobalModel.usage_count + 1)
)
# 更新 Provider 月度使用量
if provider_id:
actual_total_cost = usage_params["actual_total_cost_usd"]
db.execute(
update(Provider)
.where(Provider.id == provider_id)
.values(monthly_used_usd=Provider.monthly_used_usd + actual_total_cost)
)
# 提交事务
try:
db.commit()
except Exception as e:
logger.error(f"提交使用记录时出错: {e}")
db.rollback()
raise
return usage
@staticmethod
def check_user_quota(
db: Session,
user: User,
estimated_tokens: int = 0,
estimated_cost: float = 0,
api_key: Optional[ApiKey] = None,
) -> tuple[bool, str]:
"""检查用户配额或独立Key余额
Args:
db: 数据库会话
user: 用户对象
estimated_tokens: 预估token数
estimated_cost: 预估费用
api_key: API Key对象用于检查独立余额Key
Returns:
(是否通过, 消息)
"""
# 如果是独立余额Key检查Key的余额而不是用户配额
if api_key and api_key.is_standalone:
# 导入 ApiKeyService 以使用统一的余额计算方法
from src.services.user.apikey import ApiKeyService
# NULL 表示无限制
if api_key.current_balance_usd is None:
return True, "OK"
# 使用统一的余额计算方法
remaining_balance = ApiKeyService.get_remaining_balance(api_key)
if remaining_balance is None:
return True, "OK"
# 检查余额是否充足
if remaining_balance < estimated_cost:
return (
False,
f"Key余额不足剩余: ${remaining_balance:.2f},需要: ${estimated_cost:.2f}",
)
return True, "OK"
# 普通Key检查用户配额
# 管理员无限制
if user.role == UserRole.ADMIN:
return True, "OK"
# NULL 表示无限制
if user.quota_usd is None:
return True, "OK"
# 有配额限制,检查是否超额
used_usd = float(user.used_usd or 0)
quota_usd = float(user.quota_usd)
if used_usd + estimated_cost > quota_usd:
remaining = quota_usd - used_usd
return False, f"配额不足(剩余: ${remaining:.2f}"
return True, "OK"
@staticmethod
def get_usage_summary(
db: Session,
user_id: Optional[str] = None,
api_key_id: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
group_by: str = "day", # day, week, month
) -> List[Dict[str, Any]]:
"""获取使用汇总"""
query = db.query(Usage)
# 过滤掉 pending/streaming 状态的请求(尚未完成的请求不应计入统计)
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
if user_id:
query = query.filter(Usage.user_id == user_id)
if api_key_id:
query = query.filter(Usage.api_key_id == api_key_id)
if start_date:
query = query.filter(Usage.created_at >= start_date)
if end_date:
query = query.filter(Usage.created_at <= end_date)
# 使用跨数据库兼容的日期函数
from src.utils.database_helpers import date_trunc_portable
# 检测数据库方言
bind = db.bind
dialect = bind.dialect.name if bind is not None else "sqlite"
# 根据分组类型选择日期函数(兼容多种数据库)
if group_by == "day":
date_func = date_trunc_portable(dialect, "day", Usage.created_at)
elif group_by == "week":
date_func = date_trunc_portable(dialect, "week", Usage.created_at)
elif group_by == "month":
date_func = date_trunc_portable(dialect, "month", Usage.created_at)
else:
# 默认按天分组
date_func = date_trunc_portable(dialect, "day", Usage.created_at)
# 汇总查询
summary = db.query(
date_func.label("period"),
Usage.provider,
Usage.model,
func.count(Usage.id).label("requests"),
func.sum(Usage.input_tokens).label("input_tokens"),
func.sum(Usage.output_tokens).label("output_tokens"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
func.avg(Usage.response_time_ms).label("avg_response_time"),
)
if user_id:
summary = summary.filter(Usage.user_id == user_id)
if api_key_id:
summary = summary.filter(Usage.api_key_id == api_key_id)
if start_date:
summary = summary.filter(Usage.created_at >= start_date)
if end_date:
summary = summary.filter(Usage.created_at <= end_date)
summary = summary.group_by(date_func, Usage.provider, Usage.model).all()
return [
{
"period": row.period,
"provider": row.provider,
"model": row.model,
"requests": row.requests,
"input_tokens": row.input_tokens,
"output_tokens": row.output_tokens,
"total_tokens": row.total_tokens,
"total_cost_usd": float(row.total_cost_usd),
"avg_response_time_ms": (
float(row.avg_response_time) if row.avg_response_time else 0
),
}
for row in summary
]
@staticmethod
def get_daily_activity(
db: Session,
user_id: Optional[int] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
window_days: int = 365,
include_actual_cost: bool = False,
) -> Dict[str, Any]:
"""按天统计请求活跃度,用于渲染热力图。"""
def ensure_timezone(value: datetime) -> datetime:
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
# 如果调用方未指定时间范围,则默认统计最近 window_days 天
now = datetime.now(timezone.utc)
end_dt = ensure_timezone(end_date) if end_date else now
start_dt = (
ensure_timezone(start_date) if start_date else end_dt - timedelta(days=window_days - 1)
)
# 对齐到自然日的开始/结束,避免遗漏边界数据
start_dt = start_dt.replace(hour=0, minute=0, second=0, microsecond=0)
end_dt = end_dt.replace(hour=23, minute=59, second=59, microsecond=999999)
from src.utils.database_helpers import date_trunc_portable
bind = db.get_bind()
dialect = bind.dialect.name if bind is not None else "sqlite"
day_bucket = date_trunc_portable(dialect, "day", Usage.created_at).label("day")
columns = [
day_bucket,
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
]
if include_actual_cost:
columns.append(func.sum(Usage.actual_total_cost_usd).label("actual_total_cost_usd"))
query = db.query(*columns).filter(Usage.created_at >= start_dt, Usage.created_at <= end_dt)
if user_id:
query = query.filter(Usage.user_id == user_id)
query = query.group_by(day_bucket).order_by(day_bucket)
rows = query.all()
def normalize_period(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value[:10]
if isinstance(value, datetime):
return value.date().isoformat()
return str(value)
aggregated: Dict[str, Dict[str, Any]] = {}
for row in rows:
key = normalize_period(row.day)
aggregated[key] = {
"requests": int(row.requests or 0),
"total_tokens": int(row.total_tokens or 0),
"total_cost_usd": float(row.total_cost_usd or 0.0),
}
if include_actual_cost:
aggregated[key]["actual_total_cost_usd"] = float(row.actual_total_cost_usd or 0.0)
days: List[Dict[str, Any]] = []
cursor = start_dt.date()
end_date_only = end_dt.date()
max_requests = 0
while cursor <= end_date_only:
iso_date = cursor.isoformat()
stats = aggregated.get(iso_date, {})
requests = stats.get("requests", 0)
total_tokens = stats.get("total_tokens", 0)
total_cost = stats.get("total_cost_usd", 0.0)
entry: Dict[str, Any] = {
"date": iso_date,
"requests": requests,
"total_tokens": total_tokens,
"total_cost": total_cost,
}
if include_actual_cost:
entry["actual_total_cost"] = stats.get("actual_total_cost_usd", 0.0)
days.append(entry)
max_requests = max(max_requests, requests)
cursor += timedelta(days=1)
return {
"start_date": start_dt.date().isoformat(),
"end_date": end_dt.date().isoformat(),
"total_days": len(days),
"max_requests": max_requests,
"days": days,
}
@staticmethod
def get_top_users(
db: Session,
limit: int = 10,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
order_by: str = "cost", # cost, tokens, requests
) -> List[Dict[str, Any]]:
"""获取使用量最高的用户"""
query = (
db.query(
User.id,
User.email,
User.username,
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("tokens"),
func.sum(Usage.total_cost_usd).label("cost_usd"),
)
.join(Usage, User.id == Usage.user_id)
.filter(Usage.user_id.isnot(None))
)
if start_date:
query = query.filter(Usage.created_at >= start_date)
if end_date:
query = query.filter(Usage.created_at <= end_date)
query = query.group_by(User.id, User.email, User.username)
# 排序
if order_by == "cost":
query = query.order_by(func.sum(Usage.total_cost_usd).desc())
elif order_by == "tokens":
query = query.order_by(func.sum(Usage.total_tokens).desc())
else:
query = query.order_by(func.count(Usage.id).desc())
results = query.limit(limit).all()
return [
{
"user_id": row.id,
"email": row.email,
"username": row.username,
"requests": row.requests,
"tokens": row.tokens,
"cost_usd": float(row.cost_usd),
}
for row in results
]
@staticmethod
def cleanup_old_usage_records(db: Session, days_to_keep: int = 90) -> int:
"""清理旧的使用记录"""
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
# 删除旧记录
deleted = db.query(Usage).filter(Usage.created_at < cutoff_date).delete()
db.commit()
logger.info(f"清理使用记录: 删除 {deleted} 条超过 {days_to_keep} 天的记录")
return deleted
# ========== 请求状态追踪方法 ==========
@classmethod
def create_pending_usage(
cls,
db: Session,
request_id: str,
user: Optional[User],
api_key: Optional[ApiKey],
model: str,
is_stream: bool = False,
api_format: Optional[str] = None,
request_headers: Optional[Dict[str, Any]] = None,
request_body: Optional[Any] = None,
) -> Usage:
"""
创建 pending 状态的使用记录(在请求开始时调用)
Args:
db: 数据库会话
request_id: 请求ID
user: 用户对象
api_key: API Key 对象
model: 模型名称
is_stream: 是否流式请求
api_format: API 格式
request_headers: 请求头
request_body: 请求体
Returns:
创建的 Usage 记录
"""
# 根据配置决定是否记录请求详情
should_log_headers = SystemConfigService.should_log_headers(db)
should_log_body = SystemConfigService.should_log_body(db)
# 处理请求头
processed_request_headers = None
if should_log_headers and request_headers:
processed_request_headers = SystemConfigService.mask_sensitive_headers(
db, request_headers
)
# 处理请求体
processed_request_body = None
if should_log_body and request_body:
processed_request_body = SystemConfigService.truncate_body(
db, request_body, is_request=True
)
usage = Usage(
user_id=user.id if user else None,
api_key_id=api_key.id if api_key else None,
request_id=request_id,
provider="pending", # 尚未确定 provider
model=model,
input_tokens=0,
output_tokens=0,
total_tokens=0,
total_cost_usd=0.0,
request_type="chat",
api_format=api_format,
is_stream=is_stream,
status="pending",
request_headers=processed_request_headers,
request_body=processed_request_body,
)
db.add(usage)
db.commit()
logger.debug(f"创建 pending 使用记录: request_id={request_id}, model={model}")
return usage
@classmethod
def update_usage_status(
cls,
db: Session,
request_id: str,
status: str,
error_message: Optional[str] = None,
provider: Optional[str] = None,
target_model: Optional[str] = None,
) -> Optional[Usage]:
"""
快速更新使用记录状态
Args:
db: 数据库会话
request_id: 请求ID
status: 新状态 (pending, streaming, completed, failed)
error_message: 错误消息(仅在 failed 状态时使用)
provider: 提供商名称可选streaming 状态时更新)
target_model: 映射后的目标模型名(可选)
Returns:
更新后的 Usage 记录,如果未找到则返回 None
"""
usage = db.query(Usage).filter(Usage.request_id == request_id).first()
if not usage:
logger.warning(f"未找到 request_id={request_id} 的使用记录,无法更新状态")
return None
old_status = usage.status
usage.status = status
if error_message:
usage.error_message = error_message
if provider:
usage.provider = provider
if target_model:
usage.target_model = target_model
db.commit()
logger.debug(f"更新使用记录状态: request_id={request_id}, {old_status} -> {status}")
return usage
@classmethod
def get_active_requests(
cls,
db: Session,
user_id: Optional[str] = None,
limit: int = 50,
) -> List[Usage]:
"""
获取活跃的请求pending 或 streaming 状态)
Args:
db: 数据库会话
user_id: 用户ID可选用于过滤
limit: 最大返回数量
Returns:
活跃请求的 Usage 列表
"""
query = db.query(Usage).filter(Usage.status.in_(["pending", "streaming"]))
if user_id:
query = query.filter(Usage.user_id == user_id)
return query.order_by(Usage.created_at.desc()).limit(limit).all()
@classmethod
def cleanup_stale_pending_requests(
cls,
db: Session,
timeout_minutes: int = 10,
) -> int:
"""
清理超时的 pending/streaming 请求
将超过指定时间仍处于 pending 或 streaming 状态的请求标记为 failed。
这些请求可能是由于网络问题、服务重启或其他异常导致未能正常完成。
Args:
db: 数据库会话
timeout_minutes: 超时时间(分钟),默认 10 分钟
Returns:
清理的记录数
"""
cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=timeout_minutes)
# 查找超时的请求
stale_requests = (
db.query(Usage)
.filter(
Usage.status.in_(["pending", "streaming"]),
Usage.created_at < cutoff_time,
)
.all()
)
count = 0
for usage in stale_requests:
old_status = usage.status
usage.status = "failed"
usage.error_message = f"请求超时: 状态 '{old_status}' 超过 {timeout_minutes} 分钟未完成"
usage.status_code = 504 # Gateway Timeout
count += 1
if count > 0:
db.commit()
logger.info(f"清理超时请求: 将 {count} 条超过 {timeout_minutes} 分钟的 pending/streaming 请求标记为 failed")
return count
@classmethod
def get_stale_pending_count(
cls,
db: Session,
timeout_minutes: int = 10,
) -> int:
"""
获取超时的 pending/streaming 请求数量(用于监控)
Args:
db: 数据库会话
timeout_minutes: 超时时间(分钟)
Returns:
超时请求数量
"""
cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=timeout_minutes)
return (
db.query(Usage)
.filter(
Usage.status.in_(["pending", "streaming"]),
Usage.created_at < cutoff_time,
)
.count()
)
@classmethod
def get_active_requests_status(
cls,
db: Session,
ids: Optional[List[str]] = None,
user_id: Optional[str] = None,
default_timeout_seconds: int = 300,
) -> List[Dict[str, Any]]:
"""
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending/streaming 请求
与 get_active_requests 不同,此方法:
1. 返回轻量级的状态字典而非完整 Usage 对象
2. 自动检测并清理超时的 pending/streaming 请求
3. 支持按 ID 列表查询特定请求
Args:
db: 数据库会话
ids: 指定要查询的请求 ID 列表(可选)
user_id: 限制只查询该用户的请求(可选,用于普通用户接口)
default_timeout_seconds: 默认超时时间(秒),当端点未配置时使用
Returns:
请求状态列表
"""
from src.models.database import ProviderEndpoint
now = datetime.now(timezone.utc)
# 构建基础查询,包含端点的 timeout 配置
query = db.query(
Usage.id,
Usage.status,
Usage.input_tokens,
Usage.output_tokens,
Usage.total_cost_usd,
Usage.response_time_ms,
Usage.first_byte_time_ms, # 首字时间 (TTFB)
Usage.created_at,
Usage.provider_endpoint_id,
ProviderEndpoint.timeout.label("endpoint_timeout"),
).outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
if ids:
query = query.filter(Usage.id.in_(ids))
if user_id:
query = query.filter(Usage.user_id == user_id)
else:
# 查询所有活跃请求
query = query.filter(Usage.status.in_(["pending", "streaming"]))
if user_id:
query = query.filter(Usage.user_id == user_id)
query = query.order_by(Usage.created_at.desc()).limit(50)
records = query.all()
# 检查超时的 pending/streaming 请求
timeout_ids = []
for r in records:
if r.status in ("pending", "streaming") and r.created_at:
# 使用端点配置的超时时间,若无则使用默认值
timeout_seconds = r.endpoint_timeout or default_timeout_seconds
# 处理时区:如果 created_at 没有时区信息,假定为 UTC
created_at = r.created_at
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
elapsed = (now - created_at).total_seconds()
if elapsed > timeout_seconds:
timeout_ids.append(r.id)
# 批量更新超时的请求
if timeout_ids:
db.query(Usage).filter(Usage.id.in_(timeout_ids)).update(
{"status": "failed", "error_message": "请求超时(服务器可能已重启)"},
synchronize_session=False,
)
db.commit()
return [
{
"id": r.id,
"status": "failed" if r.id in timeout_ids else r.status,
"input_tokens": r.input_tokens,
"output_tokens": r.output_tokens,
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
"response_time_ms": r.response_time_ms,
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
}
for r in records
]
# ========== 缓存亲和性分析方法 ==========
@staticmethod
def analyze_cache_affinity_ttl(
db: Session,
user_id: Optional[str] = None,
api_key_id: Optional[str] = None,
hours: int = 168,
) -> Dict[str, Any]:
"""
分析用户请求间隔分布,推荐合适的缓存亲和性 TTL
通过分析同一用户连续请求之间的时间间隔,判断用户的使用模式:
- 高频用户间隔短5 分钟 TTL 足够
- 中频用户15-30 分钟 TTL
- 低频用户(间隔长):需要 60 分钟 TTL
Args:
db: 数据库会话
user_id: 指定用户 ID可选为空则分析所有用户
api_key_id: 指定 API Key ID可选
hours: 分析最近多少小时的数据
Returns:
包含分析结果的字典
"""
from sqlalchemy import text
# 计算时间范围
start_date = datetime.now(timezone.utc) - timedelta(hours=hours)
# 构建 SQL 查询 - 使用窗口函数计算请求间隔
# 按 user_id 或 api_key_id 分组,计算同一组内连续请求的时间差
group_by_field = "api_key_id" if api_key_id else "user_id"
# 构建过滤条件
filter_clause = ""
if user_id or api_key_id:
filter_clause = f"AND {group_by_field} = :filter_id"
sql = text(f"""
WITH user_requests AS (
SELECT
{group_by_field} as group_id,
created_at,
LAG(created_at) OVER (
PARTITION BY {group_by_field}
ORDER BY created_at
) as prev_request_at
FROM usage
WHERE status = 'completed'
AND created_at > :start_date
AND {group_by_field} IS NOT NULL
{filter_clause}
),
intervals AS (
SELECT
group_id,
EXTRACT(EPOCH FROM (created_at - prev_request_at)) / 60.0 as interval_minutes
FROM user_requests
WHERE prev_request_at IS NOT NULL
),
user_stats AS (
SELECT
group_id,
COUNT(*) as request_count,
COUNT(*) FILTER (WHERE interval_minutes <= 5) as within_5min,
COUNT(*) FILTER (WHERE interval_minutes > 5 AND interval_minutes <= 15) as within_15min,
COUNT(*) FILTER (WHERE interval_minutes > 15 AND interval_minutes <= 30) as within_30min,
COUNT(*) FILTER (WHERE interval_minutes > 30 AND interval_minutes <= 60) as within_60min,
COUNT(*) FILTER (WHERE interval_minutes > 60) as over_60min,
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY interval_minutes) as median_interval,
PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY interval_minutes) as p75_interval,
PERCENTILE_CONT(0.90) WITHIN GROUP (ORDER BY interval_minutes) as p90_interval,
AVG(interval_minutes) as avg_interval,
MIN(interval_minutes) as min_interval,
MAX(interval_minutes) as max_interval
FROM intervals
GROUP BY group_id
HAVING COUNT(*) >= 2
)
SELECT * FROM user_stats
ORDER BY request_count DESC
""")
params: Dict[str, Any] = {
"start_date": start_date,
}
if user_id:
params["filter_id"] = user_id
elif api_key_id:
params["filter_id"] = api_key_id
result = db.execute(sql, params)
rows = result.fetchall()
# 收集所有 user_id 以便批量查询用户信息
group_ids = [row[0] for row in rows]
# 如果是按 user_id 分组,查询用户信息
user_info_map: Dict[str, Dict[str, str]] = {}
if group_by_field == "user_id" and group_ids:
users = db.query(User).filter(User.id.in_(group_ids)).all()
for user in users:
user_info_map[str(user.id)] = {
"username": str(user.username),
"email": str(user.email) if user.email else "",
}
# 处理结果
users_analysis = []
for row in rows:
# row 是一个 tuple按查询顺序访问
(
group_id,
request_count,
within_5min,
within_15min,
within_30min,
within_60min,
over_60min,
median_interval,
p75_interval,
p90_interval,
avg_interval,
min_interval,
max_interval,
) = row
# 计算推荐 TTL
recommended_ttl = UsageService._calculate_recommended_ttl(
p75_interval, p90_interval
)
# 获取用户信息
user_info = user_info_map.get(str(group_id), {})
# 计算各区间占比
total_intervals = request_count
users_analysis.append({
"group_id": group_id,
"username": user_info.get("username"),
"email": user_info.get("email"),
"request_count": request_count,
"interval_distribution": {
"within_5min": within_5min,
"within_15min": within_15min,
"within_30min": within_30min,
"within_60min": within_60min,
"over_60min": over_60min,
},
"interval_percentages": {
"within_5min": round(within_5min / total_intervals * 100, 1),
"within_15min": round(within_15min / total_intervals * 100, 1),
"within_30min": round(within_30min / total_intervals * 100, 1),
"within_60min": round(within_60min / total_intervals * 100, 1),
"over_60min": round(over_60min / total_intervals * 100, 1),
},
"percentiles": {
"p50": round(float(median_interval), 2) if median_interval else None,
"p75": round(float(p75_interval), 2) if p75_interval else None,
"p90": round(float(p90_interval), 2) if p90_interval else None,
},
"avg_interval_minutes": round(float(avg_interval), 2) if avg_interval else None,
"min_interval_minutes": round(float(min_interval), 2) if min_interval else None,
"max_interval_minutes": round(float(max_interval), 2) if max_interval else None,
"recommended_ttl_minutes": recommended_ttl,
"recommendation_reason": UsageService._get_ttl_recommendation_reason(
recommended_ttl, p75_interval, p90_interval
),
})
# 汇总统计
ttl_distribution = {"5min": 0, "15min": 0, "30min": 0, "60min": 0}
for analysis in users_analysis:
ttl = analysis["recommended_ttl_minutes"]
if ttl <= 5:
ttl_distribution["5min"] += 1
elif ttl <= 15:
ttl_distribution["15min"] += 1
elif ttl <= 30:
ttl_distribution["30min"] += 1
else:
ttl_distribution["60min"] += 1
return {
"analysis_period_hours": hours,
"total_users_analyzed": len(users_analysis),
"ttl_distribution": ttl_distribution,
"users": users_analysis,
}
@staticmethod
def _calculate_recommended_ttl(
p75_interval: Optional[float],
p90_interval: Optional[float],
) -> int:
"""
根据请求间隔分布计算推荐的缓存 TTL
策略:
- 如果 90% 的请求间隔都在 5 分钟内 → 5 分钟 TTL
- 如果 75% 的请求间隔在 15 分钟内 → 15 分钟 TTL
- 如果 75% 的请求间隔在 30 分钟内 → 30 分钟 TTL
- 否则 → 60 分钟 TTL
"""
if p90_interval is None or p75_interval is None:
return 5 # 默认值
# 如果 90% 的间隔都在 5 分钟内
if p90_interval <= 5:
return 5
# 如果 75% 的间隔在 15 分钟内
if p75_interval <= 15:
return 15
# 如果 75% 的间隔在 30 分钟内
if p75_interval <= 30:
return 30
# 低频用户,需要更长的 TTL
return 60
@staticmethod
def _get_ttl_recommendation_reason(
ttl: int,
p75_interval: Optional[float],
p90_interval: Optional[float],
) -> str:
"""生成 TTL 推荐理由"""
if p75_interval is None or p90_interval is None:
return "数据不足,使用默认值"
if ttl == 5:
return f"高频用户90% 的请求间隔在 {p90_interval:.1f} 分钟内"
elif ttl == 15:
return f"中高频用户75% 的请求间隔在 {p75_interval:.1f} 分钟内"
elif ttl == 30:
return f"中频用户75% 的请求间隔在 {p75_interval:.1f} 分钟内"
else:
return f"低频用户75% 的请求间隔为 {p75_interval:.1f} 分钟,建议使用长 TTL"
@staticmethod
def get_cache_hit_analysis(
db: Session,
user_id: Optional[str] = None,
api_key_id: Optional[str] = None,
hours: int = 168,
) -> Dict[str, Any]:
"""
分析缓存命中情况
Args:
db: 数据库会话
user_id: 指定用户 ID可选
api_key_id: 指定 API Key ID可选
hours: 分析最近多少小时的数据
Returns:
缓存命中分析结果
"""
start_date = datetime.now(timezone.utc) - timedelta(hours=hours)
# 基础查询
query = db.query(
func.count(Usage.id).label("total_requests"),
func.sum(Usage.input_tokens).label("total_input_tokens"),
func.sum(Usage.cache_read_input_tokens).label("total_cache_read_tokens"),
func.sum(Usage.cache_creation_input_tokens).label("total_cache_creation_tokens"),
func.sum(Usage.cache_read_cost_usd).label("total_cache_read_cost"),
func.sum(Usage.cache_creation_cost_usd).label("total_cache_creation_cost"),
).filter(
Usage.status == "completed",
Usage.created_at >= start_date,
)
if user_id:
query = query.filter(Usage.user_id == user_id)
if api_key_id:
query = query.filter(Usage.api_key_id == api_key_id)
result = query.first()
if result is None:
total_requests = 0
total_input_tokens = 0
total_cache_read_tokens = 0
total_cache_creation_tokens = 0
total_cache_read_cost = 0.0
total_cache_creation_cost = 0.0
else:
total_requests = result.total_requests or 0
total_input_tokens = result.total_input_tokens or 0
total_cache_read_tokens = result.total_cache_read_tokens or 0
total_cache_creation_tokens = result.total_cache_creation_tokens or 0
total_cache_read_cost = float(result.total_cache_read_cost or 0)
total_cache_creation_cost = float(result.total_cache_creation_cost or 0)
# 计算缓存命中率(按 token 数)
# 总输入上下文 = input_tokens + cache_read_tokens因为 input_tokens 不含 cache_read
# 或者如果 input_tokens 已经包含 cache_read则直接用 input_tokens
# 这里假设 cache_read_tokens 是额外的,命中率 = cache_read / (input + cache_read)
total_context_tokens = total_input_tokens + total_cache_read_tokens
cache_hit_rate = 0.0
if total_context_tokens > 0:
cache_hit_rate = total_cache_read_tokens / total_context_tokens * 100
# 计算节省的费用
# 缓存读取价格是正常输入价格的 10%,所以节省了 90%
# 节省 = cache_read_tokens * (正常价格 - 缓存价格) = cache_read_cost * 9
# 因为 cache_read_cost 是按 10% 价格算的,如果按 100% 算就是 10 倍
estimated_savings = total_cache_read_cost * 9 # 节省了 90%
# 统计有缓存命中的请求数
requests_with_cache_hit = db.query(func.count(Usage.id)).filter(
Usage.status == "completed",
Usage.created_at >= start_date,
Usage.cache_read_input_tokens > 0,
)
if user_id:
requests_with_cache_hit = requests_with_cache_hit.filter(Usage.user_id == user_id)
if api_key_id:
requests_with_cache_hit = requests_with_cache_hit.filter(Usage.api_key_id == api_key_id)
requests_with_cache_hit_count = int(requests_with_cache_hit.scalar() or 0)
return {
"analysis_period_hours": hours,
"total_requests": total_requests,
"requests_with_cache_hit": requests_with_cache_hit_count,
"request_cache_hit_rate": round(requests_with_cache_hit_count / total_requests * 100, 2) if total_requests > 0 else 0,
"total_input_tokens": total_input_tokens,
"total_cache_read_tokens": total_cache_read_tokens,
"total_cache_creation_tokens": total_cache_creation_tokens,
"token_cache_hit_rate": round(cache_hit_rate, 2),
"total_cache_read_cost_usd": round(total_cache_read_cost, 4),
"total_cache_creation_cost_usd": round(total_cache_creation_cost, 4),
"estimated_savings_usd": round(estimated_savings, 4),
}
@staticmethod
def get_interval_timeline(
db: Session,
hours: int = 24,
limit: int = 10000,
user_id: Optional[str] = None,
include_user_info: bool = False,
) -> Dict[str, Any]:
"""
获取请求间隔时间线数据,用于散点图展示
Args:
db: 数据库会话
hours: 分析最近多少小时的数据默认24小时
limit: 最大返回数据点数量默认10000
user_id: 指定用户 ID可选为空则返回所有用户
include_user_info: 是否包含用户信息(用于管理员多用户视图)
Returns:
包含时间线数据点的字典,每个数据点包含 model 字段用于按模型区分颜色
"""
from sqlalchemy import text
start_date = datetime.now(timezone.utc) - timedelta(hours=hours)
# 构建用户过滤条件
user_filter = "AND u.user_id = :user_id" if user_id else ""
# 根据是否需要用户信息选择不同的查询
if include_user_info and not user_id:
# 管理员视图:返回带用户信息的数据点
# 使用按比例采样,保持每个用户的数据量比例不变
sql = text(f"""
WITH request_intervals AS (
SELECT
u.created_at,
u.user_id,
u.model,
usr.username,
LAG(u.created_at) OVER (
PARTITION BY u.user_id
ORDER BY u.created_at
) as prev_request_at
FROM usage u
LEFT JOIN users usr ON u.user_id = usr.id
WHERE u.status = 'completed'
AND u.created_at > :start_date
AND u.user_id IS NOT NULL
{user_filter}
),
filtered_intervals AS (
SELECT
created_at,
user_id,
model,
username,
EXTRACT(EPOCH FROM (created_at - prev_request_at)) / 60.0 as interval_minutes,
ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY created_at) as rn
FROM request_intervals
WHERE prev_request_at IS NOT NULL
AND EXTRACT(EPOCH FROM (created_at - prev_request_at)) / 60.0 <= 120
),
total_count AS (
SELECT COUNT(*) as cnt FROM filtered_intervals
),
user_totals AS (
SELECT user_id, COUNT(*) as user_cnt FROM filtered_intervals GROUP BY user_id
),
user_limits AS (
SELECT
ut.user_id,
CASE WHEN tc.cnt <= :limit THEN ut.user_cnt
ELSE GREATEST(CEIL(ut.user_cnt::float * :limit / tc.cnt), 1)::int
END as user_limit
FROM user_totals ut, total_count tc
)
SELECT
fi.created_at,
fi.user_id,
fi.model,
fi.username,
fi.interval_minutes
FROM filtered_intervals fi
JOIN user_limits ul ON fi.user_id = ul.user_id
WHERE fi.rn <= ul.user_limit
ORDER BY fi.created_at
""")
else:
# 普通视图:返回时间、间隔和模型信息
sql = text(f"""
WITH request_intervals AS (
SELECT
u.created_at,
u.user_id,
u.model,
LAG(u.created_at) OVER (
PARTITION BY u.user_id
ORDER BY u.created_at
) as prev_request_at
FROM usage u
WHERE u.status = 'completed'
AND u.created_at > :start_date
AND u.user_id IS NOT NULL
{user_filter}
)
SELECT
created_at,
model,
EXTRACT(EPOCH FROM (created_at - prev_request_at)) / 60.0 as interval_minutes
FROM request_intervals
WHERE prev_request_at IS NOT NULL
AND EXTRACT(EPOCH FROM (created_at - prev_request_at)) / 60.0 <= 120
ORDER BY created_at
LIMIT :limit
""")
params: Dict[str, Any] = {"start_date": start_date, "limit": limit}
if user_id:
params["user_id"] = user_id
result = db.execute(sql, params)
rows = result.fetchall()
# 转换为时间线数据点
points = []
users_map: Dict[str, str] = {} # user_id -> username
models_set: set = set() # 收集所有出现的模型
if include_user_info and not user_id:
for row in rows:
created_at, row_user_id, model, username, interval_minutes = row
point_data: Dict[str, Any] = {
"x": created_at.isoformat(),
"y": round(float(interval_minutes), 2),
"user_id": str(row_user_id),
}
if model:
point_data["model"] = model
models_set.add(model)
points.append(point_data)
if row_user_id and username:
users_map[str(row_user_id)] = username
else:
for row in rows:
created_at, model, interval_minutes = row
point_data = {
"x": created_at.isoformat(),
"y": round(float(interval_minutes), 2)
}
if model:
point_data["model"] = model
models_set.add(model)
points.append(point_data)
response: Dict[str, Any] = {
"analysis_period_hours": hours,
"total_points": len(points),
"points": points,
}
if include_user_info and not user_id:
response["users"] = users_map
# 如果有模型信息,返回模型列表
if models_set:
response["models"] = sorted(models_set)
return response