Files
Aether/src/api/admin/system.py
RWDai 6bd8cdb9cf feat: 实现邮箱验证注册功能
添加完整的邮箱验证注册系统,包括验证码发送、验证和限流控制:
  - 新增邮箱验证服务模块(email_sender, email_template, email_verification)
  - 更新认证API支持邮箱验证注册流程
  - 添加注册对话框和验证码输入组件
  - 完善IP限流器支持邮箱验证场景
  - 修复前端类型定义问题,升级esbuild依赖

  🤖 Generated with [Claude Code](https://claude.com/claude-code)

  Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2025-12-30 17:15:48 +08:00

1154 lines
50 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""系统设置API端点。"""
from dataclasses import dataclass
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException, translate_pydantic_error
from src.database import get_db
from src.models.api import SystemSettingsRequest, SystemSettingsResponse
from src.models.database import ApiKey, Provider, Usage, User
from src.services.system.config import SystemConfigService
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
pipeline = ApiRequestPipeline()
@router.get("/settings")
async def get_system_settings(request: Request, db: Session = Depends(get_db)):
"""获取系统设置(管理员)"""
adapter = AdminGetSystemSettingsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/settings")
async def update_system_settings(http_request: Request, db: Session = Depends(get_db)):
"""更新系统设置(管理员)"""
adapter = AdminUpdateSystemSettingsAdapter()
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
@router.get("/configs")
async def get_all_system_configs(request: Request, db: Session = Depends(get_db)):
"""获取所有系统配置(管理员)"""
adapter = AdminGetAllConfigsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/configs/{key}")
async def get_system_config(key: str, request: Request, db: Session = Depends(get_db)):
"""获取特定系统配置(管理员)"""
adapter = AdminGetSystemConfigAdapter(key=key)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/configs/{key}")
async def set_system_config(
key: str,
request: Request,
db: Session = Depends(get_db),
):
"""设置系统配置(管理员)"""
adapter = AdminSetSystemConfigAdapter(key=key)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/configs/{key}")
async def delete_system_config(key: str, request: Request, db: Session = Depends(get_db)):
"""删除系统配置(管理员)"""
adapter = AdminDeleteSystemConfigAdapter(key=key)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/stats")
async def get_system_stats(request: Request, db: Session = Depends(get_db)):
adapter = AdminSystemStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/cleanup")
async def trigger_cleanup(request: Request, db: Session = Depends(get_db)):
"""Manually trigger usage record cleanup task"""
adapter = AdminTriggerCleanupAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/api-formats")
async def get_api_formats(request: Request, db: Session = Depends(get_db)):
"""获取所有可用的API格式列表"""
adapter = AdminGetApiFormatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/config/export")
async def export_config(request: Request, db: Session = Depends(get_db)):
"""导出提供商和模型配置(管理员)"""
adapter = AdminExportConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/config/import")
async def import_config(request: Request, db: Session = Depends(get_db)):
"""导入提供商和模型配置(管理员)"""
adapter = AdminImportConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/users/export")
async def export_users(request: Request, db: Session = Depends(get_db)):
"""导出用户数据(管理员)"""
adapter = AdminExportUsersAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/users/import")
async def import_users(request: Request, db: Session = Depends(get_db)):
"""导入用户数据(管理员)"""
adapter = AdminImportUsersAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/smtp/test")
async def test_smtp(request: Request, db: Session = Depends(get_db)):
"""测试 SMTP 连接(管理员)"""
adapter = AdminTestSmtpAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- 系统设置适配器 --------
class AdminGetSystemSettingsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
default_provider = SystemConfigService.get_default_provider(db)
default_model = SystemConfigService.get_config(db, "default_model")
enable_usage_tracking = (
SystemConfigService.get_config(db, "enable_usage_tracking", "true") == "true"
)
return SystemSettingsResponse(
default_provider=default_provider,
default_model=default_model,
enable_usage_tracking=enable_usage_tracking,
)
class AdminUpdateSystemSettingsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
settings_request = SystemSettingsRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
if settings_request.default_provider is not None:
provider = (
db.query(Provider)
.filter(
Provider.name == settings_request.default_provider,
Provider.is_active.is_(True),
)
.first()
)
if not provider and settings_request.default_provider != "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"提供商 '{settings_request.default_provider}' 不存在或未启用",
)
if settings_request.default_provider:
SystemConfigService.set_default_provider(db, settings_request.default_provider)
else:
SystemConfigService.delete_config(db, "default_provider")
if settings_request.default_model is not None:
if settings_request.default_model:
SystemConfigService.set_config(db, "default_model", settings_request.default_model)
else:
SystemConfigService.delete_config(db, "default_model")
if settings_request.enable_usage_tracking is not None:
SystemConfigService.set_config(
db,
"enable_usage_tracking",
str(settings_request.enable_usage_tracking).lower(),
)
return {"message": "系统设置更新成功"}
class AdminGetAllConfigsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
return SystemConfigService.get_all_configs(context.db)
@dataclass
class AdminGetSystemConfigAdapter(AdminApiAdapter):
key: str
async def handle(self, context): # type: ignore[override]
value = SystemConfigService.get_config(context.db, self.key)
if value is None:
raise NotFoundException(f"配置项 '{self.key}' 不存在")
return {"key": self.key, "value": value}
@dataclass
class AdminSetSystemConfigAdapter(AdminApiAdapter):
key: str
async def handle(self, context): # type: ignore[override]
payload = context.ensure_json_body()
config = SystemConfigService.set_config(
context.db,
self.key,
payload.get("value"),
payload.get("description"),
)
return {
"key": config.key,
"value": config.value,
"description": config.description,
"updated_at": config.updated_at.isoformat(),
}
@dataclass
class AdminDeleteSystemConfigAdapter(AdminApiAdapter):
key: str
async def handle(self, context): # type: ignore[override]
deleted = SystemConfigService.delete_config(context.db, self.key)
if not deleted:
raise NotFoundException(f"配置项 '{self.key}' 不存在")
return {"message": f"配置项 '{self.key}' 已删除"}
class AdminSystemStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
total_users = db.query(User).count()
active_users = db.query(User).filter(User.is_active.is_(True)).count()
total_providers = db.query(Provider).count()
active_providers = db.query(Provider).filter(Provider.is_active.is_(True)).count()
total_api_keys = db.query(ApiKey).count()
total_requests = db.query(Usage).count()
return {
"users": {"total": total_users, "active": active_users},
"providers": {"total": total_providers, "active": active_providers},
"api_keys": total_api_keys,
"requests": total_requests,
}
class AdminTriggerCleanupAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""手动触发清理任务"""
from datetime import datetime, timedelta, timezone
from sqlalchemy import func
from src.services.system.cleanup_scheduler import get_cleanup_scheduler
db = context.db
# 获取清理前的统计信息
total_before = db.query(Usage).count()
with_body_before = (
db.query(Usage)
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
.count()
)
with_headers_before = (
db.query(Usage)
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
.count()
)
# 触发清理
cleanup_scheduler = get_cleanup_scheduler()
await cleanup_scheduler._perform_cleanup()
# 获取清理后的统计信息
total_after = db.query(Usage).count()
with_body_after = (
db.query(Usage)
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
.count()
)
with_headers_after = (
db.query(Usage)
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
.count()
)
return {
"message": "清理任务执行完成",
"stats": {
"total_records": {
"before": total_before,
"after": total_after,
"deleted": total_before - total_after,
},
"body_fields": {
"before": with_body_before,
"after": with_body_after,
"cleaned": with_body_before - with_body_after,
},
"header_fields": {
"before": with_headers_before,
"after": with_headers_after,
"cleaned": with_headers_before - with_headers_after,
},
},
"timestamp": datetime.now(timezone.utc).isoformat(),
}
class AdminGetApiFormatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""获取所有可用的API格式"""
from src.core.api_format_metadata import API_FORMAT_DEFINITIONS
from src.core.enums import APIFormat
_ = context # 参数保留以符合接口规范
formats = []
for api_format in APIFormat:
definition = API_FORMAT_DEFINITIONS.get(api_format)
formats.append(
{
"value": api_format.value,
"label": api_format.value,
"default_path": definition.default_path if definition else "/",
"aliases": list(definition.aliases) if definition else [],
}
)
return {"formats": formats}
class AdminExportConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导出提供商和模型配置(解密数据)"""
from datetime import datetime, timezone
from src.core.crypto import crypto_service
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
db = context.db
# 导出 GlobalModels
global_models = db.query(GlobalModel).all()
global_models_data = []
for gm in global_models:
global_models_data.append(
{
"name": gm.name,
"display_name": gm.display_name,
"default_price_per_request": gm.default_price_per_request,
"default_tiered_pricing": gm.default_tiered_pricing,
"supported_capabilities": gm.supported_capabilities,
"config": gm.config,
"is_active": gm.is_active,
}
)
# 导出 Providers 及其关联数据
providers = db.query(Provider).all()
providers_data = []
for provider in providers:
# 导出 Endpoints
endpoints = (
db.query(ProviderEndpoint)
.filter(ProviderEndpoint.provider_id == provider.id)
.all()
)
endpoints_data = []
for ep in endpoints:
# 导出 Endpoint Keys
keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
)
keys_data = []
for key in keys:
# 解密 API Key
try:
decrypted_key = crypto_service.decrypt(key.api_key)
except Exception:
decrypted_key = ""
keys_data.append(
{
"api_key": decrypted_key,
"name": key.name,
"note": key.note,
"rate_multiplier": key.rate_multiplier,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"max_concurrent": key.max_concurrent,
"rate_limit": key.rate_limit,
"daily_limit": key.daily_limit,
"monthly_limit": key.monthly_limit,
"allowed_models": key.allowed_models,
"capabilities": key.capabilities,
"is_active": key.is_active,
}
)
endpoints_data.append(
{
"api_format": ep.api_format,
"base_url": ep.base_url,
"headers": ep.headers,
"timeout": ep.timeout,
"max_retries": ep.max_retries,
"max_concurrent": ep.max_concurrent,
"rate_limit": ep.rate_limit,
"is_active": ep.is_active,
"custom_path": ep.custom_path,
"config": ep.config,
"keys": keys_data,
}
)
# 导出 Provider Models
models = db.query(Model).filter(Model.provider_id == provider.id).all()
models_data = []
for model in models:
# 获取关联的 GlobalModel 名称
global_model = (
db.query(GlobalModel).filter(GlobalModel.id == model.global_model_id).first()
)
models_data.append(
{
"global_model_name": global_model.name if global_model else None,
"provider_model_name": model.provider_model_name,
"provider_model_mappings": model.provider_model_mappings,
"price_per_request": model.price_per_request,
"tiered_pricing": model.tiered_pricing,
"supports_vision": model.supports_vision,
"supports_function_calling": model.supports_function_calling,
"supports_streaming": model.supports_streaming,
"supports_extended_thinking": model.supports_extended_thinking,
"supports_image_generation": model.supports_image_generation,
"is_active": model.is_active,
"config": model.config,
}
)
providers_data.append(
{
"name": provider.name,
"display_name": provider.display_name,
"description": provider.description,
"website": provider.website,
"billing_type": provider.billing_type.value if provider.billing_type else None,
"monthly_quota_usd": provider.monthly_quota_usd,
"quota_reset_day": provider.quota_reset_day,
"rpm_limit": provider.rpm_limit,
"provider_priority": provider.provider_priority,
"is_active": provider.is_active,
"rate_limit": provider.rate_limit,
"concurrent_limit": provider.concurrent_limit,
"config": provider.config,
"endpoints": endpoints_data,
"models": models_data,
}
)
return {
"version": "1.0",
"exported_at": datetime.now(timezone.utc).isoformat(),
"global_models": global_models_data,
"providers": providers_data,
}
MAX_IMPORT_SIZE = 10 * 1024 * 1024 # 10MB
class AdminImportConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导入提供商和模型配置"""
import uuid
from datetime import datetime, timezone
from src.core.crypto import crypto_service
from src.core.enums import ProviderBillingType
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
# 检查请求体大小
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
raise InvalidRequestException("请求体大小不能超过 10MB")
db = context.db
payload = context.ensure_json_body()
# 验证配置版本
version = payload.get("version")
if version != "1.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
global_models_data = payload.get("global_models", [])
providers_data = payload.get("providers", [])
stats = {
"global_models": {"created": 0, "updated": 0, "skipped": 0},
"providers": {"created": 0, "updated": 0, "skipped": 0},
"endpoints": {"created": 0, "updated": 0, "skipped": 0},
"keys": {"created": 0, "updated": 0, "skipped": 0},
"models": {"created": 0, "updated": 0, "skipped": 0},
"errors": [],
}
try:
# 导入 GlobalModels
global_model_map = {} # name -> id 映射
for gm_data in global_models_data:
existing = (
db.query(GlobalModel).filter(GlobalModel.name == gm_data["name"]).first()
)
if existing:
global_model_map[gm_data["name"]] = existing.id
if merge_mode == "skip":
stats["global_models"]["skipped"] += 1
continue
elif merge_mode == "error":
raise InvalidRequestException(
f"GlobalModel '{gm_data['name']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有记录
existing.display_name = gm_data.get(
"display_name", existing.display_name
)
existing.default_price_per_request = gm_data.get(
"default_price_per_request"
)
existing.default_tiered_pricing = gm_data.get(
"default_tiered_pricing", existing.default_tiered_pricing
)
existing.supported_capabilities = gm_data.get(
"supported_capabilities"
)
existing.config = gm_data.get("config")
existing.is_active = gm_data.get("is_active", True)
existing.updated_at = datetime.now(timezone.utc)
stats["global_models"]["updated"] += 1
else:
# 创建新记录
new_gm = GlobalModel(
id=str(uuid.uuid4()),
name=gm_data["name"],
display_name=gm_data.get("display_name", gm_data["name"]),
default_price_per_request=gm_data.get("default_price_per_request"),
default_tiered_pricing=gm_data.get(
"default_tiered_pricing",
{"tiers": [{"up_to": None, "input_price_per_1m": 0, "output_price_per_1m": 0}]},
),
supported_capabilities=gm_data.get("supported_capabilities"),
config=gm_data.get("config"),
is_active=gm_data.get("is_active", True),
)
db.add(new_gm)
db.flush()
global_model_map[gm_data["name"]] = new_gm.id
stats["global_models"]["created"] += 1
# 导入 Providers
for prov_data in providers_data:
existing_provider = (
db.query(Provider).filter(Provider.name == prov_data["name"]).first()
)
if existing_provider:
provider_id = existing_provider.id
if merge_mode == "skip":
stats["providers"]["skipped"] += 1
# 仍然需要处理 endpoints 和 models如果存在
elif merge_mode == "error":
raise InvalidRequestException(
f"Provider '{prov_data['name']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有记录
existing_provider.display_name = prov_data.get(
"display_name", existing_provider.display_name
)
existing_provider.description = prov_data.get("description")
existing_provider.website = prov_data.get("website")
if prov_data.get("billing_type"):
existing_provider.billing_type = ProviderBillingType(
prov_data["billing_type"]
)
existing_provider.monthly_quota_usd = prov_data.get(
"monthly_quota_usd"
)
existing_provider.quota_reset_day = prov_data.get(
"quota_reset_day", 30
)
existing_provider.rpm_limit = prov_data.get("rpm_limit")
existing_provider.provider_priority = prov_data.get(
"provider_priority", 100
)
existing_provider.is_active = prov_data.get("is_active", True)
existing_provider.rate_limit = prov_data.get("rate_limit")
existing_provider.concurrent_limit = prov_data.get(
"concurrent_limit"
)
existing_provider.config = prov_data.get("config")
existing_provider.updated_at = datetime.now(timezone.utc)
stats["providers"]["updated"] += 1
else:
# 创建新 Provider
billing_type = ProviderBillingType.PAY_AS_YOU_GO
if prov_data.get("billing_type"):
billing_type = ProviderBillingType(prov_data["billing_type"])
new_provider = Provider(
id=str(uuid.uuid4()),
name=prov_data["name"],
display_name=prov_data.get("display_name", prov_data["name"]),
description=prov_data.get("description"),
website=prov_data.get("website"),
billing_type=billing_type,
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
quota_reset_day=prov_data.get("quota_reset_day", 30),
rpm_limit=prov_data.get("rpm_limit"),
provider_priority=prov_data.get("provider_priority", 100),
is_active=prov_data.get("is_active", True),
rate_limit=prov_data.get("rate_limit"),
concurrent_limit=prov_data.get("concurrent_limit"),
config=prov_data.get("config"),
)
db.add(new_provider)
db.flush()
provider_id = new_provider.id
stats["providers"]["created"] += 1
# 导入 Endpoints
for ep_data in prov_data.get("endpoints", []):
existing_ep = (
db.query(ProviderEndpoint)
.filter(
ProviderEndpoint.provider_id == provider_id,
ProviderEndpoint.api_format == ep_data["api_format"],
)
.first()
)
if existing_ep:
endpoint_id = existing_ep.id
if merge_mode == "skip":
stats["endpoints"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"Endpoint '{ep_data['api_format']}' 已存在于 Provider '{prov_data['name']}'"
)
elif merge_mode == "overwrite":
existing_ep.base_url = ep_data.get(
"base_url", existing_ep.base_url
)
existing_ep.headers = ep_data.get("headers")
existing_ep.timeout = ep_data.get("timeout", 300)
existing_ep.max_retries = ep_data.get("max_retries", 3)
existing_ep.max_concurrent = ep_data.get("max_concurrent")
existing_ep.rate_limit = ep_data.get("rate_limit")
existing_ep.is_active = ep_data.get("is_active", True)
existing_ep.custom_path = ep_data.get("custom_path")
existing_ep.config = ep_data.get("config")
existing_ep.updated_at = datetime.now(timezone.utc)
stats["endpoints"]["updated"] += 1
else:
new_ep = ProviderEndpoint(
id=str(uuid.uuid4()),
provider_id=provider_id,
api_format=ep_data["api_format"],
base_url=ep_data["base_url"],
headers=ep_data.get("headers"),
timeout=ep_data.get("timeout", 300),
max_retries=ep_data.get("max_retries", 3),
max_concurrent=ep_data.get("max_concurrent"),
rate_limit=ep_data.get("rate_limit"),
is_active=ep_data.get("is_active", True),
custom_path=ep_data.get("custom_path"),
config=ep_data.get("config"),
)
db.add(new_ep)
db.flush()
endpoint_id = new_ep.id
stats["endpoints"]["created"] += 1
# 导入 Keys
# 获取当前 endpoint 下所有已有的 keys用于去重
existing_keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
.all()
)
# 解密已有 keys 用于比对
existing_key_values = set()
for ek in existing_keys:
try:
decrypted = crypto_service.decrypt(ek.api_key)
existing_key_values.add(decrypted)
except Exception:
pass
for key_data in ep_data.get("keys", []):
if not key_data.get("api_key"):
stats["errors"].append(
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
)
continue
# 检查是否已存在相同的 Key通过明文比对
if key_data["api_key"] in existing_key_values:
stats["keys"]["skipped"] += 1
continue
encrypted_key = crypto_service.encrypt(key_data["api_key"])
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=endpoint_id,
api_key=encrypted_key,
name=key_data.get("name"),
note=key_data.get("note"),
rate_multiplier=key_data.get("rate_multiplier", 1.0),
internal_priority=key_data.get("internal_priority", 100),
global_priority=key_data.get("global_priority"),
max_concurrent=key_data.get("max_concurrent"),
rate_limit=key_data.get("rate_limit"),
daily_limit=key_data.get("daily_limit"),
monthly_limit=key_data.get("monthly_limit"),
allowed_models=key_data.get("allowed_models"),
capabilities=key_data.get("capabilities"),
is_active=key_data.get("is_active", True),
)
db.add(new_key)
# 添加到已有集合,防止同一批导入中重复
existing_key_values.add(key_data["api_key"])
stats["keys"]["created"] += 1
# 导入 Models
for model_data in prov_data.get("models", []):
global_model_name = model_data.get("global_model_name")
if not global_model_name:
stats["errors"].append(
f"跳过无 global_model_name 的模型 (Provider: {prov_data['name']})"
)
continue
global_model_id = global_model_map.get(global_model_name)
if not global_model_id:
# 尝试从数据库查找
existing_gm = (
db.query(GlobalModel)
.filter(GlobalModel.name == global_model_name)
.first()
)
if existing_gm:
global_model_id = existing_gm.id
else:
stats["errors"].append(
f"GlobalModel '{global_model_name}' 不存在,跳过模型"
)
continue
existing_model = (
db.query(Model)
.filter(
Model.provider_id == provider_id,
Model.provider_model_name == model_data["provider_model_name"],
)
.first()
)
if existing_model:
if merge_mode == "skip":
stats["models"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"Model '{model_data['provider_model_name']}' 已存在于 Provider '{prov_data['name']}'"
)
elif merge_mode == "overwrite":
existing_model.global_model_id = global_model_id
existing_model.provider_model_mappings = model_data.get(
"provider_model_mappings"
)
existing_model.price_per_request = model_data.get(
"price_per_request"
)
existing_model.tiered_pricing = model_data.get(
"tiered_pricing"
)
existing_model.supports_vision = model_data.get(
"supports_vision"
)
existing_model.supports_function_calling = model_data.get(
"supports_function_calling"
)
existing_model.supports_streaming = model_data.get(
"supports_streaming"
)
existing_model.supports_extended_thinking = model_data.get(
"supports_extended_thinking"
)
existing_model.supports_image_generation = model_data.get(
"supports_image_generation"
)
existing_model.is_active = model_data.get("is_active", True)
existing_model.config = model_data.get("config")
existing_model.updated_at = datetime.now(timezone.utc)
stats["models"]["updated"] += 1
else:
new_model = Model(
id=str(uuid.uuid4()),
provider_id=provider_id,
global_model_id=global_model_id,
provider_model_name=model_data["provider_model_name"],
provider_model_mappings=model_data.get(
"provider_model_mappings"
),
price_per_request=model_data.get("price_per_request"),
tiered_pricing=model_data.get("tiered_pricing"),
supports_vision=model_data.get("supports_vision"),
supports_function_calling=model_data.get(
"supports_function_calling"
),
supports_streaming=model_data.get("supports_streaming"),
supports_extended_thinking=model_data.get(
"supports_extended_thinking"
),
supports_image_generation=model_data.get(
"supports_image_generation"
),
is_active=model_data.get("is_active", True),
config=model_data.get("config"),
)
db.add(new_model)
stats["models"]["created"] += 1
db.commit()
# 失效缓存
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.clear_all_caches()
return {
"message": "配置导入成功",
"stats": stats,
}
except InvalidRequestException:
db.rollback()
raise
except Exception as e:
db.rollback()
raise InvalidRequestException(f"导入失败: {str(e)}")
class AdminExportUsersAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导出用户数据(保留加密数据,排除管理员)"""
from datetime import datetime, timezone
from src.core.enums import UserRole
from src.models.database import ApiKey, User
db = context.db
# 导出 Users排除管理员
users = db.query(User).filter(
User.is_deleted.is_(False),
User.role != UserRole.ADMIN
).all()
users_data = []
for user in users:
# 导出用户的 API Keys保留加密数据
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
api_keys_data = []
for key in api_keys:
api_keys_data.append(
{
"key_hash": key.key_hash,
"key_encrypted": key.key_encrypted,
"name": key.name,
"is_standalone": key.is_standalone,
"balance_used_usd": key.balance_used_usd,
"current_balance_usd": key.current_balance_usd,
"allowed_providers": key.allowed_providers,
"allowed_endpoints": key.allowed_endpoints,
"allowed_api_formats": key.allowed_api_formats,
"allowed_models": key.allowed_models,
"rate_limit": key.rate_limit,
"concurrent_limit": key.concurrent_limit,
"force_capabilities": key.force_capabilities,
"is_active": key.is_active,
"auto_delete_on_expiry": key.auto_delete_on_expiry,
"total_requests": key.total_requests,
"total_cost_usd": key.total_cost_usd,
}
)
users_data.append(
{
"email": user.email,
"username": user.username,
"password_hash": user.password_hash,
"role": user.role.value if user.role else "user",
"allowed_providers": user.allowed_providers,
"allowed_endpoints": user.allowed_endpoints,
"allowed_models": user.allowed_models,
"model_capability_settings": user.model_capability_settings,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": user.total_usd,
"is_active": user.is_active,
"api_keys": api_keys_data,
}
)
return {
"version": "1.0",
"exported_at": datetime.now(timezone.utc).isoformat(),
"users": users_data,
}
class AdminImportUsersAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导入用户数据"""
import uuid
from datetime import datetime, timezone
from src.core.enums import UserRole
from src.models.database import ApiKey, User
# 检查请求体大小
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
raise InvalidRequestException("请求体大小不能超过 10MB")
db = context.db
payload = context.ensure_json_body()
# 验证配置版本
version = payload.get("version")
if version != "1.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
users_data = payload.get("users", [])
stats = {
"users": {"created": 0, "updated": 0, "skipped": 0},
"api_keys": {"created": 0, "skipped": 0},
"errors": [],
}
try:
for user_data in users_data:
# 跳过管理员角色的导入(不区分大小写)
role_str = str(user_data.get("role", "")).lower()
if role_str == "admin":
stats["errors"].append(f"跳过管理员用户: {user_data.get('email')}")
stats["users"]["skipped"] += 1
continue
existing_user = (
db.query(User).filter(User.email == user_data["email"]).first()
)
if existing_user:
user_id = existing_user.id
if merge_mode == "skip":
stats["users"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"用户 '{user_data['email']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有用户
existing_user.username = user_data.get(
"username", existing_user.username
)
if user_data.get("password_hash"):
existing_user.password_hash = user_data["password_hash"]
if user_data.get("role"):
existing_user.role = UserRole(user_data["role"])
existing_user.allowed_providers = user_data.get("allowed_providers")
existing_user.allowed_endpoints = user_data.get("allowed_endpoints")
existing_user.allowed_models = user_data.get("allowed_models")
existing_user.model_capability_settings = user_data.get(
"model_capability_settings"
)
existing_user.quota_usd = user_data.get("quota_usd")
existing_user.used_usd = user_data.get("used_usd", 0.0)
existing_user.total_usd = user_data.get("total_usd", 0.0)
existing_user.is_active = user_data.get("is_active", True)
existing_user.updated_at = datetime.now(timezone.utc)
stats["users"]["updated"] += 1
else:
# 创建新用户
role = UserRole.USER
if user_data.get("role"):
role = UserRole(user_data["role"])
new_user = User(
id=str(uuid.uuid4()),
email=user_data["email"],
username=user_data.get("username", user_data["email"].split("@")[0]),
password_hash=user_data.get("password_hash", ""),
role=role,
allowed_providers=user_data.get("allowed_providers"),
allowed_endpoints=user_data.get("allowed_endpoints"),
allowed_models=user_data.get("allowed_models"),
model_capability_settings=user_data.get("model_capability_settings"),
quota_usd=user_data.get("quota_usd"),
used_usd=user_data.get("used_usd", 0.0),
total_usd=user_data.get("total_usd", 0.0),
is_active=user_data.get("is_active", True),
)
db.add(new_user)
db.flush()
user_id = new_user.id
stats["users"]["created"] += 1
# 导入 API Keys
for key_data in user_data.get("api_keys", []):
# 检查是否已存在相同的 key_hash
if key_data.get("key_hash"):
existing_key = (
db.query(ApiKey)
.filter(ApiKey.key_hash == key_data["key_hash"])
.first()
)
if existing_key:
stats["api_keys"]["skipped"] += 1
continue
new_key = ApiKey(
id=str(uuid.uuid4()),
user_id=user_id,
key_hash=key_data.get("key_hash", ""),
key_encrypted=key_data.get("key_encrypted"),
name=key_data.get("name"),
is_standalone=key_data.get("is_standalone", False),
balance_used_usd=key_data.get("balance_used_usd", 0.0),
current_balance_usd=key_data.get("current_balance_usd"),
allowed_providers=key_data.get("allowed_providers"),
allowed_endpoints=key_data.get("allowed_endpoints"),
allowed_api_formats=key_data.get("allowed_api_formats"),
allowed_models=key_data.get("allowed_models"),
rate_limit=key_data.get("rate_limit", 100),
concurrent_limit=key_data.get("concurrent_limit", 5),
force_capabilities=key_data.get("force_capabilities"),
is_active=key_data.get("is_active", True),
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
total_requests=key_data.get("total_requests", 0),
total_cost_usd=key_data.get("total_cost_usd", 0.0),
)
db.add(new_key)
stats["api_keys"]["created"] += 1
db.commit()
return {
"message": "用户数据导入成功",
"stats": stats,
}
except InvalidRequestException:
db.rollback()
raise
except Exception as e:
db.rollback()
raise InvalidRequestException(f"导入失败: {str(e)}")
class AdminTestSmtpAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""测试 SMTP 连接"""
from src.services.system.config import ConfigService
from src.services.verification.email_sender import EmailSenderService
db = context.db
payload = context.ensure_json_body() or {}
# 前端可传入未保存的配置,优先使用前端值,否则回退数据库
config = {
"smtp_host": payload.get("smtp_host") or ConfigService.get_config(db, "smtp_host"),
"smtp_port": payload.get("smtp_port") or ConfigService.get_config(db, "smtp_port", default=587),
"smtp_user": payload.get("smtp_user") or ConfigService.get_config(db, "smtp_user"),
"smtp_password": payload.get("smtp_password") or ConfigService.get_config(db, "smtp_password"),
"smtp_use_tls": payload.get("smtp_use_tls")
if payload.get("smtp_use_tls") is not None
else ConfigService.get_config(db, "smtp_use_tls", default=True),
"smtp_use_ssl": payload.get("smtp_use_ssl")
if payload.get("smtp_use_ssl") is not None
else ConfigService.get_config(db, "smtp_use_ssl", default=False),
"smtp_from_email": payload.get("smtp_from_email")
or ConfigService.get_config(db, "smtp_from_email"),
"smtp_from_name": payload.get("smtp_from_name")
or ConfigService.get_config(db, "smtp_from_name", default="Aether"),
}
# 验证必要配置
missing_fields = [
field for field in ["smtp_host", "smtp_user", "smtp_password", "smtp_from_email"] if not config.get(field)
]
if missing_fields:
return {
"success": False,
"message": f"SMTP 配置不完整,请检查 {', '.join(missing_fields)}"
}
# 测试连接
try:
success, error_msg = await EmailSenderService.test_smtp_connection(
db=db, override_config=config
)
if success:
return {
"success": True,
"message": "SMTP 连接测试成功"
}
else:
return {
"success": False,
"message": f"SMTP 连接测试失败: {error_msg}"
}
except Exception as e:
return {
"success": False,
"message": f"SMTP 连接测试失败: {str(e)}"
}