feat: add daily model statistics aggregation with stats_daily_model table

This commit is contained in:
fawney19
2025-12-20 02:39:10 +08:00
parent e2e7996a54
commit 4e1aed9976
22 changed files with 561 additions and 202 deletions

View File

@@ -589,14 +589,14 @@ class CacheAwareScheduler:
target_format = normalize_api_format(api_format)
# 0. 解析 model_name 到 GlobalModel支持直接匹配和名匹配,使用 ModelCacheService
# 0. 解析 model_name 到 GlobalModel支持直接匹配和映射名匹配,使用 ModelCacheService
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
if not global_model:
logger.warning(f"GlobalModel not found: {model_name}")
raise ModelNotSupportedException(model=model_name)
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保名和规范名都能命中同一个缓存
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保映射名和规范名都能命中同一个缓存
global_model_id: str = str(global_model.id)
requested_model_name = model_name
resolved_model_name = str(global_model.name)
@@ -751,19 +751,19 @@ class CacheAwareScheduler:
支持两种匹配方式:
1. 直接匹配 GlobalModel.name
2. 通过 ModelCacheService 匹配名(全局查找)
2. 通过 ModelCacheService 匹配映射名(全局查找)
Args:
db: 数据库会话
provider: Provider 对象
model_name: 模型名称(可以是 GlobalModel.name 或名)
model_name: 模型名称(可以是 GlobalModel.name 或映射名)
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
Returns:
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
"""
# 使用 ModelCacheService 解析模型名称(支持名)
# 使用 ModelCacheService 解析模型名称(支持映射名)
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
if not global_model:
@@ -914,7 +914,7 @@ class CacheAwareScheduler:
db: 数据库会话
providers: Provider 列表
target_format: 目标 API 格式
model_name: 模型名称(用户请求的名称,可能是名)
model_name: 模型名称(用户请求的名称,可能是映射名)
affinity_key: 亲和性标识符通常为API Key ID
resolved_model_name: 解析后的 GlobalModel.name用于 Key.allowed_models 校验)
max_candidates: 最大候选数

View File

@@ -198,7 +198,7 @@ class ModelCacheService:
provider_id: Optional[str] = None,
global_model_id: Optional[str] = None,
provider_model_name: Optional[str] = None,
provider_model_aliases: Optional[list] = None,
provider_model_mappings: Optional[list] = None,
) -> None:
"""清除 Model 缓存
@@ -207,7 +207,7 @@ class ModelCacheService:
provider_id: Provider ID用于清除 provider_global 缓存)
global_model_id: GlobalModel ID用于清除 provider_global 缓存)
provider_model_name: provider_model_name用于清除 resolve 缓存)
provider_model_aliases: 映射名称列表(用于清除 resolve 缓存)
provider_model_mappings: 映射名称列表(用于清除 resolve 缓存)
"""
# 清除 model:id 缓存
await CacheService.delete(f"model:id:{model_id}")
@@ -222,16 +222,16 @@ class ModelCacheService:
else:
logger.debug(f"Model 缓存已清除: {model_id}")
# 清除 resolve 缓存provider_model_name 和 aliases 可能都被用作解析 key
# 清除 resolve 缓存provider_model_name 和 mappings 可能都被用作解析 key
resolve_keys_to_clear = []
if provider_model_name:
resolve_keys_to_clear.append(provider_model_name)
if provider_model_aliases:
for alias_entry in provider_model_aliases:
if isinstance(alias_entry, dict):
alias_name = alias_entry.get("name", "").strip()
if alias_name:
resolve_keys_to_clear.append(alias_name)
if provider_model_mappings:
for mapping_entry in provider_model_mappings:
if isinstance(mapping_entry, dict):
mapping_name = mapping_entry.get("name", "").strip()
if mapping_name:
resolve_keys_to_clear.append(mapping_name)
for key in resolve_keys_to_clear:
await CacheService.delete(f"global_model:resolve:{key}")
@@ -261,8 +261,8 @@ class ModelCacheService:
2. 通过 provider_model_name 匹配(查询 Model 表)
3. 直接匹配 GlobalModel.name兜底
注意:此方法不使用 provider_model_aliases 进行全局解析。
provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效,
注意:此方法不使用 provider_model_mappings 进行全局解析。
provider_model_mappings 是 Provider 级别的映射配置,只在特定 Provider 上下文中生效,
由 resolve_provider_model() 处理。
Args:
@@ -301,9 +301,9 @@ class ModelCacheService:
logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}")
return ModelCacheService._dict_to_global_model(cached_data)
# 2. 通过 provider_model_name 匹配(不考虑 provider_model_aliases
# 重要provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效
# 全局解析不应该受到某个 Provider 别名配置的影响
# 2. 通过 provider_model_name 匹配(不考虑 provider_model_mappings
# 重要provider_model_mappings 是 Provider 级别的映射配置,只在特定 Provider 上下文中生效
# 全局解析不应该受到某个 Provider 映射配置的影响
# 例如Provider A 把 "haiku" 映射到 "sonnet",不应该影响 Provider B 的 "haiku" 解析
from src.models.database import Provider
@@ -401,7 +401,7 @@ class ModelCacheService:
"provider_id": model.provider_id,
"global_model_id": model.global_model_id,
"provider_model_name": model.provider_model_name,
"provider_model_aliases": getattr(model, "provider_model_aliases", None),
"provider_model_mappings": getattr(model, "provider_model_mappings", None),
"is_active": model.is_active,
"is_available": model.is_available if hasattr(model, "is_available") else True,
"price_per_request": (
@@ -424,7 +424,7 @@ class ModelCacheService:
provider_id=model_dict["provider_id"],
global_model_id=model_dict["global_model_id"],
provider_model_name=model_dict["provider_model_name"],
provider_model_aliases=model_dict.get("provider_model_aliases"),
provider_model_mappings=model_dict.get("provider_model_mappings"),
is_active=model_dict["is_active"],
is_available=model_dict.get("is_available", True),
price_per_request=model_dict.get("price_per_request"),

View File

@@ -443,7 +443,7 @@ class ModelCostService:
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或名)
model: 用户请求的模型名(可能是 GlobalModel.name 或映射名)
Returns:
按次计费价格,如果没有配置则返回 None

View File

@@ -84,11 +84,11 @@ class ModelMapperMiddleware:
获取模型映射
简化后的逻辑:
1. 通过 GlobalModel.name 或名解析 GlobalModel
1. 通过 GlobalModel.name 或映射名解析 GlobalModel
2. 找到 GlobalModel 后,查找该 Provider 的 Model 实现
Args:
source_model: 用户请求的模型名(可以是 GlobalModel.name 或名)
source_model: 用户请求的模型名(可以是 GlobalModel.name 或映射名)
provider_id: 提供商ID (UUID)
Returns:
@@ -101,7 +101,7 @@ class ModelMapperMiddleware:
mapping = None
# 步骤 1: 解析 GlobalModel支持名)
# 步骤 1: 解析 GlobalModel支持映射名)
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(
self.db, source_model
)

View File

@@ -51,7 +51,7 @@ class ModelService:
provider_id=provider_id,
global_model_id=model_data.global_model_id,
provider_model_name=model_data.provider_model_name,
provider_model_aliases=model_data.provider_model_aliases,
provider_model_mappings=model_data.provider_model_mappings,
price_per_request=model_data.price_per_request,
tiered_pricing=model_data.tiered_pricing,
supports_vision=model_data.supports_vision,
@@ -153,9 +153,9 @@ class ModelService:
if not model:
raise NotFoundException(f"模型 {model_id} 不存在")
# 保存旧的别名,用于清除缓存
# 保存旧的映射,用于清除缓存
old_provider_model_name = model.provider_model_name
old_provider_model_aliases = model.provider_model_aliases
old_provider_model_mappings = model.provider_model_mappings
# 更新字段
update_data = model_data.model_dump(exclude_unset=True)
@@ -174,26 +174,26 @@ class ModelService:
db.refresh(model)
# 清除 Redis 缓存(异步执行,不阻塞返回)
# 先清除旧的别名缓存
# 先清除旧的映射缓存
asyncio.create_task(
ModelCacheService.invalidate_model_cache(
model_id=model.id,
provider_id=model.provider_id,
global_model_id=model.global_model_id,
provider_model_name=old_provider_model_name,
provider_model_aliases=old_provider_model_aliases,
provider_model_mappings=old_provider_model_mappings,
)
)
# 再清除新的别名缓存(如果有变化)
# 再清除新的映射缓存(如果有变化)
if (model.provider_model_name != old_provider_model_name or
model.provider_model_aliases != old_provider_model_aliases):
model.provider_model_mappings != old_provider_model_mappings):
asyncio.create_task(
ModelCacheService.invalidate_model_cache(
model_id=model.id,
provider_id=model.provider_id,
global_model_id=model.global_model_id,
provider_model_name=model.provider_model_name,
provider_model_aliases=model.provider_model_aliases,
provider_model_mappings=model.provider_model_mappings,
)
)
@@ -246,7 +246,7 @@ class ModelService:
"provider_id": model.provider_id,
"global_model_id": model.global_model_id,
"provider_model_name": model.provider_model_name,
"provider_model_aliases": model.provider_model_aliases,
"provider_model_mappings": model.provider_model_mappings,
}
try:
@@ -260,7 +260,7 @@ class ModelService:
provider_id=cache_info["provider_id"],
global_model_id=cache_info["global_model_id"],
provider_model_name=cache_info["provider_model_name"],
provider_model_aliases=cache_info["provider_model_aliases"],
provider_model_mappings=cache_info["provider_model_mappings"],
)
)
@@ -297,7 +297,7 @@ class ModelService:
provider_id=model.provider_id,
global_model_id=model.global_model_id,
provider_model_name=model.provider_model_name,
provider_model_aliases=model.provider_model_aliases,
provider_model_mappings=model.provider_model_mappings,
)
)
@@ -390,7 +390,7 @@ class ModelService:
provider_id=model.provider_id,
global_model_id=model.global_model_id,
provider_model_name=model.provider_model_name,
provider_model_aliases=model.provider_model_aliases,
provider_model_mappings=model.provider_model_mappings,
# 原始配置值(可能为空)
price_per_request=model.price_per_request,
tiered_pricing=model.tiered_pricing,

View File

@@ -259,6 +259,9 @@ class CleanupScheduler:
StatsAggregatorService.aggregate_daily_stats(
db, current_date_local
)
StatsAggregatorService.aggregate_daily_model_stats(
db, current_date_local
)
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(
@@ -291,6 +294,7 @@ class CleanupScheduler:
yesterday_local = today_local - timedelta(days=1)
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
StatsAggregatorService.aggregate_daily_model_stats(db, yesterday_local)
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
for (user_id,) in users:

View File

@@ -16,6 +16,7 @@ from src.models.database import (
ApiKey,
RequestCandidate,
StatsDaily,
StatsDailyModel,
StatsSummary,
StatsUserDaily,
Usage,
@@ -219,6 +220,120 @@ class StatsAggregatorService:
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {computed['total_requests']} 请求")
return stats
@staticmethod
def aggregate_daily_model_stats(db: Session, date: datetime) -> list[StatsDailyModel]:
"""聚合指定日期的模型维度统计数据
Args:
db: 数据库会话
date: 要聚合的业务日期
Returns:
StatsDailyModel 记录列表
"""
day_start, day_end = _get_business_day_range(date)
# 按模型分组统计
model_stats = (
db.query(
Usage.model,
func.count(Usage.id).label("total_requests"),
func.sum(Usage.input_tokens).label("input_tokens"),
func.sum(Usage.output_tokens).label("output_tokens"),
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time"),
)
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
.group_by(Usage.model)
.all()
)
results = []
for stat in model_stats:
if not stat.model:
continue
existing = (
db.query(StatsDailyModel)
.filter(and_(StatsDailyModel.date == day_start, StatsDailyModel.model == stat.model))
.first()
)
if existing:
record = existing
else:
record = StatsDailyModel(
id=str(uuid.uuid4()), date=day_start, model=stat.model
)
record.total_requests = stat.total_requests or 0
record.input_tokens = int(stat.input_tokens or 0)
record.output_tokens = int(stat.output_tokens or 0)
record.cache_creation_tokens = int(stat.cache_creation_tokens or 0)
record.cache_read_tokens = int(stat.cache_read_tokens or 0)
record.total_cost = float(stat.total_cost or 0)
record.avg_response_time_ms = float(stat.avg_response_time or 0)
if not existing:
db.add(record)
results.append(record)
db.commit()
logger.info(
f"[StatsAggregator] 聚合日期 {date.date()} 模型统计完成: {len(results)} 个模型"
)
return results
@staticmethod
def get_daily_model_stats(db: Session, start_date: datetime, end_date: datetime) -> list[dict]:
"""获取日期范围内的模型统计数据(优先使用预聚合)
Args:
db: 数据库会话
start_date: 开始日期 (UTC)
end_date: 结束日期 (UTC)
Returns:
模型统计数据列表
"""
from zoneinfo import ZoneInfo
app_tz = ZoneInfo(APP_TIMEZONE)
# 从预聚合表获取历史数据
stats = (
db.query(StatsDailyModel)
.filter(and_(StatsDailyModel.date >= start_date, StatsDailyModel.date < end_date))
.order_by(StatsDailyModel.date.asc(), StatsDailyModel.total_cost.desc())
.all()
)
# 转换为字典格式,按日期分组
result = []
for stat in stats:
# 转换日期为业务时区
if stat.date.tzinfo is None:
date_utc = stat.date.replace(tzinfo=timezone.utc)
else:
date_utc = stat.date.astimezone(timezone.utc)
date_str = date_utc.astimezone(app_tz).date().isoformat()
result.append({
"date": date_str,
"model": stat.model,
"requests": stat.total_requests,
"tokens": (
stat.input_tokens + stat.output_tokens +
stat.cache_creation_tokens + stat.cache_read_tokens
),
"cost": stat.total_cost,
"avg_response_time": stat.avg_response_time_ms / 1000.0 if stat.avg_response_time_ms else 0,
})
return result
@staticmethod
def aggregate_user_daily_stats(
db: Session, user_id: str, date: datetime
@@ -497,6 +612,7 @@ class StatsAggregatorService:
current_date = start_date
while current_date < today_local:
StatsAggregatorService.aggregate_daily_stats(db, current_date)
StatsAggregatorService.aggregate_daily_model_stats(db, current_date)
count += 1
current_date += timedelta(days=1)