Files
Aether/src/services/user/service.py
2025-12-10 20:52:44 +08:00

434 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
用户管理服务
"""
import asyncio
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
from sqlalchemy import and_, func
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.core.validators import EmailValidator, PasswordValidator, UsernameValidator
from src.models.database import ApiKey, GlobalModel, Model, Provider, Usage, User, UserRole
from src.services.cache.user_cache import UserCacheService
from src.utils.transaction_manager import retry_on_database_error, transactional
class UserService:
"""用户管理服务"""
@staticmethod
@transactional()
@retry_on_database_error(max_retries=3)
def create_user(
db: Session,
email: str,
username: str,
password: str,
role: UserRole = UserRole.USER,
quota_usd: Optional[float] = 10.0,
) -> User:
"""创建新用户quota_usd 为 None 表示无限制"""
# 验证邮箱格式
valid, error_msg = EmailValidator.validate(email)
if not valid:
raise ValueError(error_msg)
# 验证用户名格式
valid, error_msg = UsernameValidator.validate(username)
if not valid:
raise ValueError(error_msg)
# 验证密码复杂度
valid, error_msg = PasswordValidator.validate(password)
if not valid:
raise ValueError(error_msg)
# 检查邮箱是否已存在
if db.query(User).filter(User.email == email).first():
raise ValueError(f"邮箱已存在: {email}")
# 检查用户名是否已存在
if db.query(User).filter(User.username == username).first():
raise ValueError(f"用户名已存在: {username}")
user = User(
email=email,
username=username,
role=role,
quota_usd=quota_usd,
is_active=True,
)
user.set_password(password)
db.add(user)
db.commit() # 立即提交事务,释放数据库锁
db.refresh(user)
logger.info(f"创建新用户: {email} (ID: {user.id}, 角色: {role.value})")
return user
@staticmethod
@transactional()
def create_user_with_api_key(
db: Session,
email: str,
username: str,
password: str,
api_key_name: str = "默认密钥",
role: UserRole = UserRole.USER,
quota_usd: Optional[float] = 10.0,
concurrent_limit: int = 5,
) -> tuple[User, ApiKey]:
"""
创建用户并同时创建API密钥原子操作
Args:
db: 数据库会话
email: 邮箱
username: 用户名
password: 密码
api_key_name: API密钥名称
role: 用户角色
quota_usd: USD配额None 表示无限制
concurrent_limit: 并发限制
Returns:
tuple[User, ApiKey]: 用户对象和API密钥对象
Raises:
ValueError: 当验证失败时
"""
# 创建用户
user = UserService.create_user(
db=db, email=email, username=username, password=password, role=role, quota_usd=quota_usd
)
# 导入API密钥服务避免循环导入
from .apikey import ApiKeyService
# 创建API密钥返回值是 (api_key, plain_key)
api_key, plain_key = ApiKeyService.create_api_key(
db=db, user_id=user.id, name=api_key_name, concurrent_limit=concurrent_limit
)
logger.info(f"创建用户和API密钥完成: {email} (用户ID: {user.id}, 密钥ID: {api_key.id})")
# 返回用户对象、API Key对象和明文密钥
return user, api_key, plain_key
@staticmethod
def get_user(db: Session, user_id: str) -> Optional[User]:
"""获取用户"""
import random
import time
# 添加重试机制处理数据库并发问题
max_retries = 3
for attempt in range(max_retries):
try:
user = db.query(User).filter(User.id == user_id).first()
return user
except Exception as e:
if attempt < max_retries - 1:
# 添加随机延迟避免并发冲突
time.sleep(random.uniform(0.01, 0.05))
db.rollback() # 回滚事务
continue
else:
raise e
@staticmethod
def get_user_by_email(db: Session, email: str) -> Optional[User]:
"""通过邮箱获取用户"""
return db.query(User).filter(User.email == email).first()
@staticmethod
def list_users(
db: Session,
skip: int = 0,
limit: int = 100,
role: Optional[UserRole] = None,
is_active: Optional[bool] = None,
) -> List[User]:
"""列出用户"""
query = db.query(User)
if role:
query = query.filter(User.role == role)
if is_active is not None:
query = query.filter(User.is_active == is_active)
return query.offset(skip).limit(limit).all()
@staticmethod
@transactional()
def update_user(db: Session, user_id: str, **kwargs) -> Optional[User]:
"""更新用户信息"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
return None
# 可更新的字段
updatable_fields = [
"email",
"username",
"quota_usd",
"is_active",
"role",
# 访问限制字段
"allowed_providers",
"allowed_endpoints",
"allowed_models",
]
# 允许设置为 None 的字段(表示无限制)
nullable_fields = ["quota_usd", "allowed_providers", "allowed_endpoints", "allowed_models"]
for field, value in kwargs.items():
if field not in updatable_fields:
continue
# nullable_fields 中的字段允许设置为 None
if field in nullable_fields:
setattr(user, field, value)
elif value is not None:
setattr(user, field, value)
# 如果提供了新密码
if "password" in kwargs and kwargs["password"]:
# 验证新密码复杂度
valid, error_msg = PasswordValidator.validate(kwargs["password"])
if not valid:
raise ValueError(error_msg)
user.set_password(kwargs["password"])
user.updated_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁
db.refresh(user)
# 清除用户缓存
asyncio.create_task(UserCacheService.invalidate_user_cache(user.id, user.email))
logger.debug(f"更新用户信息: {user.email} (ID: {user_id})")
return user
@staticmethod
@transactional()
def delete_user(db: Session, user_id: str) -> bool:
"""删除用户(硬删除)
删除流程:
1. 手动删除关联的子记录(避免 SQLAlchemy ORM 与数据库 CASCADE 冲突)
2. 删除用户记录
3. 历史 Usage 记录保留user_id 会被数据库设为 NULL
4. 新用户注册时会有新的 UUID看不到旧用户的记录
"""
from src.models.database import (
AnnouncementRead,
ApiKey,
UserPreference,
UserQuota,
)
user = db.query(User).filter(User.id == user_id).first()
if not user:
return False
# 记录删除信息用于日志
email = user.email
# 手动删除子记录,避免 SQLAlchemy 的 ORM cascade 与数据库 CASCADE 冲突
# 这些表的数据库外键已经设置了 ON DELETE CASCADE但 SQLAlchemy 会先尝试 UPDATE 设置为 NULL
# 所以我们手动删除来避免这个问题
db.query(UserPreference).filter(UserPreference.user_id == user_id).delete(
synchronize_session=False
)
db.query(UserQuota).filter(UserQuota.user_id == user_id).delete(synchronize_session=False)
db.query(AnnouncementRead).filter(AnnouncementRead.user_id == user_id).delete(
synchronize_session=False
)
api_key_count = db.query(ApiKey).filter(ApiKey.user_id == user_id).count()
db.query(ApiKey).filter(ApiKey.user_id == user_id).delete(synchronize_session=False)
# 现在删除用户Usage, AuditLog, RequestAttempt 会通过数据库 SET NULL 保留)
db.delete(user)
db.commit() # 立即提交事务,释放数据库锁
# 清除用户缓存
asyncio.create_task(UserCacheService.invalidate_user_cache(user_id, email))
logger.info(f"删除用户: {email} (ID: {user_id}), 同时删除 {api_key_count} 个API密钥")
return True
@staticmethod
@transactional()
def change_password(
db: Session, user_id: str, old_password: str, new_password: str
) -> tuple[bool, str]:
"""
更改用户密码
Args:
db: 数据库会话
user_id: 用户ID
old_password: 旧密码
new_password: 新密码
Returns:
(是否成功, 消息)
"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
return False, "用户不存在"
# 验证旧密码
if not user.verify_password(old_password):
logger.warning(f"密码更改失败 - 旧密码错误: 用户ID {user_id}")
return False, "旧密码错误"
# 验证新密码复杂度
valid, error_msg = PasswordValidator.validate(new_password)
if not valid:
return False, error_msg
# 检查新密码不能与旧密码相同
if old_password == new_password:
return False, "新密码不能与旧密码相同"
# 设置新密码
user.set_password(new_password)
user.updated_at = datetime.now(timezone.utc)
# 清除用户缓存
asyncio.create_task(UserCacheService.invalidate_user_cache(user.id, user.email))
logger.info(f"密码更改成功: 用户ID {user_id}")
return True, "密码更改成功"
@staticmethod
def update_user_quota(
db: Session,
user_id: str,
quota_usd: Optional[float] = None,
) -> Optional[User]:
"""更新用户配额"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
return None
if quota_usd is not None:
user.quota_usd = quota_usd
db.commit()
db.refresh(user)
# 清除用户缓存
asyncio.create_task(UserCacheService.invalidate_user_cache(user.id, user.email))
logger.debug(f"更新用户配额: {user.email} (USD: {quota_usd})")
return user
@staticmethod
def get_user_usage_stats(
db: Session,
user_id: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> Dict[str, Any]:
"""获取用户使用统计"""
query = db.query(Usage).filter(Usage.user_id == user_id)
if start_date:
query = query.filter(Usage.created_at >= start_date)
if end_date:
query = query.filter(Usage.created_at <= end_date)
# 统计数据
stats = db.query(
func.count(Usage.id).label("total_requests"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
func.avg(Usage.response_time_ms).label("avg_response_time"),
).filter(Usage.user_id == user_id)
if start_date:
stats = stats.filter(Usage.created_at >= start_date)
if end_date:
stats = stats.filter(Usage.created_at <= end_date)
result = stats.first()
# 按模型分组统计
model_stats = db.query(
Usage.model,
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("tokens"),
func.sum(Usage.total_cost_usd).label("cost_usd"),
).filter(Usage.user_id == user_id)
if start_date:
model_stats = model_stats.filter(Usage.created_at >= start_date)
if end_date:
model_stats = model_stats.filter(Usage.created_at <= end_date)
model_stats = model_stats.group_by(Usage.model).all()
return {
"total_requests": result.total_requests or 0,
"total_tokens": result.total_tokens or 0,
"total_cost_usd": float(result.total_cost_usd or 0),
"avg_response_time_ms": float(result.avg_response_time or 0),
"by_model": [
{
"model": stat.model,
"requests": stat.requests,
"tokens": stat.tokens,
"cost_usd": float(stat.cost_usd),
}
for stat in model_stats
],
}
@staticmethod
def get_user_available_models(db: Session, user: User) -> List[Model]:
"""获取用户可用的模型
新架构:通过 GlobalModel + Model 关联查询用户可用模型
逻辑:用户可用提供商 → Provider 的 Model 实现 → 关联的 GlobalModel
"""
# 获取用户可用的提供商
if user.role == UserRole.ADMIN:
# 管理员可以使用所有活动提供商
provider_ids = [
p.id for p in db.query(Provider.id).filter(Provider.is_active == True).all()
]
else:
# 普通用户使用关联的提供商
provider_ids = [p.id for p in user.providers]
if not provider_ids:
return []
# 查询这些提供商的所有活跃 Model关联 GlobalModel
models = (
db.query(Model)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
and_(
Model.provider_id.in_(provider_ids),
Model.is_active == True,
GlobalModel.is_active == True,
)
)
.all()
)
logger.debug(f"用户 {user.email} 可用模型: {len(models)} 个 (提供商数: {len(provider_ids)})")
return models