From 77613795ede58082eec69576da7b5c35fc6f7b4c Mon Sep 17 00:00:00 2001 From: fawney19 Date: Sat, 13 Dec 2025 22:27:00 +0800 Subject: [PATCH] refactor(backend): optimize usage service and database helpers --- pyproject.toml | 7 ++++ src/services/usage/service.py | 61 +++++++++++++++++++++-------------- src/utils/database_helpers.py | 5 +-- 3 files changed, 46 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2be9233..864546f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,13 @@ disallow_untyped_defs = true exclude = "tools/debug" # 忽略项目内部模块的 import-untyped 警告 ignore_missing_imports = true +# SQLAlchemy mypy 插件 +plugins = ["sqlalchemy.ext.mypy.plugin"] + +# SQLAlchemy 相关模块放宽类型检查(模型未使用 Mapped 注解) +[[tool.mypy.overrides]] +module = ["src.services.usage.service", "src.models.database"] +disable_error_code = ["arg-type", "assignment"] [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/src/services/usage/service.py b/src/services/usage/service.py index 6605616..b4a14cf 100644 --- a/src/services/usage/service.py +++ b/src/services/usage/service.py @@ -4,7 +4,7 @@ import uuid from datetime import datetime, timedelta, timezone -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from sqlalchemy import func from sqlalchemy.orm import Session @@ -53,7 +53,7 @@ class UsageService: @classmethod async def get_cache_prices_async( cls, db: Session, provider: str, model: str, input_price: float - ) -> tuple[float, float]: + ) -> Tuple[Optional[float], Optional[float]]: """异步获取模型缓存价格(缓存创建价格,缓存读取价格)每1M tokens""" service = ModelCostService(db) return await service.get_cache_prices_async(provider, model, input_price) @@ -61,7 +61,7 @@ class UsageService: @classmethod def get_cache_prices( cls, db: Session, provider: str, model: str, input_price: float - ) -> tuple[float, float]: + ) -> Tuple[Optional[float], Optional[float]]: """获取模型缓存价格(缓存创建价格,缓存读取价格)每1M tokens""" service = ModelCostService(db) return service.get_cache_prices(provider, model, input_price) @@ -703,20 +703,20 @@ class UsageService: from src.models.database import ApiKey, User - # 更新用户使用量(原子操作)- 使用实际费用(免费套餐为 0) + # 更新用户使用量(原子操作)- 使用标准计费价格 # 独立Key不计入创建者的使用记录 if user and not (api_key and api_key.is_standalone): db.execute( update(User) .where(User.id == user.id) .values( - used_usd=User.used_usd + actual_total_cost, - total_usd=User.total_usd + actual_total_cost, + used_usd=User.used_usd + total_cost, + total_usd=User.total_usd + total_cost, updated_at=func.now(), ) ) - # 更新API密钥使用量(原子操作)- 使用实际费用(免费套餐为 0) + # 更新API密钥使用量(原子操作)- 使用标准计费价格 if api_key: # 独立余额Key需要扣除余额 if api_key.is_standalone: @@ -725,8 +725,8 @@ class UsageService: .where(ApiKey.id == api_key.id) .values( total_requests=ApiKey.total_requests + 1, - total_cost_usd=ApiKey.total_cost_usd + actual_total_cost, - balance_used_usd=ApiKey.balance_used_usd + actual_total_cost, + total_cost_usd=ApiKey.total_cost_usd + total_cost, + balance_used_usd=ApiKey.balance_used_usd + total_cost, last_used_at=func.now(), updated_at=func.now(), ) @@ -738,7 +738,7 @@ class UsageService: .where(ApiKey.id == api_key.id) .values( total_requests=ApiKey.total_requests + 1, - total_cost_usd=ApiKey.total_cost_usd + actual_total_cost, + total_cost_usd=ApiKey.total_cost_usd + total_cost, last_used_at=func.now(), updated_at=func.now(), ) @@ -837,8 +837,10 @@ class UsageService: return True, "OK" # 有配额限制,检查是否超额 - if user.used_usd + estimated_cost > user.quota_usd: - remaining = user.quota_usd - user.used_usd + used_usd = float(user.used_usd or 0) + quota_usd = float(user.quota_usd) + if used_usd + estimated_cost > quota_usd: + remaining = quota_usd - used_usd return False, f"配额不足(剩余: ${remaining:.2f})" return True, "OK" @@ -871,7 +873,8 @@ class UsageService: from src.utils.database_helpers import date_trunc_portable # 检测数据库方言 - dialect = db.bind.dialect.name + bind = db.bind + dialect = bind.dialect.name if bind is not None else "sqlite" # 根据分组类型选择日期函数(兼容多种数据库) if group_by == "day": @@ -976,7 +979,7 @@ class UsageService: query = query.group_by(day_bucket).order_by(day_bucket) rows = query.all() - def normalize_period(value) -> str: + def normalize_period(value: Any) -> str: if value is None: return "" if isinstance(value, str): @@ -1500,8 +1503,8 @@ class UsageService: users = db.query(User).filter(User.id.in_(group_ids)).all() for user in users: user_info_map[str(user.id)] = { - "username": user.username, - "email": user.email or "", + "username": str(user.username), + "email": str(user.email) if user.email else "", } # 处理结果 @@ -1679,12 +1682,20 @@ class UsageService: result = query.first() - total_requests = result.total_requests or 0 - total_input_tokens = result.total_input_tokens or 0 - total_cache_read_tokens = result.total_cache_read_tokens or 0 - total_cache_creation_tokens = result.total_cache_creation_tokens or 0 - total_cache_read_cost = float(result.total_cache_read_cost or 0) - total_cache_creation_cost = float(result.total_cache_creation_cost or 0) + if result is None: + total_requests = 0 + total_input_tokens = 0 + total_cache_read_tokens = 0 + total_cache_creation_tokens = 0 + total_cache_read_cost = 0.0 + total_cache_creation_cost = 0.0 + else: + total_requests = result.total_requests or 0 + total_input_tokens = result.total_input_tokens or 0 + total_cache_read_tokens = result.total_cache_read_tokens or 0 + total_cache_creation_tokens = result.total_cache_creation_tokens or 0 + total_cache_read_cost = float(result.total_cache_read_cost or 0) + total_cache_creation_cost = float(result.total_cache_creation_cost or 0) # 计算缓存命中率(按 token 数) # 总输入上下文 = input_tokens + cache_read_tokens(因为 input_tokens 不含 cache_read) @@ -1711,13 +1722,13 @@ class UsageService: requests_with_cache_hit = requests_with_cache_hit.filter(Usage.user_id == user_id) if api_key_id: requests_with_cache_hit = requests_with_cache_hit.filter(Usage.api_key_id == api_key_id) - requests_with_cache_hit = requests_with_cache_hit.scalar() or 0 + requests_with_cache_hit_count = int(requests_with_cache_hit.scalar() or 0) return { "analysis_period_hours": hours, "total_requests": total_requests, - "requests_with_cache_hit": requests_with_cache_hit, - "request_cache_hit_rate": round(requests_with_cache_hit / total_requests * 100, 2) if total_requests > 0 else 0, + "requests_with_cache_hit": requests_with_cache_hit_count, + "request_cache_hit_rate": round(requests_with_cache_hit_count / total_requests * 100, 2) if total_requests > 0 else 0, "total_input_tokens": total_input_tokens, "total_cache_read_tokens": total_cache_read_tokens, "total_cache_creation_tokens": total_cache_creation_tokens, diff --git a/src/utils/database_helpers.py b/src/utils/database_helpers.py index f3a4251..d974238 100644 --- a/src/utils/database_helpers.py +++ b/src/utils/database_helpers.py @@ -2,11 +2,12 @@ 数据库方言兼容性辅助函数 """ +from typing import Any + from sqlalchemy import func -from sqlalchemy.sql.elements import ClauseElement -def date_trunc_portable(dialect_name: str, interval: str, column) -> ClauseElement: +def date_trunc_portable(dialect_name: str, interval: str, column: Any) -> Any: """ 跨数据库的日期截断函数