refactor(backend): optimize usage service and database helpers

This commit is contained in:
fawney19
2025-12-13 22:27:00 +08:00
parent f54127cba5
commit 77613795ed
3 changed files with 46 additions and 27 deletions

View File

@@ -97,6 +97,13 @@ disallow_untyped_defs = true
exclude = "tools/debug" exclude = "tools/debug"
# 忽略项目内部模块的 import-untyped 警告 # 忽略项目内部模块的 import-untyped 警告
ignore_missing_imports = true ignore_missing_imports = true
# SQLAlchemy mypy 插件
plugins = ["sqlalchemy.ext.mypy.plugin"]
# SQLAlchemy 相关模块放宽类型检查(模型未使用 Mapped 注解)
[[tool.mypy.overrides]]
module = ["src.services.usage.service", "src.models.database"]
disable_error_code = ["arg-type", "assignment"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = ["tests"] testpaths = ["tests"]

View File

@@ -4,7 +4,7 @@
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -53,7 +53,7 @@ class UsageService:
@classmethod @classmethod
async def get_cache_prices_async( async def get_cache_prices_async(
cls, db: Session, provider: str, model: str, input_price: float cls, db: Session, provider: str, model: str, input_price: float
) -> tuple[float, float]: ) -> Tuple[Optional[float], Optional[float]]:
"""异步获取模型缓存价格缓存创建价格缓存读取价格每1M tokens""" """异步获取模型缓存价格缓存创建价格缓存读取价格每1M tokens"""
service = ModelCostService(db) service = ModelCostService(db)
return await service.get_cache_prices_async(provider, model, input_price) return await service.get_cache_prices_async(provider, model, input_price)
@@ -61,7 +61,7 @@ class UsageService:
@classmethod @classmethod
def get_cache_prices( def get_cache_prices(
cls, db: Session, provider: str, model: str, input_price: float cls, db: Session, provider: str, model: str, input_price: float
) -> tuple[float, float]: ) -> Tuple[Optional[float], Optional[float]]:
"""获取模型缓存价格缓存创建价格缓存读取价格每1M tokens""" """获取模型缓存价格缓存创建价格缓存读取价格每1M tokens"""
service = ModelCostService(db) service = ModelCostService(db)
return service.get_cache_prices(provider, model, input_price) return service.get_cache_prices(provider, model, input_price)
@@ -703,20 +703,20 @@ class UsageService:
from src.models.database import ApiKey, User from src.models.database import ApiKey, User
# 更新用户使用量(原子操作)- 使用实际费用(免费套餐为 0 # 更新用户使用量(原子操作)- 使用标准计费价格
# 独立Key不计入创建者的使用记录 # 独立Key不计入创建者的使用记录
if user and not (api_key and api_key.is_standalone): if user and not (api_key and api_key.is_standalone):
db.execute( db.execute(
update(User) update(User)
.where(User.id == user.id) .where(User.id == user.id)
.values( .values(
used_usd=User.used_usd + actual_total_cost, used_usd=User.used_usd + total_cost,
total_usd=User.total_usd + actual_total_cost, total_usd=User.total_usd + total_cost,
updated_at=func.now(), updated_at=func.now(),
) )
) )
# 更新API密钥使用量原子操作- 使用实际费用(免费套餐为 0 # 更新API密钥使用量原子操作- 使用标准计费价格
if api_key: if api_key:
# 独立余额Key需要扣除余额 # 独立余额Key需要扣除余额
if api_key.is_standalone: if api_key.is_standalone:
@@ -725,8 +725,8 @@ class UsageService:
.where(ApiKey.id == api_key.id) .where(ApiKey.id == api_key.id)
.values( .values(
total_requests=ApiKey.total_requests + 1, total_requests=ApiKey.total_requests + 1,
total_cost_usd=ApiKey.total_cost_usd + actual_total_cost, total_cost_usd=ApiKey.total_cost_usd + total_cost,
balance_used_usd=ApiKey.balance_used_usd + actual_total_cost, balance_used_usd=ApiKey.balance_used_usd + total_cost,
last_used_at=func.now(), last_used_at=func.now(),
updated_at=func.now(), updated_at=func.now(),
) )
@@ -738,7 +738,7 @@ class UsageService:
.where(ApiKey.id == api_key.id) .where(ApiKey.id == api_key.id)
.values( .values(
total_requests=ApiKey.total_requests + 1, total_requests=ApiKey.total_requests + 1,
total_cost_usd=ApiKey.total_cost_usd + actual_total_cost, total_cost_usd=ApiKey.total_cost_usd + total_cost,
last_used_at=func.now(), last_used_at=func.now(),
updated_at=func.now(), updated_at=func.now(),
) )
@@ -837,8 +837,10 @@ class UsageService:
return True, "OK" return True, "OK"
# 有配额限制,检查是否超额 # 有配额限制,检查是否超额
if user.used_usd + estimated_cost > user.quota_usd: used_usd = float(user.used_usd or 0)
remaining = user.quota_usd - user.used_usd quota_usd = float(user.quota_usd)
if used_usd + estimated_cost > quota_usd:
remaining = quota_usd - used_usd
return False, f"配额不足(剩余: ${remaining:.2f}" return False, f"配额不足(剩余: ${remaining:.2f}"
return True, "OK" return True, "OK"
@@ -871,7 +873,8 @@ class UsageService:
from src.utils.database_helpers import date_trunc_portable from src.utils.database_helpers import date_trunc_portable
# 检测数据库方言 # 检测数据库方言
dialect = db.bind.dialect.name bind = db.bind
dialect = bind.dialect.name if bind is not None else "sqlite"
# 根据分组类型选择日期函数(兼容多种数据库) # 根据分组类型选择日期函数(兼容多种数据库)
if group_by == "day": if group_by == "day":
@@ -976,7 +979,7 @@ class UsageService:
query = query.group_by(day_bucket).order_by(day_bucket) query = query.group_by(day_bucket).order_by(day_bucket)
rows = query.all() rows = query.all()
def normalize_period(value) -> str: def normalize_period(value: Any) -> str:
if value is None: if value is None:
return "" return ""
if isinstance(value, str): if isinstance(value, str):
@@ -1500,8 +1503,8 @@ class UsageService:
users = db.query(User).filter(User.id.in_(group_ids)).all() users = db.query(User).filter(User.id.in_(group_ids)).all()
for user in users: for user in users:
user_info_map[str(user.id)] = { user_info_map[str(user.id)] = {
"username": user.username, "username": str(user.username),
"email": user.email or "", "email": str(user.email) if user.email else "",
} }
# 处理结果 # 处理结果
@@ -1679,6 +1682,14 @@ class UsageService:
result = query.first() result = query.first()
if result is None:
total_requests = 0
total_input_tokens = 0
total_cache_read_tokens = 0
total_cache_creation_tokens = 0
total_cache_read_cost = 0.0
total_cache_creation_cost = 0.0
else:
total_requests = result.total_requests or 0 total_requests = result.total_requests or 0
total_input_tokens = result.total_input_tokens or 0 total_input_tokens = result.total_input_tokens or 0
total_cache_read_tokens = result.total_cache_read_tokens or 0 total_cache_read_tokens = result.total_cache_read_tokens or 0
@@ -1711,13 +1722,13 @@ class UsageService:
requests_with_cache_hit = requests_with_cache_hit.filter(Usage.user_id == user_id) requests_with_cache_hit = requests_with_cache_hit.filter(Usage.user_id == user_id)
if api_key_id: if api_key_id:
requests_with_cache_hit = requests_with_cache_hit.filter(Usage.api_key_id == api_key_id) requests_with_cache_hit = requests_with_cache_hit.filter(Usage.api_key_id == api_key_id)
requests_with_cache_hit = requests_with_cache_hit.scalar() or 0 requests_with_cache_hit_count = int(requests_with_cache_hit.scalar() or 0)
return { return {
"analysis_period_hours": hours, "analysis_period_hours": hours,
"total_requests": total_requests, "total_requests": total_requests,
"requests_with_cache_hit": requests_with_cache_hit, "requests_with_cache_hit": requests_with_cache_hit_count,
"request_cache_hit_rate": round(requests_with_cache_hit / total_requests * 100, 2) if total_requests > 0 else 0, "request_cache_hit_rate": round(requests_with_cache_hit_count / total_requests * 100, 2) if total_requests > 0 else 0,
"total_input_tokens": total_input_tokens, "total_input_tokens": total_input_tokens,
"total_cache_read_tokens": total_cache_read_tokens, "total_cache_read_tokens": total_cache_read_tokens,
"total_cache_creation_tokens": total_cache_creation_tokens, "total_cache_creation_tokens": total_cache_creation_tokens,

View File

@@ -2,11 +2,12 @@
数据库方言兼容性辅助函数 数据库方言兼容性辅助函数
""" """
from typing import Any
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.sql.elements import ClauseElement
def date_trunc_portable(dialect_name: str, interval: str, column) -> ClauseElement: def date_trunc_portable(dialect_name: str, interval: str, column: Any) -> Any:
""" """
跨数据库的日期截断函数 跨数据库的日期截断函数