mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-09 19:22:26 +08:00
Initial commit
This commit is contained in:
15
src/services/user/__init__.py
Normal file
15
src/services/user/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
用户服务模块
|
||||
|
||||
包含用户管理、API Key 管理等功能。
|
||||
"""
|
||||
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
from src.services.user.preference import PreferenceService
|
||||
from src.services.user.service import UserService
|
||||
|
||||
__all__ = [
|
||||
"UserService",
|
||||
"ApiKeyService",
|
||||
"PreferenceService",
|
||||
]
|
||||
393
src/services/user/apikey.py
Normal file
393
src/services/user/apikey.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
API密钥管理服务
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, Usage, User
|
||||
|
||||
|
||||
|
||||
class ApiKeyService:
|
||||
"""API密钥管理服务"""
|
||||
|
||||
@staticmethod
|
||||
def create_api_key(
|
||||
db: Session,
|
||||
user_id: str, # UUID
|
||||
name: Optional[str] = None,
|
||||
allowed_providers: Optional[List[str]] = None,
|
||||
allowed_api_formats: Optional[List[str]] = None,
|
||||
allowed_models: Optional[List[str]] = None,
|
||||
rate_limit: int = 100,
|
||||
concurrent_limit: int = 5,
|
||||
expire_days: Optional[int] = None,
|
||||
initial_balance_usd: Optional[float] = None,
|
||||
is_standalone: bool = False,
|
||||
auto_delete_on_expiry: bool = False,
|
||||
) -> tuple[ApiKey, str]:
|
||||
"""创建新的API密钥,返回密钥对象和明文密钥
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
name: 密钥名称
|
||||
allowed_providers: 允许的提供商列表
|
||||
allowed_api_formats: 允许的 API 格式列表
|
||||
allowed_models: 允许的模型列表
|
||||
rate_limit: 速率限制
|
||||
concurrent_limit: 并发限制
|
||||
expire_days: 过期天数,None = 永不过期
|
||||
initial_balance_usd: 初始余额(USD),仅用于独立Key,None = 无限制
|
||||
is_standalone: 是否为独立余额Key(仅管理员可创建)
|
||||
auto_delete_on_expiry: 过期后是否自动删除(True=物理删除,False=仅禁用)
|
||||
"""
|
||||
|
||||
# 生成密钥
|
||||
key = ApiKey.generate_key()
|
||||
key_hash = ApiKey.hash_key(key)
|
||||
key_encrypted = crypto_service.encrypt(key) # 加密存储密钥
|
||||
|
||||
# 计算过期时间
|
||||
expires_at = None
|
||||
if expire_days:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
|
||||
|
||||
api_key = ApiKey(
|
||||
user_id=user_id,
|
||||
key_hash=key_hash,
|
||||
key_encrypted=key_encrypted,
|
||||
name=name or f"API Key {datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}",
|
||||
allowed_providers=allowed_providers,
|
||||
allowed_api_formats=allowed_api_formats,
|
||||
allowed_models=allowed_models,
|
||||
rate_limit=rate_limit,
|
||||
concurrent_limit=concurrent_limit,
|
||||
expires_at=expires_at,
|
||||
balance_used_usd=0.0,
|
||||
current_balance_usd=initial_balance_usd, # 直接使用初始余额,None = 无限制
|
||||
is_standalone=is_standalone,
|
||||
auto_delete_on_expiry=auto_delete_on_expiry,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
db.add(api_key)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
logger.info(f"创建API密钥: 用户ID {user_id}, 密钥名 {api_key.name}, "
|
||||
f"独立Key={is_standalone}, 初始余额={initial_balance_usd}")
|
||||
return api_key, key # 返回密钥对象和明文密钥
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(db: Session, key_id: str) -> Optional[ApiKey]: # UUID
|
||||
"""获取API密钥"""
|
||||
return db.query(ApiKey).filter(ApiKey.id == key_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_api_key_by_key(db: Session, key: str) -> Optional[ApiKey]:
|
||||
"""通过密钥字符串获取API密钥"""
|
||||
key_hash = ApiKey.hash_key(key)
|
||||
return db.query(ApiKey).filter(ApiKey.key_hash == key_hash).first()
|
||||
|
||||
@staticmethod
|
||||
def list_user_api_keys(
|
||||
db: Session, user_id: str, is_active: Optional[bool] = None # UUID
|
||||
) -> List[ApiKey]:
|
||||
"""列出用户的所有API密钥(不包括独立Key)"""
|
||||
query = db.query(ApiKey).filter(
|
||||
ApiKey.user_id == user_id, ApiKey.is_standalone == False # 排除独立Key
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(ApiKey.is_active == is_active)
|
||||
|
||||
return query.order_by(ApiKey.created_at.desc()).all()
|
||||
|
||||
@staticmethod
|
||||
def list_standalone_api_keys(db: Session, is_active: Optional[bool] = None) -> List[ApiKey]:
|
||||
"""列出所有独立余额Key(仅管理员可用)"""
|
||||
query = db.query(ApiKey).filter(ApiKey.is_standalone == True)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(ApiKey.is_active == is_active)
|
||||
|
||||
return query.order_by(ApiKey.created_at.desc()).all()
|
||||
|
||||
@staticmethod
|
||||
def update_api_key(db: Session, key_id: str, **kwargs) -> Optional[ApiKey]: # UUID
|
||||
"""更新API密钥"""
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
# 可更新的字段
|
||||
updatable_fields = [
|
||||
"name",
|
||||
"allowed_providers",
|
||||
"allowed_api_formats",
|
||||
"allowed_models",
|
||||
"rate_limit",
|
||||
"concurrent_limit",
|
||||
"is_active",
|
||||
"expires_at",
|
||||
"balance_limit_usd",
|
||||
"auto_delete_on_expiry",
|
||||
]
|
||||
|
||||
for field, value in kwargs.items():
|
||||
if field in updatable_fields and value is not None:
|
||||
setattr(api_key, field, value)
|
||||
|
||||
api_key.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
logger.debug(f"更新API密钥: ID {key_id}")
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def delete_api_key(db: Session, key_id: str) -> bool: # UUID
|
||||
"""删除API密钥(禁用)"""
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
api_key.is_active = False
|
||||
api_key.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"删除API密钥: ID {key_id}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_remaining_balance(api_key: ApiKey) -> Optional[float]:
|
||||
"""计算剩余余额(仅用于独立Key)
|
||||
|
||||
Returns:
|
||||
剩余余额,None 表示无限制或非独立Key
|
||||
"""
|
||||
if not api_key.is_standalone:
|
||||
return None
|
||||
|
||||
if api_key.current_balance_usd is None:
|
||||
return None
|
||||
|
||||
# 剩余余额 = 当前余额 - 已使用余额
|
||||
remaining = api_key.current_balance_usd - (api_key.balance_used_usd or 0)
|
||||
return max(0, remaining) # 不能为负数
|
||||
|
||||
@staticmethod
|
||||
def check_balance(api_key: ApiKey) -> tuple[bool, Optional[float]]:
|
||||
"""检查余额限制(仅用于独立Key)
|
||||
|
||||
Returns:
|
||||
(is_allowed, remaining_balance): 是否允许请求,剩余余额(None表示无限制)
|
||||
"""
|
||||
if not api_key.is_standalone:
|
||||
# 非独立Key不检查余额
|
||||
return True, None
|
||||
|
||||
# 使用新的预付费模式: current_balance_usd
|
||||
if api_key.current_balance_usd is None:
|
||||
# 无余额限制
|
||||
return True, None
|
||||
|
||||
# 使用统一的余额计算方法
|
||||
remaining = ApiKeyService.get_remaining_balance(api_key)
|
||||
is_allowed = remaining > 0 if remaining is not None else True
|
||||
|
||||
if not is_allowed:
|
||||
logger.warning(f"API密钥余额不足: Key ID {api_key.id}, " f"剩余余额 ${remaining:.4f}")
|
||||
|
||||
return is_allowed, remaining
|
||||
|
||||
@staticmethod
|
||||
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
|
||||
"""检查速率限制"""
|
||||
|
||||
# 计算时间窗口
|
||||
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)
|
||||
|
||||
# 统计窗口内的请求数
|
||||
request_count = (
|
||||
db.query(func.count(Usage.id))
|
||||
.filter(Usage.api_key_id == api_key.id, Usage.created_at >= window_start)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# 检查是否超限
|
||||
is_allowed = request_count < api_key.rate_limit
|
||||
|
||||
if not is_allowed:
|
||||
logger.warning(f"API密钥速率限制: Key ID {api_key.id}, 请求数 {request_count}/{api_key.rate_limit}")
|
||||
|
||||
return is_allowed, api_key.rate_limit - request_count
|
||||
|
||||
@staticmethod
|
||||
def add_balance(db: Session, key_id: str, amount_usd: float) -> Optional[ApiKey]:
|
||||
"""为独立余额Key调整余额
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
key_id: API Key ID
|
||||
amount_usd: 要调整的余额金额(USD),正数为增加,负数为扣除
|
||||
|
||||
Returns:
|
||||
更新后的API Key对象,如果Key不存在或不是独立Key则返回None
|
||||
"""
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
|
||||
if not api_key:
|
||||
logger.warning(f"余额调整失败: Key ID {key_id} 不存在")
|
||||
return None
|
||||
|
||||
if not api_key.is_standalone:
|
||||
logger.warning(f"余额调整失败: Key ID {key_id} 不是独立余额Key")
|
||||
return None
|
||||
|
||||
if amount_usd == 0:
|
||||
logger.warning(f"余额调整失败: 调整金额不能为0,当前值 ${amount_usd}")
|
||||
return None
|
||||
|
||||
# 如果是扣除(负数),检查是否超过当前余额
|
||||
if amount_usd < 0:
|
||||
current = api_key.current_balance_usd or 0
|
||||
if abs(amount_usd) > current:
|
||||
logger.warning(f"余额扣除失败: 扣除金额 ${abs(amount_usd):.4f} 超过当前余额 ${current:.4f}")
|
||||
return None
|
||||
|
||||
# 调整当前余额
|
||||
if api_key.current_balance_usd is None:
|
||||
api_key.current_balance_usd = amount_usd if amount_usd > 0 else 0
|
||||
else:
|
||||
api_key.current_balance_usd = max(0, api_key.current_balance_usd + amount_usd)
|
||||
|
||||
api_key.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
action = "增加" if amount_usd > 0 else "扣除"
|
||||
logger.info(f"余额调整成功: Key ID {key_id}, {action} ${abs(amount_usd):.4f}, "
|
||||
f"新余额 ${api_key.current_balance_usd:.4f}")
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def cleanup_expired_keys(db: Session, auto_delete: bool = False) -> int:
|
||||
"""清理过期的API密钥
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
auto_delete: 全局默认行为(True=物理删除,False=仅禁用)
|
||||
单个Key的 auto_delete_on_expiry 字段会覆盖此设置
|
||||
|
||||
Returns:
|
||||
int: 清理的密钥数量
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expired_keys = (
|
||||
db.query(ApiKey)
|
||||
.filter(ApiKey.expires_at <= now, ApiKey.is_active == True) # 只处理仍然活跃的
|
||||
.all()
|
||||
)
|
||||
|
||||
count = 0
|
||||
for api_key in expired_keys:
|
||||
# 优先使用Key自身的auto_delete_on_expiry设置,否则使用全局设置
|
||||
should_delete = (
|
||||
api_key.auto_delete_on_expiry
|
||||
if api_key.auto_delete_on_expiry is not None
|
||||
else auto_delete
|
||||
)
|
||||
|
||||
if should_delete:
|
||||
# 物理删除(Usage记录会保留,因为是 SET NULL)
|
||||
db.delete(api_key)
|
||||
logger.info(f"删除过期API密钥: ID {api_key.id}, 名称 {api_key.name}, "
|
||||
f"过期时间 {api_key.expires_at}")
|
||||
else:
|
||||
# 仅禁用
|
||||
api_key.is_active = False
|
||||
api_key.updated_at = now
|
||||
logger.info(f"禁用过期API密钥: ID {api_key.id}, 名称 {api_key.name}, "
|
||||
f"过期时间 {api_key.expires_at}")
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
db.commit()
|
||||
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def get_api_key_stats(
|
||||
db: Session,
|
||||
key_id: str, # UUID
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""获取API密钥使用统计"""
|
||||
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
|
||||
if not api_key:
|
||||
return {}
|
||||
|
||||
query = db.query(Usage).filter(Usage.api_key_id == key_id)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(Usage.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(Usage.created_at <= end_date)
|
||||
|
||||
# 统计数据
|
||||
stats = db.query(
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("cost_usd"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time"),
|
||||
).filter(Usage.api_key_id == key_id)
|
||||
|
||||
if start_date:
|
||||
stats = stats.filter(Usage.created_at >= start_date)
|
||||
if end_date:
|
||||
stats = stats.filter(Usage.created_at <= end_date)
|
||||
|
||||
result = stats.first()
|
||||
|
||||
# 按天统计
|
||||
daily_stats = db.query(
|
||||
func.date(Usage.created_at).label("date"),
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("cost_usd"),
|
||||
).filter(Usage.api_key_id == key_id)
|
||||
|
||||
if start_date:
|
||||
daily_stats = daily_stats.filter(Usage.created_at >= start_date)
|
||||
if end_date:
|
||||
daily_stats = daily_stats.filter(Usage.created_at <= end_date)
|
||||
|
||||
daily_stats = daily_stats.group_by(func.date(Usage.created_at)).all()
|
||||
|
||||
return {
|
||||
"key_id": key_id,
|
||||
"key_name": api_key.name,
|
||||
"total_requests": result.requests or 0,
|
||||
"total_tokens": result.tokens or 0,
|
||||
"total_cost_usd": float(result.cost_usd or 0),
|
||||
"avg_response_time_ms": float(result.avg_response_time or 0),
|
||||
"daily_stats": [
|
||||
{
|
||||
"date": stat.date.isoformat() if stat.date else None,
|
||||
"requests": stat.requests,
|
||||
"tokens": stat.tokens,
|
||||
"cost_usd": float(stat.cost_usd),
|
||||
}
|
||||
for stat in daily_stats
|
||||
],
|
||||
}
|
||||
137
src/services/user/preference.py
Normal file
137
src/services/user/preference.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
用户偏好设置服务
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider, User, UserPreference
|
||||
|
||||
|
||||
|
||||
class PreferenceService:
|
||||
"""用户偏好设置服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_or_create_preferences(db: Session, user_id: str) -> UserPreference: # UUID
|
||||
"""获取或创建用户偏好设置"""
|
||||
preferences = db.query(UserPreference).filter(UserPreference.user_id == user_id).first()
|
||||
|
||||
if not preferences:
|
||||
# 创建默认偏好设置
|
||||
preferences = UserPreference(
|
||||
user_id=user_id,
|
||||
theme="light",
|
||||
language="zh-CN",
|
||||
timezone="Asia/Shanghai",
|
||||
email_notifications=True,
|
||||
usage_alerts=True,
|
||||
announcement_notifications=True,
|
||||
)
|
||||
db.add(preferences)
|
||||
db.commit()
|
||||
db.refresh(preferences)
|
||||
logger.info(f"Created default preferences for user {user_id}")
|
||||
|
||||
return preferences
|
||||
|
||||
@staticmethod
|
||||
def update_preferences(
|
||||
db: Session,
|
||||
user_id: str, # UUID
|
||||
avatar_url: Optional[str] = None,
|
||||
bio: Optional[str] = None,
|
||||
default_provider_id: Optional[str] = None, # UUID
|
||||
theme: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
timezone: Optional[str] = None,
|
||||
email_notifications: Optional[bool] = None,
|
||||
usage_alerts: Optional[bool] = None,
|
||||
announcement_notifications: Optional[bool] = None,
|
||||
) -> UserPreference:
|
||||
"""更新用户偏好设置"""
|
||||
preferences = PreferenceService.get_or_create_preferences(db, user_id)
|
||||
|
||||
# 更新提供的字段
|
||||
if avatar_url is not None:
|
||||
preferences.avatar_url = avatar_url
|
||||
if bio is not None:
|
||||
preferences.bio = bio
|
||||
if default_provider_id is not None:
|
||||
# 验证提供商是否存在且活跃
|
||||
provider = (
|
||||
db.query(Provider)
|
||||
.filter(Provider.id == default_provider_id, Provider.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found or inactive")
|
||||
preferences.default_provider_id = default_provider_id
|
||||
if theme is not None:
|
||||
if theme not in ["light", "dark", "auto"]:
|
||||
raise ValueError("Invalid theme. Must be 'light', 'dark', or 'auto'")
|
||||
preferences.theme = theme
|
||||
if language is not None:
|
||||
preferences.language = language
|
||||
if timezone is not None:
|
||||
preferences.timezone = timezone
|
||||
if email_notifications is not None:
|
||||
preferences.email_notifications = email_notifications
|
||||
if usage_alerts is not None:
|
||||
preferences.usage_alerts = usage_alerts
|
||||
if announcement_notifications is not None:
|
||||
preferences.announcement_notifications = announcement_notifications
|
||||
|
||||
db.commit()
|
||||
db.refresh(preferences)
|
||||
logger.info(f"Updated preferences for user {user_id}")
|
||||
|
||||
return preferences
|
||||
|
||||
@staticmethod
|
||||
def get_user_with_preferences(db: Session, user_id: str) -> dict: # UUID
|
||||
"""获取用户信息及其偏好设置"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise NotFoundException("User not found")
|
||||
|
||||
preferences = PreferenceService.get_or_create_preferences(db, user_id)
|
||||
|
||||
# 构建返回数据
|
||||
user_data = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"is_active": user.is_active,
|
||||
"created_at": user.created_at,
|
||||
"last_login_at": user.last_login_at,
|
||||
"preferences": {
|
||||
"avatar_url": preferences.avatar_url,
|
||||
"bio": preferences.bio,
|
||||
"default_provider": (
|
||||
preferences.default_provider.name if preferences.default_provider else None
|
||||
),
|
||||
"theme": preferences.theme,
|
||||
"language": preferences.language,
|
||||
"timezone": preferences.timezone,
|
||||
"notifications": {
|
||||
"email": preferences.email_notifications,
|
||||
"usage_alerts": preferences.usage_alerts,
|
||||
"announcements": preferences.announcement_notifications,
|
||||
},
|
||||
},
|
||||
# 配额信息
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"stats": {
|
||||
"total_cost": user.used_usd,
|
||||
"total_cost_all_time": user.total_usd,
|
||||
"api_keys_count": len(user.api_keys),
|
||||
},
|
||||
}
|
||||
|
||||
return user_data
|
||||
433
src/services/user/service.py
Normal file
433
src/services/user/service.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
用户管理服务
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.core.validators import EmailValidator, PasswordValidator, UsernameValidator
|
||||
from src.models.database import ApiKey, GlobalModel, Model, Provider, Usage, User, UserRole
|
||||
from src.services.cache.user_cache import UserCacheService
|
||||
from src.utils.transaction_manager import retry_on_database_error, transactional
|
||||
|
||||
|
||||
|
||||
class UserService:
|
||||
"""用户管理服务"""
|
||||
|
||||
@staticmethod
|
||||
@transactional()
|
||||
@retry_on_database_error(max_retries=3)
|
||||
def create_user(
|
||||
db: Session,
|
||||
email: str,
|
||||
username: str,
|
||||
password: str,
|
||||
role: UserRole = UserRole.USER,
|
||||
quota_usd: Optional[float] = 10.0,
|
||||
) -> User:
|
||||
"""创建新用户,quota_usd 为 None 表示无限制"""
|
||||
|
||||
# 验证邮箱格式
|
||||
valid, error_msg = EmailValidator.validate(email)
|
||||
if not valid:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# 验证用户名格式
|
||||
valid, error_msg = UsernameValidator.validate(username)
|
||||
if not valid:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# 验证密码复杂度
|
||||
valid, error_msg = PasswordValidator.validate(password)
|
||||
if not valid:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
if db.query(User).filter(User.email == email).first():
|
||||
raise ValueError(f"邮箱已存在: {email}")
|
||||
|
||||
# 检查用户名是否已存在
|
||||
if db.query(User).filter(User.username == username).first():
|
||||
raise ValueError(f"用户名已存在: {username}")
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
username=username,
|
||||
role=role,
|
||||
quota_usd=quota_usd,
|
||||
is_active=True,
|
||||
)
|
||||
user.set_password(password)
|
||||
|
||||
db.add(user)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
db.refresh(user)
|
||||
|
||||
logger.info(f"创建新用户: {email} (ID: {user.id}, 角色: {role.value})")
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
@transactional()
|
||||
def create_user_with_api_key(
|
||||
db: Session,
|
||||
email: str,
|
||||
username: str,
|
||||
password: str,
|
||||
api_key_name: str = "默认密钥",
|
||||
role: UserRole = UserRole.USER,
|
||||
quota_usd: Optional[float] = 10.0,
|
||||
concurrent_limit: int = 5,
|
||||
) -> tuple[User, ApiKey]:
|
||||
"""
|
||||
创建用户并同时创建API密钥(原子操作)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
email: 邮箱
|
||||
username: 用户名
|
||||
password: 密码
|
||||
api_key_name: API密钥名称
|
||||
role: 用户角色
|
||||
quota_usd: USD配额,None 表示无限制
|
||||
concurrent_limit: 并发限制
|
||||
|
||||
Returns:
|
||||
tuple[User, ApiKey]: 用户对象和API密钥对象
|
||||
|
||||
Raises:
|
||||
ValueError: 当验证失败时
|
||||
"""
|
||||
# 创建用户
|
||||
user = UserService.create_user(
|
||||
db=db, email=email, username=username, password=password, role=role, quota_usd=quota_usd
|
||||
)
|
||||
|
||||
# 导入API密钥服务(避免循环导入)
|
||||
from .apikey import ApiKeyService
|
||||
|
||||
# 创建API密钥(返回值是 (api_key, plain_key))
|
||||
api_key, plain_key = ApiKeyService.create_api_key(
|
||||
db=db, user_id=user.id, name=api_key_name, concurrent_limit=concurrent_limit
|
||||
)
|
||||
|
||||
logger.info(f"创建用户和API密钥完成: {email} (用户ID: {user.id}, 密钥ID: {api_key.id})")
|
||||
|
||||
# 返回用户对象、API Key对象和明文密钥
|
||||
return user, api_key, plain_key
|
||||
|
||||
@staticmethod
|
||||
def get_user(db: Session, user_id: str) -> Optional[User]:
|
||||
"""获取用户"""
|
||||
import random
|
||||
import time
|
||||
|
||||
# 添加重试机制处理数据库并发问题
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
return user
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
# 添加随机延迟避免并发冲突
|
||||
time.sleep(random.uniform(0.01, 0.05))
|
||||
db.rollback() # 回滚事务
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||
"""通过邮箱获取用户"""
|
||||
return db.query(User).filter(User.email == email).first()
|
||||
|
||||
@staticmethod
|
||||
def list_users(
|
||||
db: Session,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
role: Optional[UserRole] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
) -> List[User]:
|
||||
"""列出用户"""
|
||||
query = db.query(User)
|
||||
|
||||
if role:
|
||||
query = query.filter(User.role == role)
|
||||
if is_active is not None:
|
||||
query = query.filter(User.is_active == is_active)
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
@transactional()
|
||||
def update_user(db: Session, user_id: str, **kwargs) -> Optional[User]:
|
||||
"""更新用户信息"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# 可更新的字段
|
||||
updatable_fields = [
|
||||
"email",
|
||||
"username",
|
||||
"quota_usd",
|
||||
"is_active",
|
||||
"role",
|
||||
# 访问限制字段
|
||||
"allowed_providers",
|
||||
"allowed_endpoints",
|
||||
"allowed_models",
|
||||
]
|
||||
|
||||
# 允许设置为 None 的字段(表示无限制)
|
||||
nullable_fields = ["quota_usd", "allowed_providers", "allowed_endpoints", "allowed_models"]
|
||||
|
||||
for field, value in kwargs.items():
|
||||
if field not in updatable_fields:
|
||||
continue
|
||||
# nullable_fields 中的字段允许设置为 None
|
||||
if field in nullable_fields:
|
||||
setattr(user, field, value)
|
||||
elif value is not None:
|
||||
setattr(user, field, value)
|
||||
|
||||
# 如果提供了新密码
|
||||
if "password" in kwargs and kwargs["password"]:
|
||||
# 验证新密码复杂度
|
||||
valid, error_msg = PasswordValidator.validate(kwargs["password"])
|
||||
if not valid:
|
||||
raise ValueError(error_msg)
|
||||
user.set_password(kwargs["password"])
|
||||
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
db.refresh(user)
|
||||
|
||||
# 清除用户缓存
|
||||
asyncio.create_task(UserCacheService.invalidate_user_cache(user.id, user.email))
|
||||
|
||||
logger.debug(f"更新用户信息: {user.email} (ID: {user_id})")
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
@transactional()
|
||||
def delete_user(db: Session, user_id: str) -> bool:
|
||||
"""删除用户(硬删除)
|
||||
|
||||
删除流程:
|
||||
1. 手动删除关联的子记录(避免 SQLAlchemy ORM 与数据库 CASCADE 冲突)
|
||||
2. 删除用户记录
|
||||
3. 历史 Usage 记录保留,user_id 会被数据库设为 NULL
|
||||
4. 新用户注册时会有新的 UUID,看不到旧用户的记录
|
||||
"""
|
||||
from src.models.database import (
|
||||
AnnouncementRead,
|
||||
ApiKey,
|
||||
UserPreference,
|
||||
UserQuota,
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
return False
|
||||
|
||||
# 记录删除信息用于日志
|
||||
email = user.email
|
||||
|
||||
# 手动删除子记录,避免 SQLAlchemy 的 ORM cascade 与数据库 CASCADE 冲突
|
||||
# 这些表的数据库外键已经设置了 ON DELETE CASCADE,但 SQLAlchemy 会先尝试 UPDATE 设置为 NULL
|
||||
# 所以我们手动删除来避免这个问题
|
||||
db.query(UserPreference).filter(UserPreference.user_id == user_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.query(UserQuota).filter(UserQuota.user_id == user_id).delete(synchronize_session=False)
|
||||
db.query(AnnouncementRead).filter(AnnouncementRead.user_id == user_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
api_key_count = db.query(ApiKey).filter(ApiKey.user_id == user_id).count()
|
||||
db.query(ApiKey).filter(ApiKey.user_id == user_id).delete(synchronize_session=False)
|
||||
|
||||
# 现在删除用户(Usage, AuditLog, RequestAttempt 会通过数据库 SET NULL 保留)
|
||||
db.delete(user)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
|
||||
# 清除用户缓存
|
||||
asyncio.create_task(UserCacheService.invalidate_user_cache(user_id, email))
|
||||
|
||||
logger.info(f"删除用户: {email} (ID: {user_id}), 同时删除 {api_key_count} 个API密钥")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@transactional()
|
||||
def change_password(
|
||||
db: Session, user_id: str, old_password: str, new_password: str
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
更改用户密码
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
old_password: 旧密码
|
||||
new_password: 新密码
|
||||
|
||||
Returns:
|
||||
(是否成功, 消息)
|
||||
"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
return False, "用户不存在"
|
||||
|
||||
# 验证旧密码
|
||||
if not user.verify_password(old_password):
|
||||
logger.warning(f"密码更改失败 - 旧密码错误: 用户ID {user_id}")
|
||||
return False, "旧密码错误"
|
||||
|
||||
# 验证新密码复杂度
|
||||
valid, error_msg = PasswordValidator.validate(new_password)
|
||||
if not valid:
|
||||
return False, error_msg
|
||||
|
||||
# 检查新密码不能与旧密码相同
|
||||
if old_password == new_password:
|
||||
return False, "新密码不能与旧密码相同"
|
||||
|
||||
# 设置新密码
|
||||
user.set_password(new_password)
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
# 清除用户缓存
|
||||
asyncio.create_task(UserCacheService.invalidate_user_cache(user.id, user.email))
|
||||
|
||||
logger.info(f"密码更改成功: 用户ID {user_id}")
|
||||
return True, "密码更改成功"
|
||||
|
||||
@staticmethod
|
||||
def update_user_quota(
|
||||
db: Session,
|
||||
user_id: str,
|
||||
quota_usd: Optional[float] = None,
|
||||
) -> Optional[User]:
|
||||
"""更新用户配额"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if quota_usd is not None:
|
||||
user.quota_usd = quota_usd
|
||||
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
# 清除用户缓存
|
||||
asyncio.create_task(UserCacheService.invalidate_user_cache(user.id, user.email))
|
||||
|
||||
logger.debug(f"更新用户配额: {user.email} (USD: {quota_usd})")
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_usage_stats(
|
||||
db: Session,
|
||||
user_id: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""获取用户使用统计"""
|
||||
|
||||
query = db.query(Usage).filter(Usage.user_id == user_id)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(Usage.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(Usage.created_at <= end_date)
|
||||
|
||||
# 统计数据
|
||||
stats = db.query(
|
||||
func.count(Usage.id).label("total_requests"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time"),
|
||||
).filter(Usage.user_id == user_id)
|
||||
|
||||
if start_date:
|
||||
stats = stats.filter(Usage.created_at >= start_date)
|
||||
if end_date:
|
||||
stats = stats.filter(Usage.created_at <= end_date)
|
||||
|
||||
result = stats.first()
|
||||
|
||||
# 按模型分组统计
|
||||
model_stats = db.query(
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("cost_usd"),
|
||||
).filter(Usage.user_id == user_id)
|
||||
|
||||
if start_date:
|
||||
model_stats = model_stats.filter(Usage.created_at >= start_date)
|
||||
if end_date:
|
||||
model_stats = model_stats.filter(Usage.created_at <= end_date)
|
||||
|
||||
model_stats = model_stats.group_by(Usage.model).all()
|
||||
|
||||
return {
|
||||
"total_requests": result.total_requests or 0,
|
||||
"total_tokens": result.total_tokens or 0,
|
||||
"total_cost_usd": float(result.total_cost_usd or 0),
|
||||
"avg_response_time_ms": float(result.avg_response_time or 0),
|
||||
"by_model": [
|
||||
{
|
||||
"model": stat.model,
|
||||
"requests": stat.requests,
|
||||
"tokens": stat.tokens,
|
||||
"cost_usd": float(stat.cost_usd),
|
||||
}
|
||||
for stat in model_stats
|
||||
],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_user_available_models(db: Session, user: User) -> List[Model]:
|
||||
"""获取用户可用的模型
|
||||
|
||||
新架构:通过 GlobalModel + Model 关联查询用户可用模型
|
||||
逻辑:用户可用提供商 → Provider 的 Model 实现 → 关联的 GlobalModel
|
||||
"""
|
||||
# 获取用户可用的提供商
|
||||
if user.role == UserRole.ADMIN:
|
||||
# 管理员可以使用所有活动提供商
|
||||
provider_ids = [
|
||||
p.id for p in db.query(Provider.id).filter(Provider.is_active == True).all()
|
||||
]
|
||||
else:
|
||||
# 普通用户使用关联的提供商
|
||||
provider_ids = [p.id for p in user.providers]
|
||||
|
||||
if not provider_ids:
|
||||
return []
|
||||
|
||||
# 查询这些提供商的所有活跃 Model(关联 GlobalModel)
|
||||
models = (
|
||||
db.query(Model)
|
||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
and_(
|
||||
Model.provider_id.in_(provider_ids),
|
||||
Model.is_active == True,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
logger.debug(f"用户 {user.email} 可用模型: {len(models)} 个 (提供商数: {len(provider_ids)})")
|
||||
|
||||
return models
|
||||
Reference in New Issue
Block a user