mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-10 03:32:26 +08:00
434 lines
14 KiB
Python
434 lines
14 KiB
Python
"""
|
||
用户管理服务
|
||
"""
|
||
|
||
import asyncio
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from sqlalchemy import and_, func
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.core.logger import logger
|
||
from src.core.validators import EmailValidator, PasswordValidator, UsernameValidator
|
||
from src.models.database import ApiKey, GlobalModel, Model, Provider, Usage, User, UserRole
|
||
from src.services.cache.user_cache import UserCacheService
|
||
from src.utils.transaction_manager import retry_on_database_error, transactional
|
||
|
||
|
||
|
||
class UserService:
|
||
"""用户管理服务"""
|
||
|
||
@staticmethod
|
||
@transactional()
|
||
@retry_on_database_error(max_retries=3)
|
||
def create_user(
|
||
db: Session,
|
||
email: str,
|
||
username: str,
|
||
password: str,
|
||
role: UserRole = UserRole.USER,
|
||
quota_usd: Optional[float] = 10.0,
|
||
) -> User:
|
||
"""创建新用户,quota_usd 为 None 表示无限制"""
|
||
|
||
# 验证邮箱格式
|
||
valid, error_msg = EmailValidator.validate(email)
|
||
if not valid:
|
||
raise ValueError(error_msg)
|
||
|
||
# 验证用户名格式
|
||
valid, error_msg = UsernameValidator.validate(username)
|
||
if not valid:
|
||
raise ValueError(error_msg)
|
||
|
||
# 验证密码复杂度
|
||
valid, error_msg = PasswordValidator.validate(password)
|
||
if not valid:
|
||
raise ValueError(error_msg)
|
||
|
||
# 检查邮箱是否已存在
|
||
if db.query(User).filter(User.email == email).first():
|
||
raise ValueError(f"邮箱已存在: {email}")
|
||
|
||
# 检查用户名是否已存在
|
||
if db.query(User).filter(User.username == username).first():
|
||
raise ValueError(f"用户名已存在: {username}")
|
||
|
||
user = User(
|
||
email=email,
|
||
username=username,
|
||
role=role,
|
||
quota_usd=quota_usd,
|
||
is_active=True,
|
||
)
|
||
user.set_password(password)
|
||
|
||
db.add(user)
|
||
db.commit() # 立即提交事务,释放数据库锁
|
||
db.refresh(user)
|
||
|
||
logger.info(f"创建新用户: {email} (ID: {user.id}, 角色: {role.value})")
|
||
return user
|
||
|
||
@staticmethod
|
||
@transactional()
|
||
def create_user_with_api_key(
|
||
db: Session,
|
||
email: str,
|
||
username: str,
|
||
password: str,
|
||
api_key_name: str = "默认密钥",
|
||
role: UserRole = UserRole.USER,
|
||
quota_usd: Optional[float] = 10.0,
|
||
concurrent_limit: int = 5,
|
||
) -> tuple[User, ApiKey]:
|
||
"""
|
||
创建用户并同时创建API密钥(原子操作)
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
email: 邮箱
|
||
username: 用户名
|
||
password: 密码
|
||
api_key_name: API密钥名称
|
||
role: 用户角色
|
||
quota_usd: USD配额,None 表示无限制
|
||
concurrent_limit: 并发限制
|
||
|
||
Returns:
|
||
tuple[User, ApiKey]: 用户对象和API密钥对象
|
||
|
||
Raises:
|
||
ValueError: 当验证失败时
|
||
"""
|
||
# 创建用户
|
||
user = UserService.create_user(
|
||
db=db, email=email, username=username, password=password, role=role, quota_usd=quota_usd
|
||
)
|
||
|
||
# 导入API密钥服务(避免循环导入)
|
||
from .apikey import ApiKeyService
|
||
|
||
# 创建API密钥(返回值是 (api_key, plain_key))
|
||
api_key, plain_key = ApiKeyService.create_api_key(
|
||
db=db, user_id=user.id, name=api_key_name, concurrent_limit=concurrent_limit
|
||
)
|
||
|
||
logger.info(f"创建用户和API密钥完成: {email} (用户ID: {user.id}, 密钥ID: {api_key.id})")
|
||
|
||
# 返回用户对象、API Key对象和明文密钥
|
||
return user, api_key, plain_key
|
||
|
||
@staticmethod
|
||
def get_user(db: Session, user_id: str) -> Optional[User]:
|
||
"""获取用户"""
|
||
import random
|
||
import time
|
||
|
||
# 添加重试机制处理数据库并发问题
|
||
max_retries = 3
|
||
for attempt in range(max_retries):
|
||
try:
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
return user
|
||
except Exception as e:
|
||
if attempt < max_retries - 1:
|
||
# 添加随机延迟避免并发冲突
|
||
time.sleep(random.uniform(0.01, 0.05))
|
||
db.rollback() # 回滚事务
|
||
continue
|
||
else:
|
||
raise e
|
||
|
||
@staticmethod
|
||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||
"""通过邮箱获取用户"""
|
||
return db.query(User).filter(User.email == email).first()
|
||
|
||
@staticmethod
|
||
def list_users(
|
||
db: Session,
|
||
skip: int = 0,
|
||
limit: int = 100,
|
||
role: Optional[UserRole] = None,
|
||
is_active: Optional[bool] = None,
|
||
) -> List[User]:
|
||
"""列出用户"""
|
||
query = db.query(User)
|
||
|
||
if role:
|
||
query = query.filter(User.role == role)
|
||
if is_active is not None:
|
||
query = query.filter(User.is_active == is_active)
|
||
|
||
return query.offset(skip).limit(limit).all()
|
||
|
||
@staticmethod
|
||
@transactional()
|
||
def update_user(db: Session, user_id: str, **kwargs) -> Optional[User]:
|
||
"""更新用户信息"""
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if not user:
|
||
return None
|
||
|
||
# 可更新的字段
|
||
updatable_fields = [
|
||
"email",
|
||
"username",
|
||
"quota_usd",
|
||
"is_active",
|
||
"role",
|
||
# 访问限制字段
|
||
"allowed_providers",
|
||
"allowed_endpoints",
|
||
"allowed_models",
|
||
]
|
||
|
||
# 允许设置为 None 的字段(表示无限制)
|
||
nullable_fields = ["quota_usd", "allowed_providers", "allowed_endpoints", "allowed_models"]
|
||
|
||
for field, value in kwargs.items():
|
||
if field not in updatable_fields:
|
||
continue
|
||
# nullable_fields 中的字段允许设置为 None
|
||
if field in nullable_fields:
|
||
setattr(user, field, value)
|
||
elif value is not None:
|
||
setattr(user, field, value)
|
||
|
||
# 如果提供了新密码
|
||
if "password" in kwargs and kwargs["password"]:
|
||
# 验证新密码复杂度
|
||
valid, error_msg = PasswordValidator.validate(kwargs["password"])
|
||
if not valid:
|
||
raise ValueError(error_msg)
|
||
user.set_password(kwargs["password"])
|
||
|
||
user.updated_at = datetime.now(timezone.utc)
|
||
db.commit() # 立即提交事务,释放数据库锁
|
||
db.refresh(user)
|
||
|
||
# 清除用户缓存
|
||
asyncio.create_task(UserCacheService.invalidate_user_cache(user.id, user.email))
|
||
|
||
logger.debug(f"更新用户信息: {user.email} (ID: {user_id})")
|
||
return user
|
||
|
||
@staticmethod
|
||
@transactional()
|
||
def delete_user(db: Session, user_id: str) -> bool:
|
||
"""删除用户(硬删除)
|
||
|
||
删除流程:
|
||
1. 手动删除关联的子记录(避免 SQLAlchemy ORM 与数据库 CASCADE 冲突)
|
||
2. 删除用户记录
|
||
3. 历史 Usage 记录保留,user_id 会被数据库设为 NULL
|
||
4. 新用户注册时会有新的 UUID,看不到旧用户的记录
|
||
"""
|
||
from src.models.database import (
|
||
AnnouncementRead,
|
||
ApiKey,
|
||
UserPreference,
|
||
UserQuota,
|
||
)
|
||
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if not user:
|
||
return False
|
||
|
||
# 记录删除信息用于日志
|
||
email = user.email
|
||
|
||
# 手动删除子记录,避免 SQLAlchemy 的 ORM cascade 与数据库 CASCADE 冲突
|
||
# 这些表的数据库外键已经设置了 ON DELETE CASCADE,但 SQLAlchemy 会先尝试 UPDATE 设置为 NULL
|
||
# 所以我们手动删除来避免这个问题
|
||
db.query(UserPreference).filter(UserPreference.user_id == user_id).delete(
|
||
synchronize_session=False
|
||
)
|
||
db.query(UserQuota).filter(UserQuota.user_id == user_id).delete(synchronize_session=False)
|
||
db.query(AnnouncementRead).filter(AnnouncementRead.user_id == user_id).delete(
|
||
synchronize_session=False
|
||
)
|
||
api_key_count = db.query(ApiKey).filter(ApiKey.user_id == user_id).count()
|
||
db.query(ApiKey).filter(ApiKey.user_id == user_id).delete(synchronize_session=False)
|
||
|
||
# 现在删除用户(Usage, AuditLog, RequestAttempt 会通过数据库 SET NULL 保留)
|
||
db.delete(user)
|
||
db.commit() # 立即提交事务,释放数据库锁
|
||
|
||
# 清除用户缓存
|
||
asyncio.create_task(UserCacheService.invalidate_user_cache(user_id, email))
|
||
|
||
logger.info(f"删除用户: {email} (ID: {user_id}), 同时删除 {api_key_count} 个API密钥")
|
||
return True
|
||
|
||
@staticmethod
|
||
@transactional()
|
||
def change_password(
|
||
db: Session, user_id: str, old_password: str, new_password: str
|
||
) -> tuple[bool, str]:
|
||
"""
|
||
更改用户密码
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
user_id: 用户ID
|
||
old_password: 旧密码
|
||
new_password: 新密码
|
||
|
||
Returns:
|
||
(是否成功, 消息)
|
||
"""
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if not user:
|
||
return False, "用户不存在"
|
||
|
||
# 验证旧密码
|
||
if not user.verify_password(old_password):
|
||
logger.warning(f"密码更改失败 - 旧密码错误: 用户ID {user_id}")
|
||
return False, "旧密码错误"
|
||
|
||
# 验证新密码复杂度
|
||
valid, error_msg = PasswordValidator.validate(new_password)
|
||
if not valid:
|
||
return False, error_msg
|
||
|
||
# 检查新密码不能与旧密码相同
|
||
if old_password == new_password:
|
||
return False, "新密码不能与旧密码相同"
|
||
|
||
# 设置新密码
|
||
user.set_password(new_password)
|
||
user.updated_at = datetime.now(timezone.utc)
|
||
|
||
# 清除用户缓存
|
||
asyncio.create_task(UserCacheService.invalidate_user_cache(user.id, user.email))
|
||
|
||
logger.info(f"密码更改成功: 用户ID {user_id}")
|
||
return True, "密码更改成功"
|
||
|
||
@staticmethod
|
||
def update_user_quota(
|
||
db: Session,
|
||
user_id: str,
|
||
quota_usd: Optional[float] = None,
|
||
) -> Optional[User]:
|
||
"""更新用户配额"""
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if not user:
|
||
return None
|
||
|
||
if quota_usd is not None:
|
||
user.quota_usd = quota_usd
|
||
|
||
db.commit()
|
||
db.refresh(user)
|
||
|
||
# 清除用户缓存
|
||
asyncio.create_task(UserCacheService.invalidate_user_cache(user.id, user.email))
|
||
|
||
logger.debug(f"更新用户配额: {user.email} (USD: {quota_usd})")
|
||
return user
|
||
|
||
@staticmethod
|
||
def get_user_usage_stats(
|
||
db: Session,
|
||
user_id: str,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
) -> Dict[str, Any]:
|
||
"""获取用户使用统计"""
|
||
|
||
query = db.query(Usage).filter(Usage.user_id == user_id)
|
||
|
||
if start_date:
|
||
query = query.filter(Usage.created_at >= start_date)
|
||
if end_date:
|
||
query = query.filter(Usage.created_at <= end_date)
|
||
|
||
# 统计数据
|
||
stats = db.query(
|
||
func.count(Usage.id).label("total_requests"),
|
||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
|
||
func.avg(Usage.response_time_ms).label("avg_response_time"),
|
||
).filter(Usage.user_id == user_id)
|
||
|
||
if start_date:
|
||
stats = stats.filter(Usage.created_at >= start_date)
|
||
if end_date:
|
||
stats = stats.filter(Usage.created_at <= end_date)
|
||
|
||
result = stats.first()
|
||
|
||
# 按模型分组统计
|
||
model_stats = db.query(
|
||
Usage.model,
|
||
func.count(Usage.id).label("requests"),
|
||
func.sum(Usage.total_tokens).label("tokens"),
|
||
func.sum(Usage.total_cost_usd).label("cost_usd"),
|
||
).filter(Usage.user_id == user_id)
|
||
|
||
if start_date:
|
||
model_stats = model_stats.filter(Usage.created_at >= start_date)
|
||
if end_date:
|
||
model_stats = model_stats.filter(Usage.created_at <= end_date)
|
||
|
||
model_stats = model_stats.group_by(Usage.model).all()
|
||
|
||
return {
|
||
"total_requests": result.total_requests or 0,
|
||
"total_tokens": result.total_tokens or 0,
|
||
"total_cost_usd": float(result.total_cost_usd or 0),
|
||
"avg_response_time_ms": float(result.avg_response_time or 0),
|
||
"by_model": [
|
||
{
|
||
"model": stat.model,
|
||
"requests": stat.requests,
|
||
"tokens": stat.tokens,
|
||
"cost_usd": float(stat.cost_usd),
|
||
}
|
||
for stat in model_stats
|
||
],
|
||
}
|
||
|
||
@staticmethod
|
||
def get_user_available_models(db: Session, user: User) -> List[Model]:
|
||
"""获取用户可用的模型
|
||
|
||
新架构:通过 GlobalModel + Model 关联查询用户可用模型
|
||
逻辑:用户可用提供商 → Provider 的 Model 实现 → 关联的 GlobalModel
|
||
"""
|
||
# 获取用户可用的提供商
|
||
if user.role == UserRole.ADMIN:
|
||
# 管理员可以使用所有活动提供商
|
||
provider_ids = [
|
||
p.id for p in db.query(Provider.id).filter(Provider.is_active == True).all()
|
||
]
|
||
else:
|
||
# 普通用户使用关联的提供商
|
||
provider_ids = [p.id for p in user.providers]
|
||
|
||
if not provider_ids:
|
||
return []
|
||
|
||
# 查询这些提供商的所有活跃 Model(关联 GlobalModel)
|
||
models = (
|
||
db.query(Model)
|
||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||
.filter(
|
||
and_(
|
||
Model.provider_id.in_(provider_ids),
|
||
Model.is_active == True,
|
||
GlobalModel.is_active == True,
|
||
)
|
||
)
|
||
.all()
|
||
)
|
||
|
||
logger.debug(f"用户 {user.email} 可用模型: {len(models)} 个 (提供商数: {len(provider_ids)})")
|
||
|
||
return models
|