refactor(backend): optimize usage service and database helpers

This commit is contained in:
fawney19
2025-12-13 22:27:00 +08:00
parent f54127cba5
commit 77613795ed
3 changed files with 46 additions and 27 deletions

View File

@@ -97,6 +97,13 @@ disallow_untyped_defs = true
exclude = "tools/debug"
# 忽略项目内部模块的 import-untyped 警告
ignore_missing_imports = true
# SQLAlchemy mypy 插件
plugins = ["sqlalchemy.ext.mypy.plugin"]
# SQLAlchemy 相关模块放宽类型检查(模型未使用 Mapped 注解)
[[tool.mypy.overrides]]
module = ["src.services.usage.service", "src.models.database"]
disable_error_code = ["arg-type", "assignment"]
[tool.pytest.ini_options]
testpaths = ["tests"]

View File

@@ -4,7 +4,7 @@
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -53,7 +53,7 @@ class UsageService:
@classmethod
async def get_cache_prices_async(
cls, db: Session, provider: str, model: str, input_price: float
) -> tuple[float, float]:
) -> Tuple[Optional[float], Optional[float]]:
"""异步获取模型缓存价格缓存创建价格缓存读取价格每1M tokens"""
service = ModelCostService(db)
return await service.get_cache_prices_async(provider, model, input_price)
@@ -61,7 +61,7 @@ class UsageService:
@classmethod
def get_cache_prices(
cls, db: Session, provider: str, model: str, input_price: float
) -> tuple[float, float]:
) -> Tuple[Optional[float], Optional[float]]:
"""获取模型缓存价格缓存创建价格缓存读取价格每1M tokens"""
service = ModelCostService(db)
return service.get_cache_prices(provider, model, input_price)
@@ -703,20 +703,20 @@ class UsageService:
from src.models.database import ApiKey, User
# 更新用户使用量(原子操作)- 使用实际费用(免费套餐为 0
# 更新用户使用量(原子操作)- 使用标准计费价格
# 独立Key不计入创建者的使用记录
if user and not (api_key and api_key.is_standalone):
db.execute(
update(User)
.where(User.id == user.id)
.values(
used_usd=User.used_usd + actual_total_cost,
total_usd=User.total_usd + actual_total_cost,
used_usd=User.used_usd + total_cost,
total_usd=User.total_usd + total_cost,
updated_at=func.now(),
)
)
# 更新API密钥使用量原子操作- 使用实际费用(免费套餐为 0
# 更新API密钥使用量原子操作- 使用标准计费价格
if api_key:
# 独立余额Key需要扣除余额
if api_key.is_standalone:
@@ -725,8 +725,8 @@ class UsageService:
.where(ApiKey.id == api_key.id)
.values(
total_requests=ApiKey.total_requests + 1,
total_cost_usd=ApiKey.total_cost_usd + actual_total_cost,
balance_used_usd=ApiKey.balance_used_usd + actual_total_cost,
total_cost_usd=ApiKey.total_cost_usd + total_cost,
balance_used_usd=ApiKey.balance_used_usd + total_cost,
last_used_at=func.now(),
updated_at=func.now(),
)
@@ -738,7 +738,7 @@ class UsageService:
.where(ApiKey.id == api_key.id)
.values(
total_requests=ApiKey.total_requests + 1,
total_cost_usd=ApiKey.total_cost_usd + actual_total_cost,
total_cost_usd=ApiKey.total_cost_usd + total_cost,
last_used_at=func.now(),
updated_at=func.now(),
)
@@ -837,8 +837,10 @@ class UsageService:
return True, "OK"
# 有配额限制,检查是否超额
if user.used_usd + estimated_cost > user.quota_usd:
remaining = user.quota_usd - user.used_usd
used_usd = float(user.used_usd or 0)
quota_usd = float(user.quota_usd)
if used_usd + estimated_cost > quota_usd:
remaining = quota_usd - used_usd
return False, f"配额不足(剩余: ${remaining:.2f}"
return True, "OK"
@@ -871,7 +873,8 @@ class UsageService:
from src.utils.database_helpers import date_trunc_portable
# 检测数据库方言
dialect = db.bind.dialect.name
bind = db.bind
dialect = bind.dialect.name if bind is not None else "sqlite"
# 根据分组类型选择日期函数(兼容多种数据库)
if group_by == "day":
@@ -976,7 +979,7 @@ class UsageService:
query = query.group_by(day_bucket).order_by(day_bucket)
rows = query.all()
def normalize_period(value) -> str:
def normalize_period(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
@@ -1500,8 +1503,8 @@ class UsageService:
users = db.query(User).filter(User.id.in_(group_ids)).all()
for user in users:
user_info_map[str(user.id)] = {
"username": user.username,
"email": user.email or "",
"username": str(user.username),
"email": str(user.email) if user.email else "",
}
# 处理结果
@@ -1679,6 +1682,14 @@ class UsageService:
result = query.first()
if result is None:
total_requests = 0
total_input_tokens = 0
total_cache_read_tokens = 0
total_cache_creation_tokens = 0
total_cache_read_cost = 0.0
total_cache_creation_cost = 0.0
else:
total_requests = result.total_requests or 0
total_input_tokens = result.total_input_tokens or 0
total_cache_read_tokens = result.total_cache_read_tokens or 0
@@ -1711,13 +1722,13 @@ class UsageService:
requests_with_cache_hit = requests_with_cache_hit.filter(Usage.user_id == user_id)
if api_key_id:
requests_with_cache_hit = requests_with_cache_hit.filter(Usage.api_key_id == api_key_id)
requests_with_cache_hit = requests_with_cache_hit.scalar() or 0
requests_with_cache_hit_count = int(requests_with_cache_hit.scalar() or 0)
return {
"analysis_period_hours": hours,
"total_requests": total_requests,
"requests_with_cache_hit": requests_with_cache_hit,
"request_cache_hit_rate": round(requests_with_cache_hit / total_requests * 100, 2) if total_requests > 0 else 0,
"requests_with_cache_hit": requests_with_cache_hit_count,
"request_cache_hit_rate": round(requests_with_cache_hit_count / total_requests * 100, 2) if total_requests > 0 else 0,
"total_input_tokens": total_input_tokens,
"total_cache_read_tokens": total_cache_read_tokens,
"total_cache_creation_tokens": total_cache_creation_tokens,

View File

@@ -2,11 +2,12 @@
数据库方言兼容性辅助函数
"""
from typing import Any
from sqlalchemy import func
from sqlalchemy.sql.elements import ClauseElement
def date_trunc_portable(dialect_name: str, interval: str, column) -> ClauseElement:
def date_trunc_portable(dialect_name: str, interval: str, column: Any) -> Any:
"""
跨数据库的日期截断函数