mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
refactor(backend): optimize usage service and database helpers
This commit is contained in:
@@ -97,6 +97,13 @@ disallow_untyped_defs = true
|
||||
exclude = "tools/debug"
|
||||
# 忽略项目内部模块的 import-untyped 警告
|
||||
ignore_missing_imports = true
|
||||
# SQLAlchemy mypy 插件
|
||||
plugins = ["sqlalchemy.ext.mypy.plugin"]
|
||||
|
||||
# SQLAlchemy 相关模块放宽类型检查(模型未使用 Mapped 注解)
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["src.services.usage.service", "src.models.database"]
|
||||
disable_error_code = ["arg-type", "assignment"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -53,7 +53,7 @@ class UsageService:
|
||||
@classmethod
|
||||
async def get_cache_prices_async(
|
||||
cls, db: Session, provider: str, model: str, input_price: float
|
||||
) -> tuple[float, float]:
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""异步获取模型缓存价格(缓存创建价格,缓存读取价格)每1M tokens"""
|
||||
service = ModelCostService(db)
|
||||
return await service.get_cache_prices_async(provider, model, input_price)
|
||||
@@ -61,7 +61,7 @@ class UsageService:
|
||||
@classmethod
|
||||
def get_cache_prices(
|
||||
cls, db: Session, provider: str, model: str, input_price: float
|
||||
) -> tuple[float, float]:
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""获取模型缓存价格(缓存创建价格,缓存读取价格)每1M tokens"""
|
||||
service = ModelCostService(db)
|
||||
return service.get_cache_prices(provider, model, input_price)
|
||||
@@ -703,20 +703,20 @@ class UsageService:
|
||||
|
||||
from src.models.database import ApiKey, User
|
||||
|
||||
# 更新用户使用量(原子操作)- 使用实际费用(免费套餐为 0)
|
||||
# 更新用户使用量(原子操作)- 使用标准计费价格
|
||||
# 独立Key不计入创建者的使用记录
|
||||
if user and not (api_key and api_key.is_standalone):
|
||||
db.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id)
|
||||
.values(
|
||||
used_usd=User.used_usd + actual_total_cost,
|
||||
total_usd=User.total_usd + actual_total_cost,
|
||||
used_usd=User.used_usd + total_cost,
|
||||
total_usd=User.total_usd + total_cost,
|
||||
updated_at=func.now(),
|
||||
)
|
||||
)
|
||||
|
||||
# 更新API密钥使用量(原子操作)- 使用实际费用(免费套餐为 0)
|
||||
# 更新API密钥使用量(原子操作)- 使用标准计费价格
|
||||
if api_key:
|
||||
# 独立余额Key需要扣除余额
|
||||
if api_key.is_standalone:
|
||||
@@ -725,8 +725,8 @@ class UsageService:
|
||||
.where(ApiKey.id == api_key.id)
|
||||
.values(
|
||||
total_requests=ApiKey.total_requests + 1,
|
||||
total_cost_usd=ApiKey.total_cost_usd + actual_total_cost,
|
||||
balance_used_usd=ApiKey.balance_used_usd + actual_total_cost,
|
||||
total_cost_usd=ApiKey.total_cost_usd + total_cost,
|
||||
balance_used_usd=ApiKey.balance_used_usd + total_cost,
|
||||
last_used_at=func.now(),
|
||||
updated_at=func.now(),
|
||||
)
|
||||
@@ -738,7 +738,7 @@ class UsageService:
|
||||
.where(ApiKey.id == api_key.id)
|
||||
.values(
|
||||
total_requests=ApiKey.total_requests + 1,
|
||||
total_cost_usd=ApiKey.total_cost_usd + actual_total_cost,
|
||||
total_cost_usd=ApiKey.total_cost_usd + total_cost,
|
||||
last_used_at=func.now(),
|
||||
updated_at=func.now(),
|
||||
)
|
||||
@@ -837,8 +837,10 @@ class UsageService:
|
||||
return True, "OK"
|
||||
|
||||
# 有配额限制,检查是否超额
|
||||
if user.used_usd + estimated_cost > user.quota_usd:
|
||||
remaining = user.quota_usd - user.used_usd
|
||||
used_usd = float(user.used_usd or 0)
|
||||
quota_usd = float(user.quota_usd)
|
||||
if used_usd + estimated_cost > quota_usd:
|
||||
remaining = quota_usd - used_usd
|
||||
return False, f"配额不足(剩余: ${remaining:.2f})"
|
||||
|
||||
return True, "OK"
|
||||
@@ -871,7 +873,8 @@ class UsageService:
|
||||
from src.utils.database_helpers import date_trunc_portable
|
||||
|
||||
# 检测数据库方言
|
||||
dialect = db.bind.dialect.name
|
||||
bind = db.bind
|
||||
dialect = bind.dialect.name if bind is not None else "sqlite"
|
||||
|
||||
# 根据分组类型选择日期函数(兼容多种数据库)
|
||||
if group_by == "day":
|
||||
@@ -976,7 +979,7 @@ class UsageService:
|
||||
query = query.group_by(day_bucket).order_by(day_bucket)
|
||||
rows = query.all()
|
||||
|
||||
def normalize_period(value) -> str:
|
||||
def normalize_period(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
@@ -1500,8 +1503,8 @@ class UsageService:
|
||||
users = db.query(User).filter(User.id.in_(group_ids)).all()
|
||||
for user in users:
|
||||
user_info_map[str(user.id)] = {
|
||||
"username": user.username,
|
||||
"email": user.email or "",
|
||||
"username": str(user.username),
|
||||
"email": str(user.email) if user.email else "",
|
||||
}
|
||||
|
||||
# 处理结果
|
||||
@@ -1679,6 +1682,14 @@ class UsageService:
|
||||
|
||||
result = query.first()
|
||||
|
||||
if result is None:
|
||||
total_requests = 0
|
||||
total_input_tokens = 0
|
||||
total_cache_read_tokens = 0
|
||||
total_cache_creation_tokens = 0
|
||||
total_cache_read_cost = 0.0
|
||||
total_cache_creation_cost = 0.0
|
||||
else:
|
||||
total_requests = result.total_requests or 0
|
||||
total_input_tokens = result.total_input_tokens or 0
|
||||
total_cache_read_tokens = result.total_cache_read_tokens or 0
|
||||
@@ -1711,13 +1722,13 @@ class UsageService:
|
||||
requests_with_cache_hit = requests_with_cache_hit.filter(Usage.user_id == user_id)
|
||||
if api_key_id:
|
||||
requests_with_cache_hit = requests_with_cache_hit.filter(Usage.api_key_id == api_key_id)
|
||||
requests_with_cache_hit = requests_with_cache_hit.scalar() or 0
|
||||
requests_with_cache_hit_count = int(requests_with_cache_hit.scalar() or 0)
|
||||
|
||||
return {
|
||||
"analysis_period_hours": hours,
|
||||
"total_requests": total_requests,
|
||||
"requests_with_cache_hit": requests_with_cache_hit,
|
||||
"request_cache_hit_rate": round(requests_with_cache_hit / total_requests * 100, 2) if total_requests > 0 else 0,
|
||||
"requests_with_cache_hit": requests_with_cache_hit_count,
|
||||
"request_cache_hit_rate": round(requests_with_cache_hit_count / total_requests * 100, 2) if total_requests > 0 else 0,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_cache_read_tokens": total_cache_read_tokens,
|
||||
"total_cache_creation_tokens": total_cache_creation_tokens,
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
数据库方言兼容性辅助函数
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.sql.elements import ClauseElement
|
||||
|
||||
|
||||
def date_trunc_portable(dialect_name: str, interval: str, column) -> ClauseElement:
|
||||
def date_trunc_portable(dialect_name: str, interval: str, column: Any) -> Any:
|
||||
"""
|
||||
跨数据库的日期截断函数
|
||||
|
||||
|
||||
Reference in New Issue
Block a user