refactor: optimize provider query and stats aggregation logic

This commit is contained in:
fawney19
2025-12-17 16:41:10 +08:00
parent 50abb55c94
commit 1dac4cb156
21 changed files with 1753 additions and 592 deletions

View File

@@ -7,6 +7,7 @@ from .api_keys import router as api_keys_router
from .endpoints import router as endpoints_router
from .models import router as models_router
from .monitoring import router as monitoring_router
from .provider_query import router as provider_query_router
from .provider_strategy import router as provider_strategy_router
from .providers import router as providers_router
from .security import router as security_router
@@ -26,5 +27,6 @@ router.include_router(provider_strategy_router)
router.include_router(adaptive_router)
router.include_router(models_router)
router.include_router(security_router)
router.include_router(provider_query_router)
__all__ = ["router"]

View File

@@ -1,46 +1,28 @@
"""
Provider Query API 端点
用于查询提供商的余额、使用记录等信息
用于查询提供商的模型列表等信息
"""
from datetime import datetime
import asyncio
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
import httpx
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, joinedload
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.database.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
# 初始化适配器注册
from src.plugins.provider_query import init # noqa
from src.plugins.provider_query import get_query_registry
from src.plugins.provider_query.base import QueryCapability
from src.models.database import Provider, ProviderEndpoint, User
from src.utils.auth_utils import get_current_user
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
# ============ Request/Response Models ============
class BalanceQueryRequest(BaseModel):
"""余额查询请求"""
provider_id: str
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
class UsageSummaryQueryRequest(BaseModel):
"""使用汇总查询请求"""
provider_id: str
api_key_id: Optional[str] = None
period: str = "month" # day, week, month, year
class ModelsQueryRequest(BaseModel):
"""模型列表查询请求"""
@@ -51,360 +33,281 @@ class ModelsQueryRequest(BaseModel):
# ============ API Endpoints ============
@router.get("/adapters")
async def list_adapters(
current_user: User = Depends(get_current_user),
):
"""
获取所有可用的查询适配器
async def _fetch_openai_models(
client: httpx.AsyncClient,
base_url: str,
api_key: str,
api_format: str,
extra_headers: Optional[dict] = None,
) -> tuple[list, Optional[str]]:
"""获取 OpenAI 格式的模型列表
Returns:
适配器列表
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
registry = get_query_registry()
adapters = registry.list_adapters()
headers = {"Authorization": f"Bearer {api_key}"}
if extra_headers:
# 防止 extra_headers 覆盖 Authorization
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
headers.update(safe_headers)
return {"success": True, "data": adapters}
# 构建 /v1/models URL
if base_url.endswith("/v1"):
models_url = f"{base_url}/models"
else:
models_url = f"{base_url}/v1/models"
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
models = []
if "data" in data:
models = data["data"]
elif isinstance(data, list):
models = data
# 为每个模型添加 api_format 字段
for m in models:
m["api_format"] = api_format
return models, None
else:
# 记录详细的错误信息
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch models from {models_url}: {e}")
return [], error_msg
@router.get("/capabilities/{provider_id}")
async def get_provider_capabilities(
provider_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取提供商支持的查询能力
Args:
provider_id: 提供商 ID
async def _fetch_claude_models(
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
) -> tuple[list, Optional[str]]:
"""获取 Claude 格式的模型列表
Returns:
支持的查询能力列表
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
# 获取提供商
from sqlalchemy import select
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
capabilities = registry.get_capabilities_for_provider(provider.name)
if capabilities is None:
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [],
"has_adapter": False,
"message": "No query adapter available for this provider",
},
}
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [c.name for c in capabilities],
"has_adapter": True,
},
headers = {
"x-api-key": api_key,
"Authorization": f"Bearer {api_key}",
"anthropic-version": "2023-06-01",
}
# 构建 /v1/models URL
if base_url.endswith("/v1"):
models_url = f"{base_url}/models"
else:
models_url = f"{base_url}/v1/models"
@router.post("/balance")
async def query_balance(
request: BalanceQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商余额
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
models = []
if "data" in data:
models = data["data"]
elif isinstance(data, list):
models = data
# 为每个模型添加 api_format 字段
for m in models:
m["api_format"] = api_format
return models, None
else:
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
return [], error_msg
Args:
request: 查询请求
async def _fetch_gemini_models(
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
) -> tuple[list, Optional[str]]:
"""获取 Gemini 格式的模型列表
Returns:
余额信息
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 兼容 base_url 已包含 /v1beta 的情况
base_url_clean = base_url.rstrip("/")
if base_url_clean.endswith("/v1beta"):
models_url = f"{base_url_clean}/models?key={api_key}"
else:
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
if request.api_key_id:
# 查找指定的 API Key
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
try:
response = await client.get(models_url)
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
if "models" in data:
# 转换为统一格式
return [
{
"id": m.get("name", "").replace("models/", ""),
"owned_by": "google",
"display_name": m.get("displayName", ""),
"api_format": api_format,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 使用第一个可用的 API Key
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询余额
registry = get_query_registry()
query_result = await registry.query_provider_balance(
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
)
if not query_result.success:
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.post("/usage-summary")
async def query_usage_summary(
request: UsageSummaryQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商使用汇总
Args:
request: 查询请求
Returns:
使用汇总信息
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key逻辑同上
api_key_value = None
endpoint_config = None
if request.api_key_id:
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询使用汇总
registry = get_query_registry()
query_result = await registry.query_provider_usage(
provider_type=provider.name,
api_key=api_key_value,
period=request.period,
endpoint_config=endpoint_config,
)
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
for m in data["models"]
], None
return [], None
else:
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
return [], error_msg
@router.post("/models")
async def query_available_models(
request: ModelsQueryRequest,
db: AsyncSession = Depends(get_db),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商可用模型
遍历所有活跃端点,根据端点的 API 格式选择正确的请求方式:
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
Args:
request: 查询请求
Returns:
模型列表
所有端点的模型列表(合并)
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.filter(Provider.id == request.provider_id)
.first()
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
# 收集所有活跃端点的配置
endpoint_configs: list[dict] = []
if request.api_key_id:
# 指定了特定的 API Key只使用该 Key 对应的端点
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break
if api_key_value:
if endpoint_configs:
break
if not api_key_value:
if not endpoint_configs:
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 遍历所有活跃端点,每个端点取第一个可用的 Key
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not endpoint.is_active or not endpoint.api_keys:
continue
if not api_key_value:
# 找第一个可用的 Key
for api_key in endpoint.api_keys:
if api_key.is_active:
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
continue # 尝试下一个 Key
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break # 只取第一个可用的 Key
if not endpoint_configs:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询模型
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
# 并发请求所有端点的模型列表
all_models: list = []
errors: list[str] = []
if not adapter:
raise HTTPException(
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
async def fetch_endpoint_models(
client: httpx.AsyncClient, config: dict
) -> tuple[list, Optional[str]]:
base_url = config["base_url"]
if not base_url:
return [], None
base_url = base_url.rstrip("/")
api_format = config["api_format"]
api_key_value = config["api_key"]
extra_headers = config["extra_headers"]
try:
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
elif api_format in ["GEMINI", "GEMINI_CLI"]:
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
else:
return await _fetch_openai_models(
client, base_url, api_key_value, api_format, extra_headers
)
except Exception as e:
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
return [], f"{api_format}: {str(e)}"
async with httpx.AsyncClient(timeout=30.0) as client:
results = await asyncio.gather(
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
)
for models, error in results:
all_models.extend(models)
if error:
errors.append(error)
query_result = await adapter.query_available_models(
api_key=api_key_value, endpoint_config=endpoint_config
)
# 按 model id 去重(保留第一个)
seen_ids: set[str] = set()
unique_models: list = []
for model in all_models:
model_id = model.get("id")
if model_id and model_id not in seen_ids:
seen_ids.add(model_id)
unique_models.append(model)
error = "; ".join(errors) if errors else None
if not unique_models and not error:
error = "No models returned from any endpoint"
return {
"success": query_result.success,
"data": query_result.to_dict(),
"success": len(unique_models) > 0,
"data": {"models": unique_models, "error": error},
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.delete("/cache/{provider_id}")
async def clear_query_cache(
provider_id: str,
api_key_id: Optional[str] = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
清除查询缓存
Args:
provider_id: 提供商 ID
api_key_id: 可选,指定清除某个 API Key 的缓存
Returns:
清除结果
"""
from sqlalchemy import select
# 获取提供商
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
if adapter:
if api_key_id:
# 获取 API Key 值来清除缓存
from sqlalchemy.orm import selectinload
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
api_key = result.scalar_one_or_none()
if api_key:
adapter.clear_cache(api_key.api_key)
else:
adapter.clear_cache()
return {"success": True, "message": "Cache cleared successfully"}

View File

@@ -731,8 +731,15 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
)
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 需要转回业务时区再取日期,才能与日期序列匹配
def _to_business_date_str(value: datetime) -> str:
if value.tzinfo is None:
value_utc = value.replace(tzinfo=timezone.utc)
else:
value_utc = value.astimezone(timezone.utc)
return value_utc.astimezone(app_tz).date().isoformat()
stats_map = {
stat.date.replace(tzinfo=timezone.utc).astimezone(app_tz).date().isoformat(): {
_to_business_date_str(stat.date): {
"requests": stat.total_requests,
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
"cost": stat.total_cost,
@@ -790,6 +797,38 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
"unique_providers": today_unique_providers,
"fallback_count": today_fallback_count,
}
# 历史预聚合缺失时兜底:按业务日范围实时计算(仅补最近少量缺失,避免全表扫描)
yesterday_date = today_local.date() - timedelta(days=1)
historical_end = min(end_date_local.date(), yesterday_date)
missing_dates: list[str] = []
cursor = start_date_local.date()
while cursor <= historical_end:
date_str = cursor.isoformat()
if date_str not in stats_map:
missing_dates.append(date_str)
cursor += timedelta(days=1)
if missing_dates:
for date_str in missing_dates[-7:]:
target_local = datetime.fromisoformat(date_str).replace(tzinfo=app_tz)
computed = StatsAggregatorService.compute_daily_stats(db, target_local)
stats_map[date_str] = {
"requests": computed["total_requests"],
"tokens": (
computed["input_tokens"]
+ computed["output_tokens"]
+ computed["cache_creation_tokens"]
+ computed["cache_read_tokens"]
),
"cost": computed["total_cost"],
"avg_response_time": computed["avg_response_time_ms"] / 1000.0
if computed["avg_response_time_ms"]
else 0,
"unique_models": computed["unique_models"],
"unique_providers": computed["unique_providers"],
"fallback_count": computed["fallback_count"],
}
else:
# 普通用户:仍需实时查询(用户级预聚合可选)
query = db.query(Usage).filter(

View File

@@ -266,8 +266,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
if mapping and mapping.model:
# 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
# 传入 api_format 用于过滤适用的别名作用域
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
mapped_name = mapping.model.select_provider_model_name(
affinity_key, api_format=self.FORMAT_ID
)
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
return mapped_name

View File

@@ -155,8 +155,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
if mapping and mapping.model:
# 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
# 传入 api_format 用于过滤适用的别名作用域
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
mapped_name = mapping.model.select_provider_model_name(
affinity_key, api_format=self.FORMAT_ID
)
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
return mapped_name

View File

@@ -813,7 +813,9 @@ class Model(Base):
def get_effective_supports_image_generation(self) -> bool:
return self._get_effective_capability("supports_image_generation", False)
def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str:
def select_provider_model_name(
self, affinity_key: Optional[str] = None, api_format: Optional[str] = None
) -> str:
"""按优先级选择要使用的 Provider 模型名称
如果配置了 provider_model_aliases按优先级选择数字越小越优先
@@ -822,6 +824,7 @@ class Model(Base):
Args:
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
api_format: 当前请求的 API 格式(如 CLAUDE、OPENAI 等),用于过滤适用的别名
"""
import hashlib
@@ -840,6 +843,13 @@ class Model(Base):
if not isinstance(name, str) or not name.strip():
continue
# 检查 api_formats 作用域(如果配置了且当前有 api_format
alias_api_formats = raw.get("api_formats")
if api_format and alias_api_formats:
# 如果配置了作用域,只有匹配时才生效
if isinstance(alias_api_formats, list) and api_format not in alias_api_formats:
continue
raw_priority = raw.get("priority", 1)
try:
priority = int(raw_priority)

View File

@@ -35,6 +35,7 @@ class CleanupScheduler:
def __init__(self):
self.running = False
self._interval_tasks = []
self._stats_aggregation_lock = asyncio.Lock()
async def start(self):
"""启动调度器"""
@@ -56,6 +57,14 @@ class CleanupScheduler:
job_id="stats_aggregation",
name="统计数据聚合",
)
# 统计聚合补偿任务 - 每 30 分钟检查缺失并回填
scheduler.add_interval_job(
self._scheduled_stats_aggregation,
minutes=30,
job_id="stats_aggregation_backfill",
name="统计数据聚合补偿",
backfill=True,
)
# 清理任务 - 凌晨 3 点执行
scheduler.add_cron_job(
@@ -115,9 +124,9 @@ class CleanupScheduler:
# ========== 任务函数APScheduler 直接调用异步函数) ==========
async def _scheduled_stats_aggregation(self):
async def _scheduled_stats_aggregation(self, backfill: bool = False):
"""统计聚合任务(定时调用)"""
await self._perform_stats_aggregation()
await self._perform_stats_aggregation(backfill=backfill)
async def _scheduled_cleanup(self):
"""清理任务(定时调用)"""
@@ -144,136 +153,157 @@ class CleanupScheduler:
Args:
backfill: 是否回填历史数据(启动时检查缺失的日期)
"""
db = create_session()
try:
# 检查是否启用统计聚合
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
logger.info("统计聚合已禁用,跳过聚合任务")
return
if self._stats_aggregation_lock.locked():
logger.info("统计聚合任务正在运行,跳过本次触发")
return
logger.info("开始执行统计数据聚合...")
from src.models.database import StatsDaily, User as DBUser
from src.services.system.scheduler import APP_TIMEZONE
from zoneinfo import ZoneInfo
# 使用业务时区计算日期,确保与定时任务触发时间一致
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
app_tz = ZoneInfo(APP_TIMEZONE)
now_local = datetime.now(app_tz)
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
if backfill:
# 启动时检查并回填缺失的日期
from src.models.database import StatsSummary
summary = db.query(StatsSummary).first()
if not summary:
# 首次运行,回填所有历史数据
logger.info("检测到首次运行,开始回填历史统计数据...")
days_to_backfill = SystemConfigService.get_config(
db, "stats_backfill_days", 365
)
count = StatsAggregatorService.backfill_historical_data(
db, days=days_to_backfill
)
logger.info(f"历史数据回填完成,共 {count}")
async with self._stats_aggregation_lock:
db = create_session()
try:
# 检查是否启用统计聚合
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
logger.info("统计聚合已禁用,跳过聚合任务")
return
# 非首次运行,检查最近是否有缺失的日期需要回填
latest_stat = (
db.query(StatsDaily)
.order_by(StatsDaily.date.desc())
.first()
)
logger.info("开始执行统计数据聚合...")
if latest_stat:
latest_date_utc = latest_stat.date
if latest_date_utc.tzinfo is None:
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
else:
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
from src.models.database import StatsDaily, User as DBUser
from src.services.system.scheduler import APP_TIMEZONE
from zoneinfo import ZoneInfo
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
latest_business_date = latest_date_utc.astimezone(app_tz).date()
yesterday_business_date = (today_local.date() - timedelta(days=1))
missing_start_date = latest_business_date + timedelta(days=1)
# 使用业务时区计算日期,确保与定时任务触发时间一致
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
app_tz = ZoneInfo(APP_TIMEZONE)
now_local = datetime.now(app_tz)
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
if missing_start_date <= yesterday_business_date:
missing_days = (yesterday_business_date - missing_start_date).days + 1
logger.info(
f"检测到缺失 {missing_days} 天的统计数据 "
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
if backfill:
# 启动时检查并回填缺失的日期
from src.models.database import StatsSummary
summary = db.query(StatsSummary).first()
if not summary:
# 首次运行,回填所有历史数据
logger.info("检测到首次运行,开始回填历史统计数据...")
days_to_backfill = SystemConfigService.get_config(
db, "stats_backfill_days", 365
)
count = StatsAggregatorService.backfill_historical_data(
db, days=days_to_backfill
)
logger.info(f"历史数据回填完成,共 {count}")
return
current_date = missing_start_date
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
# 非首次运行,检查最近是否有缺失的日期需要回填
latest_stat = db.query(StatsDaily).order_by(StatsDaily.date.desc()).first()
while current_date <= yesterday_business_date:
try:
current_date_local = datetime.combine(
current_date, datetime.min.time(), tzinfo=app_tz
if latest_stat:
latest_date_utc = latest_stat.date
if latest_date_utc.tzinfo is None:
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
else:
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
latest_business_date = latest_date_utc.astimezone(app_tz).date()
yesterday_business_date = today_local.date() - timedelta(days=1)
missing_start_date = latest_business_date + timedelta(days=1)
if missing_start_date <= yesterday_business_date:
missing_days = (
yesterday_business_date - missing_start_date
).days + 1
# 限制最大回填天数,防止停机很久后一次性回填太多
max_backfill_days: int = SystemConfigService.get_config(
db, "max_stats_backfill_days", 30
) or 30
if missing_days > max_backfill_days:
logger.warning(
f"缺失 {missing_days} 天数据超过最大回填限制 "
f"{max_backfill_days} 天,只回填最近 {max_backfill_days}"
)
StatsAggregatorService.aggregate_daily_stats(db, current_date_local)
# 聚合用户数据
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, current_date_local
)
except Exception as e:
logger.warning(
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
)
try:
db.rollback()
except Exception:
pass
except Exception as e:
logger.warning(f"回填日期 {current_date} 失败: {e}")
missing_start_date = yesterday_business_date - timedelta(
days=max_backfill_days - 1
)
missing_days = max_backfill_days
logger.info(
f"检测到缺失 {missing_days} 天的统计数据 "
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
)
current_date = missing_start_date
users = (
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
)
while current_date <= yesterday_business_date:
try:
db.rollback()
except Exception:
pass
current_date_local = datetime.combine(
current_date, datetime.min.time(), tzinfo=app_tz
)
StatsAggregatorService.aggregate_daily_stats(
db, current_date_local
)
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, current_date_local
)
except Exception as e:
logger.warning(
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
)
try:
db.rollback()
except Exception:
pass
except Exception as e:
logger.warning(f"回填日期 {current_date} 失败: {e}")
try:
db.rollback()
except Exception:
pass
current_date += timedelta(days=1)
current_date += timedelta(days=1)
# 更新全局汇总
StatsAggregatorService.update_summary(db)
logger.info(f"缺失数据回填完成,共 {missing_days}")
else:
logger.info("统计数据已是最新,无需回填")
return
StatsAggregatorService.update_summary(db)
logger.info(f"缺失数据回填完成,共 {missing_days}")
else:
logger.info("统计数据已是最新,无需回填")
return
# 定时任务:聚合昨天的数据
# 注意aggregate_daily_stats 期望业务时区的日期,不是 UTC
yesterday_local = today_local - timedelta(days=1)
# 定时任务:聚合昨天的数据
yesterday_local = today_local - timedelta(days=1)
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
# 聚合所有用户的昨日数据
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(db, user_id, yesterday_local)
except Exception as e:
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
# 回滚当前用户的失败操作,继续处理其他用户
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
for (user_id,) in users:
try:
db.rollback()
except Exception:
pass
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, yesterday_local
)
except Exception as e:
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
try:
db.rollback()
except Exception:
pass
# 更新全局汇总
StatsAggregatorService.update_summary(db)
StatsAggregatorService.update_summary(db)
logger.info("统计数据聚合完成")
logger.info("统计数据聚合完成")
except Exception as e:
logger.exception(f"统计聚合任务执行失败: {e}")
db.rollback()
finally:
db.close()
except Exception as e:
logger.exception(f"统计聚合任务执行失败: {e}")
try:
db.rollback()
except Exception:
pass
finally:
db.close()
async def _perform_pending_cleanup(self):
"""执行 pending 状态清理"""

View File

@@ -56,65 +56,44 @@ class StatsAggregatorService:
"""统计数据聚合服务"""
@staticmethod
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
"""聚合指定日期的统计数据
Args:
db: 数据库会话
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
Returns:
StatsDaily 记录
"""
# 将业务日期转换为 UTC 时间范围
def compute_daily_stats(db: Session, date: datetime) -> dict:
"""计算指定业务日期的统计数据(不写入数据库)"""
day_start, day_end = _get_business_day_range(date)
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 检查是否已存在该日期的记录
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
if existing:
stats = existing
else:
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
# 基础请求统计
base_query = db.query(Usage).filter(
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
)
total_requests = base_query.count()
# 如果没有请求,直接返回空记录
if total_requests == 0:
stats.total_requests = 0
stats.success_requests = 0
stats.error_requests = 0
stats.input_tokens = 0
stats.output_tokens = 0
stats.cache_creation_tokens = 0
stats.cache_read_tokens = 0
stats.total_cost = 0.0
stats.actual_total_cost = 0.0
stats.input_cost = 0.0
stats.output_cost = 0.0
stats.cache_creation_cost = 0.0
stats.cache_read_cost = 0.0
stats.avg_response_time_ms = 0.0
stats.fallback_count = 0
return {
"day_start": day_start,
"total_requests": 0,
"success_requests": 0,
"error_requests": 0,
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
"total_cost": 0.0,
"actual_total_cost": 0.0,
"input_cost": 0.0,
"output_cost": 0.0,
"cache_creation_cost": 0.0,
"cache_read_cost": 0.0,
"avg_response_time_ms": 0.0,
"fallback_count": 0,
"unique_models": 0,
"unique_providers": 0,
}
if not existing:
db.add(stats)
db.commit()
return stats
# 错误请求数
error_requests = (
base_query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
)
# Token 和成本聚合
aggregated = (
db.query(
func.sum(Usage.input_tokens).label("input_tokens"),
@@ -157,7 +136,6 @@ class StatsAggregatorService:
or 0
)
# 使用维度统计
unique_models = (
db.query(func.count(func.distinct(Usage.model)))
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
@@ -171,31 +149,74 @@ class StatsAggregatorService:
or 0
)
return {
"day_start": day_start,
"total_requests": total_requests,
"success_requests": total_requests - error_requests,
"error_requests": error_requests,
"input_tokens": int(aggregated.input_tokens or 0) if aggregated else 0,
"output_tokens": int(aggregated.output_tokens or 0) if aggregated else 0,
"cache_creation_tokens": int(aggregated.cache_creation_tokens or 0) if aggregated else 0,
"cache_read_tokens": int(aggregated.cache_read_tokens or 0) if aggregated else 0,
"total_cost": float(aggregated.total_cost or 0) if aggregated else 0.0,
"actual_total_cost": float(aggregated.actual_total_cost or 0) if aggregated else 0.0,
"input_cost": float(aggregated.input_cost or 0) if aggregated else 0.0,
"output_cost": float(aggregated.output_cost or 0) if aggregated else 0.0,
"cache_creation_cost": float(aggregated.cache_creation_cost or 0) if aggregated else 0.0,
"cache_read_cost": float(aggregated.cache_read_cost or 0) if aggregated else 0.0,
"avg_response_time_ms": float(aggregated.avg_response_time or 0) if aggregated else 0.0,
"fallback_count": fallback_count,
"unique_models": unique_models,
"unique_providers": unique_providers,
}
@staticmethod
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
"""聚合指定日期的统计数据
Args:
db: 数据库会话
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
Returns:
StatsDaily 记录
"""
computed = StatsAggregatorService.compute_daily_stats(db, date)
day_start = computed["day_start"]
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 检查是否已存在该日期的记录
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
if existing:
stats = existing
else:
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
# 更新统计记录
stats.total_requests = total_requests
stats.success_requests = total_requests - error_requests
stats.error_requests = error_requests
stats.input_tokens = int(aggregated.input_tokens or 0)
stats.output_tokens = int(aggregated.output_tokens or 0)
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
stats.total_cost = float(aggregated.total_cost or 0)
stats.actual_total_cost = float(aggregated.actual_total_cost or 0)
stats.input_cost = float(aggregated.input_cost or 0)
stats.output_cost = float(aggregated.output_cost or 0)
stats.cache_creation_cost = float(aggregated.cache_creation_cost or 0)
stats.cache_read_cost = float(aggregated.cache_read_cost or 0)
stats.avg_response_time_ms = float(aggregated.avg_response_time or 0)
stats.fallback_count = fallback_count
stats.unique_models = unique_models
stats.unique_providers = unique_providers
stats.total_requests = computed["total_requests"]
stats.success_requests = computed["success_requests"]
stats.error_requests = computed["error_requests"]
stats.input_tokens = computed["input_tokens"]
stats.output_tokens = computed["output_tokens"]
stats.cache_creation_tokens = computed["cache_creation_tokens"]
stats.cache_read_tokens = computed["cache_read_tokens"]
stats.total_cost = computed["total_cost"]
stats.actual_total_cost = computed["actual_total_cost"]
stats.input_cost = computed["input_cost"]
stats.output_cost = computed["output_cost"]
stats.cache_creation_cost = computed["cache_creation_cost"]
stats.cache_read_cost = computed["cache_read_cost"]
stats.avg_response_time_ms = computed["avg_response_time_ms"]
stats.fallback_count = computed["fallback_count"]
stats.unique_models = computed["unique_models"]
stats.unique_providers = computed["unique_providers"]
if not existing:
db.add(stats)
db.commit()
# 日志使用业务日期(输入参数),而不是 UTC 日期
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {total_requests} 请求")
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {computed['total_requests']} 请求")
return stats
@staticmethod