mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
添加完整的邮箱验证注册系统,包括验证码发送、验证和限流控制: - 新增邮箱验证服务模块(email_sender, email_template, email_verification) - 更新认证API支持邮箱验证注册流程 - 添加注册对话框和验证码输入组件 - 完善IP限流器支持邮箱验证场景 - 修复前端类型定义问题,升级esbuild依赖 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1154 lines
50 KiB
Python
1154 lines
50 KiB
Python
"""系统设置API端点。"""
|
||
|
||
from dataclasses import dataclass
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||
from pydantic import ValidationError
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.api.base.admin_adapter import AdminApiAdapter
|
||
from src.api.base.pipeline import ApiRequestPipeline
|
||
from src.core.exceptions import InvalidRequestException, NotFoundException, translate_pydantic_error
|
||
from src.database import get_db
|
||
from src.models.api import SystemSettingsRequest, SystemSettingsResponse
|
||
from src.models.database import ApiKey, Provider, Usage, User
|
||
from src.services.system.config import SystemConfigService
|
||
|
||
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
|
||
pipeline = ApiRequestPipeline()
|
||
|
||
|
||
@router.get("/settings")
|
||
async def get_system_settings(request: Request, db: Session = Depends(get_db)):
|
||
"""获取系统设置(管理员)"""
|
||
|
||
adapter = AdminGetSystemSettingsAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.put("/settings")
|
||
async def update_system_settings(http_request: Request, db: Session = Depends(get_db)):
|
||
"""更新系统设置(管理员)"""
|
||
|
||
adapter = AdminUpdateSystemSettingsAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/configs")
|
||
async def get_all_system_configs(request: Request, db: Session = Depends(get_db)):
|
||
"""获取所有系统配置(管理员)"""
|
||
|
||
adapter = AdminGetAllConfigsAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/configs/{key}")
|
||
async def get_system_config(key: str, request: Request, db: Session = Depends(get_db)):
|
||
"""获取特定系统配置(管理员)"""
|
||
|
||
adapter = AdminGetSystemConfigAdapter(key=key)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.put("/configs/{key}")
|
||
async def set_system_config(
|
||
key: str,
|
||
request: Request,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""设置系统配置(管理员)"""
|
||
|
||
adapter = AdminSetSystemConfigAdapter(key=key)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.delete("/configs/{key}")
|
||
async def delete_system_config(key: str, request: Request, db: Session = Depends(get_db)):
|
||
"""删除系统配置(管理员)"""
|
||
|
||
adapter = AdminDeleteSystemConfigAdapter(key=key)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/stats")
|
||
async def get_system_stats(request: Request, db: Session = Depends(get_db)):
|
||
adapter = AdminSystemStatsAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.post("/cleanup")
|
||
async def trigger_cleanup(request: Request, db: Session = Depends(get_db)):
|
||
"""Manually trigger usage record cleanup task"""
|
||
adapter = AdminTriggerCleanupAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/api-formats")
|
||
async def get_api_formats(request: Request, db: Session = Depends(get_db)):
|
||
"""获取所有可用的API格式列表"""
|
||
adapter = AdminGetApiFormatsAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/config/export")
|
||
async def export_config(request: Request, db: Session = Depends(get_db)):
|
||
"""导出提供商和模型配置(管理员)"""
|
||
adapter = AdminExportConfigAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.post("/config/import")
|
||
async def import_config(request: Request, db: Session = Depends(get_db)):
|
||
"""导入提供商和模型配置(管理员)"""
|
||
adapter = AdminImportConfigAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/users/export")
|
||
async def export_users(request: Request, db: Session = Depends(get_db)):
|
||
"""导出用户数据(管理员)"""
|
||
adapter = AdminExportUsersAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.post("/users/import")
|
||
async def import_users(request: Request, db: Session = Depends(get_db)):
|
||
"""导入用户数据(管理员)"""
|
||
adapter = AdminImportUsersAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.post("/smtp/test")
|
||
async def test_smtp(request: Request, db: Session = Depends(get_db)):
|
||
"""测试 SMTP 连接(管理员)"""
|
||
adapter = AdminTestSmtpAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
# -------- 系统设置适配器 --------
|
||
|
||
|
||
class AdminGetSystemSettingsAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
default_provider = SystemConfigService.get_default_provider(db)
|
||
default_model = SystemConfigService.get_config(db, "default_model")
|
||
enable_usage_tracking = (
|
||
SystemConfigService.get_config(db, "enable_usage_tracking", "true") == "true"
|
||
)
|
||
|
||
return SystemSettingsResponse(
|
||
default_provider=default_provider,
|
||
default_model=default_model,
|
||
enable_usage_tracking=enable_usage_tracking,
|
||
)
|
||
|
||
|
||
class AdminUpdateSystemSettingsAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
payload = context.ensure_json_body()
|
||
try:
|
||
settings_request = SystemSettingsRequest.model_validate(payload)
|
||
except ValidationError as e:
|
||
errors = e.errors()
|
||
if errors:
|
||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||
raise InvalidRequestException("请求数据验证失败")
|
||
|
||
if settings_request.default_provider is not None:
|
||
provider = (
|
||
db.query(Provider)
|
||
.filter(
|
||
Provider.name == settings_request.default_provider,
|
||
Provider.is_active.is_(True),
|
||
)
|
||
.first()
|
||
)
|
||
|
||
if not provider and settings_request.default_provider != "":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"提供商 '{settings_request.default_provider}' 不存在或未启用",
|
||
)
|
||
|
||
if settings_request.default_provider:
|
||
SystemConfigService.set_default_provider(db, settings_request.default_provider)
|
||
else:
|
||
SystemConfigService.delete_config(db, "default_provider")
|
||
|
||
if settings_request.default_model is not None:
|
||
if settings_request.default_model:
|
||
SystemConfigService.set_config(db, "default_model", settings_request.default_model)
|
||
else:
|
||
SystemConfigService.delete_config(db, "default_model")
|
||
|
||
if settings_request.enable_usage_tracking is not None:
|
||
SystemConfigService.set_config(
|
||
db,
|
||
"enable_usage_tracking",
|
||
str(settings_request.enable_usage_tracking).lower(),
|
||
)
|
||
|
||
return {"message": "系统设置更新成功"}
|
||
|
||
|
||
class AdminGetAllConfigsAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
return SystemConfigService.get_all_configs(context.db)
|
||
|
||
|
||
@dataclass
|
||
class AdminGetSystemConfigAdapter(AdminApiAdapter):
|
||
key: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
value = SystemConfigService.get_config(context.db, self.key)
|
||
if value is None:
|
||
raise NotFoundException(f"配置项 '{self.key}' 不存在")
|
||
return {"key": self.key, "value": value}
|
||
|
||
|
||
@dataclass
|
||
class AdminSetSystemConfigAdapter(AdminApiAdapter):
|
||
key: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
payload = context.ensure_json_body()
|
||
config = SystemConfigService.set_config(
|
||
context.db,
|
||
self.key,
|
||
payload.get("value"),
|
||
payload.get("description"),
|
||
)
|
||
|
||
return {
|
||
"key": config.key,
|
||
"value": config.value,
|
||
"description": config.description,
|
||
"updated_at": config.updated_at.isoformat(),
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class AdminDeleteSystemConfigAdapter(AdminApiAdapter):
|
||
key: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
deleted = SystemConfigService.delete_config(context.db, self.key)
|
||
if not deleted:
|
||
raise NotFoundException(f"配置项 '{self.key}' 不存在")
|
||
return {"message": f"配置项 '{self.key}' 已删除"}
|
||
|
||
|
||
class AdminSystemStatsAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
total_users = db.query(User).count()
|
||
active_users = db.query(User).filter(User.is_active.is_(True)).count()
|
||
total_providers = db.query(Provider).count()
|
||
active_providers = db.query(Provider).filter(Provider.is_active.is_(True)).count()
|
||
total_api_keys = db.query(ApiKey).count()
|
||
total_requests = db.query(Usage).count()
|
||
|
||
return {
|
||
"users": {"total": total_users, "active": active_users},
|
||
"providers": {"total": total_providers, "active": active_providers},
|
||
"api_keys": total_api_keys,
|
||
"requests": total_requests,
|
||
}
|
||
|
||
|
||
class AdminTriggerCleanupAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
"""手动触发清理任务"""
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
from sqlalchemy import func
|
||
|
||
from src.services.system.cleanup_scheduler import get_cleanup_scheduler
|
||
|
||
db = context.db
|
||
|
||
# 获取清理前的统计信息
|
||
total_before = db.query(Usage).count()
|
||
with_body_before = (
|
||
db.query(Usage)
|
||
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
|
||
.count()
|
||
)
|
||
with_headers_before = (
|
||
db.query(Usage)
|
||
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
|
||
.count()
|
||
)
|
||
|
||
# 触发清理
|
||
cleanup_scheduler = get_cleanup_scheduler()
|
||
await cleanup_scheduler._perform_cleanup()
|
||
|
||
# 获取清理后的统计信息
|
||
total_after = db.query(Usage).count()
|
||
with_body_after = (
|
||
db.query(Usage)
|
||
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
|
||
.count()
|
||
)
|
||
with_headers_after = (
|
||
db.query(Usage)
|
||
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
|
||
.count()
|
||
)
|
||
|
||
return {
|
||
"message": "清理任务执行完成",
|
||
"stats": {
|
||
"total_records": {
|
||
"before": total_before,
|
||
"after": total_after,
|
||
"deleted": total_before - total_after,
|
||
},
|
||
"body_fields": {
|
||
"before": with_body_before,
|
||
"after": with_body_after,
|
||
"cleaned": with_body_before - with_body_after,
|
||
},
|
||
"header_fields": {
|
||
"before": with_headers_before,
|
||
"after": with_headers_after,
|
||
"cleaned": with_headers_before - with_headers_after,
|
||
},
|
||
},
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
}
|
||
|
||
|
||
class AdminGetApiFormatsAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
"""获取所有可用的API格式"""
|
||
from src.core.api_format_metadata import API_FORMAT_DEFINITIONS
|
||
from src.core.enums import APIFormat
|
||
|
||
_ = context # 参数保留以符合接口规范
|
||
|
||
formats = []
|
||
for api_format in APIFormat:
|
||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||
formats.append(
|
||
{
|
||
"value": api_format.value,
|
||
"label": api_format.value,
|
||
"default_path": definition.default_path if definition else "/",
|
||
"aliases": list(definition.aliases) if definition else [],
|
||
}
|
||
)
|
||
|
||
return {"formats": formats}
|
||
|
||
|
||
class AdminExportConfigAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
"""导出提供商和模型配置(解密数据)"""
|
||
from datetime import datetime, timezone
|
||
|
||
from src.core.crypto import crypto_service
|
||
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
|
||
|
||
db = context.db
|
||
|
||
# 导出 GlobalModels
|
||
global_models = db.query(GlobalModel).all()
|
||
global_models_data = []
|
||
for gm in global_models:
|
||
global_models_data.append(
|
||
{
|
||
"name": gm.name,
|
||
"display_name": gm.display_name,
|
||
"default_price_per_request": gm.default_price_per_request,
|
||
"default_tiered_pricing": gm.default_tiered_pricing,
|
||
"supported_capabilities": gm.supported_capabilities,
|
||
"config": gm.config,
|
||
"is_active": gm.is_active,
|
||
}
|
||
)
|
||
|
||
# 导出 Providers 及其关联数据
|
||
providers = db.query(Provider).all()
|
||
providers_data = []
|
||
for provider in providers:
|
||
# 导出 Endpoints
|
||
endpoints = (
|
||
db.query(ProviderEndpoint)
|
||
.filter(ProviderEndpoint.provider_id == provider.id)
|
||
.all()
|
||
)
|
||
endpoints_data = []
|
||
for ep in endpoints:
|
||
# 导出 Endpoint Keys
|
||
keys = (
|
||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
|
||
)
|
||
keys_data = []
|
||
for key in keys:
|
||
# 解密 API Key
|
||
try:
|
||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||
except Exception:
|
||
decrypted_key = ""
|
||
|
||
keys_data.append(
|
||
{
|
||
"api_key": decrypted_key,
|
||
"name": key.name,
|
||
"note": key.note,
|
||
"rate_multiplier": key.rate_multiplier,
|
||
"internal_priority": key.internal_priority,
|
||
"global_priority": key.global_priority,
|
||
"max_concurrent": key.max_concurrent,
|
||
"rate_limit": key.rate_limit,
|
||
"daily_limit": key.daily_limit,
|
||
"monthly_limit": key.monthly_limit,
|
||
"allowed_models": key.allowed_models,
|
||
"capabilities": key.capabilities,
|
||
"is_active": key.is_active,
|
||
}
|
||
)
|
||
|
||
endpoints_data.append(
|
||
{
|
||
"api_format": ep.api_format,
|
||
"base_url": ep.base_url,
|
||
"headers": ep.headers,
|
||
"timeout": ep.timeout,
|
||
"max_retries": ep.max_retries,
|
||
"max_concurrent": ep.max_concurrent,
|
||
"rate_limit": ep.rate_limit,
|
||
"is_active": ep.is_active,
|
||
"custom_path": ep.custom_path,
|
||
"config": ep.config,
|
||
"keys": keys_data,
|
||
}
|
||
)
|
||
|
||
# 导出 Provider Models
|
||
models = db.query(Model).filter(Model.provider_id == provider.id).all()
|
||
models_data = []
|
||
for model in models:
|
||
# 获取关联的 GlobalModel 名称
|
||
global_model = (
|
||
db.query(GlobalModel).filter(GlobalModel.id == model.global_model_id).first()
|
||
)
|
||
models_data.append(
|
||
{
|
||
"global_model_name": global_model.name if global_model else None,
|
||
"provider_model_name": model.provider_model_name,
|
||
"provider_model_mappings": model.provider_model_mappings,
|
||
"price_per_request": model.price_per_request,
|
||
"tiered_pricing": model.tiered_pricing,
|
||
"supports_vision": model.supports_vision,
|
||
"supports_function_calling": model.supports_function_calling,
|
||
"supports_streaming": model.supports_streaming,
|
||
"supports_extended_thinking": model.supports_extended_thinking,
|
||
"supports_image_generation": model.supports_image_generation,
|
||
"is_active": model.is_active,
|
||
"config": model.config,
|
||
}
|
||
)
|
||
|
||
providers_data.append(
|
||
{
|
||
"name": provider.name,
|
||
"display_name": provider.display_name,
|
||
"description": provider.description,
|
||
"website": provider.website,
|
||
"billing_type": provider.billing_type.value if provider.billing_type else None,
|
||
"monthly_quota_usd": provider.monthly_quota_usd,
|
||
"quota_reset_day": provider.quota_reset_day,
|
||
"rpm_limit": provider.rpm_limit,
|
||
"provider_priority": provider.provider_priority,
|
||
"is_active": provider.is_active,
|
||
"rate_limit": provider.rate_limit,
|
||
"concurrent_limit": provider.concurrent_limit,
|
||
"config": provider.config,
|
||
"endpoints": endpoints_data,
|
||
"models": models_data,
|
||
}
|
||
)
|
||
|
||
return {
|
||
"version": "1.0",
|
||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||
"global_models": global_models_data,
|
||
"providers": providers_data,
|
||
}
|
||
|
||
|
||
MAX_IMPORT_SIZE = 10 * 1024 * 1024 # 10MB
|
||
|
||
|
||
class AdminImportConfigAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
"""导入提供商和模型配置"""
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
|
||
from src.core.crypto import crypto_service
|
||
from src.core.enums import ProviderBillingType
|
||
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
|
||
|
||
# 检查请求体大小
|
||
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
|
||
raise InvalidRequestException("请求体大小不能超过 10MB")
|
||
|
||
db = context.db
|
||
payload = context.ensure_json_body()
|
||
|
||
# 验证配置版本
|
||
version = payload.get("version")
|
||
if version != "1.0":
|
||
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||
|
||
# 获取导入选项
|
||
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
||
global_models_data = payload.get("global_models", [])
|
||
providers_data = payload.get("providers", [])
|
||
|
||
stats = {
|
||
"global_models": {"created": 0, "updated": 0, "skipped": 0},
|
||
"providers": {"created": 0, "updated": 0, "skipped": 0},
|
||
"endpoints": {"created": 0, "updated": 0, "skipped": 0},
|
||
"keys": {"created": 0, "updated": 0, "skipped": 0},
|
||
"models": {"created": 0, "updated": 0, "skipped": 0},
|
||
"errors": [],
|
||
}
|
||
|
||
try:
|
||
# 导入 GlobalModels
|
||
global_model_map = {} # name -> id 映射
|
||
for gm_data in global_models_data:
|
||
existing = (
|
||
db.query(GlobalModel).filter(GlobalModel.name == gm_data["name"]).first()
|
||
)
|
||
|
||
if existing:
|
||
global_model_map[gm_data["name"]] = existing.id
|
||
if merge_mode == "skip":
|
||
stats["global_models"]["skipped"] += 1
|
||
continue
|
||
elif merge_mode == "error":
|
||
raise InvalidRequestException(
|
||
f"GlobalModel '{gm_data['name']}' 已存在"
|
||
)
|
||
elif merge_mode == "overwrite":
|
||
# 更新现有记录
|
||
existing.display_name = gm_data.get(
|
||
"display_name", existing.display_name
|
||
)
|
||
existing.default_price_per_request = gm_data.get(
|
||
"default_price_per_request"
|
||
)
|
||
existing.default_tiered_pricing = gm_data.get(
|
||
"default_tiered_pricing", existing.default_tiered_pricing
|
||
)
|
||
existing.supported_capabilities = gm_data.get(
|
||
"supported_capabilities"
|
||
)
|
||
existing.config = gm_data.get("config")
|
||
existing.is_active = gm_data.get("is_active", True)
|
||
existing.updated_at = datetime.now(timezone.utc)
|
||
stats["global_models"]["updated"] += 1
|
||
else:
|
||
# 创建新记录
|
||
new_gm = GlobalModel(
|
||
id=str(uuid.uuid4()),
|
||
name=gm_data["name"],
|
||
display_name=gm_data.get("display_name", gm_data["name"]),
|
||
default_price_per_request=gm_data.get("default_price_per_request"),
|
||
default_tiered_pricing=gm_data.get(
|
||
"default_tiered_pricing",
|
||
{"tiers": [{"up_to": None, "input_price_per_1m": 0, "output_price_per_1m": 0}]},
|
||
),
|
||
supported_capabilities=gm_data.get("supported_capabilities"),
|
||
config=gm_data.get("config"),
|
||
is_active=gm_data.get("is_active", True),
|
||
)
|
||
db.add(new_gm)
|
||
db.flush()
|
||
global_model_map[gm_data["name"]] = new_gm.id
|
||
stats["global_models"]["created"] += 1
|
||
|
||
# 导入 Providers
|
||
for prov_data in providers_data:
|
||
existing_provider = (
|
||
db.query(Provider).filter(Provider.name == prov_data["name"]).first()
|
||
)
|
||
|
||
if existing_provider:
|
||
provider_id = existing_provider.id
|
||
if merge_mode == "skip":
|
||
stats["providers"]["skipped"] += 1
|
||
# 仍然需要处理 endpoints 和 models(如果存在)
|
||
elif merge_mode == "error":
|
||
raise InvalidRequestException(
|
||
f"Provider '{prov_data['name']}' 已存在"
|
||
)
|
||
elif merge_mode == "overwrite":
|
||
# 更新现有记录
|
||
existing_provider.display_name = prov_data.get(
|
||
"display_name", existing_provider.display_name
|
||
)
|
||
existing_provider.description = prov_data.get("description")
|
||
existing_provider.website = prov_data.get("website")
|
||
if prov_data.get("billing_type"):
|
||
existing_provider.billing_type = ProviderBillingType(
|
||
prov_data["billing_type"]
|
||
)
|
||
existing_provider.monthly_quota_usd = prov_data.get(
|
||
"monthly_quota_usd"
|
||
)
|
||
existing_provider.quota_reset_day = prov_data.get(
|
||
"quota_reset_day", 30
|
||
)
|
||
existing_provider.rpm_limit = prov_data.get("rpm_limit")
|
||
existing_provider.provider_priority = prov_data.get(
|
||
"provider_priority", 100
|
||
)
|
||
existing_provider.is_active = prov_data.get("is_active", True)
|
||
existing_provider.rate_limit = prov_data.get("rate_limit")
|
||
existing_provider.concurrent_limit = prov_data.get(
|
||
"concurrent_limit"
|
||
)
|
||
existing_provider.config = prov_data.get("config")
|
||
existing_provider.updated_at = datetime.now(timezone.utc)
|
||
stats["providers"]["updated"] += 1
|
||
else:
|
||
# 创建新 Provider
|
||
billing_type = ProviderBillingType.PAY_AS_YOU_GO
|
||
if prov_data.get("billing_type"):
|
||
billing_type = ProviderBillingType(prov_data["billing_type"])
|
||
|
||
new_provider = Provider(
|
||
id=str(uuid.uuid4()),
|
||
name=prov_data["name"],
|
||
display_name=prov_data.get("display_name", prov_data["name"]),
|
||
description=prov_data.get("description"),
|
||
website=prov_data.get("website"),
|
||
billing_type=billing_type,
|
||
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
|
||
quota_reset_day=prov_data.get("quota_reset_day", 30),
|
||
rpm_limit=prov_data.get("rpm_limit"),
|
||
provider_priority=prov_data.get("provider_priority", 100),
|
||
is_active=prov_data.get("is_active", True),
|
||
rate_limit=prov_data.get("rate_limit"),
|
||
concurrent_limit=prov_data.get("concurrent_limit"),
|
||
config=prov_data.get("config"),
|
||
)
|
||
db.add(new_provider)
|
||
db.flush()
|
||
provider_id = new_provider.id
|
||
stats["providers"]["created"] += 1
|
||
|
||
# 导入 Endpoints
|
||
for ep_data in prov_data.get("endpoints", []):
|
||
existing_ep = (
|
||
db.query(ProviderEndpoint)
|
||
.filter(
|
||
ProviderEndpoint.provider_id == provider_id,
|
||
ProviderEndpoint.api_format == ep_data["api_format"],
|
||
)
|
||
.first()
|
||
)
|
||
|
||
if existing_ep:
|
||
endpoint_id = existing_ep.id
|
||
if merge_mode == "skip":
|
||
stats["endpoints"]["skipped"] += 1
|
||
elif merge_mode == "error":
|
||
raise InvalidRequestException(
|
||
f"Endpoint '{ep_data['api_format']}' 已存在于 Provider '{prov_data['name']}'"
|
||
)
|
||
elif merge_mode == "overwrite":
|
||
existing_ep.base_url = ep_data.get(
|
||
"base_url", existing_ep.base_url
|
||
)
|
||
existing_ep.headers = ep_data.get("headers")
|
||
existing_ep.timeout = ep_data.get("timeout", 300)
|
||
existing_ep.max_retries = ep_data.get("max_retries", 3)
|
||
existing_ep.max_concurrent = ep_data.get("max_concurrent")
|
||
existing_ep.rate_limit = ep_data.get("rate_limit")
|
||
existing_ep.is_active = ep_data.get("is_active", True)
|
||
existing_ep.custom_path = ep_data.get("custom_path")
|
||
existing_ep.config = ep_data.get("config")
|
||
existing_ep.updated_at = datetime.now(timezone.utc)
|
||
stats["endpoints"]["updated"] += 1
|
||
else:
|
||
new_ep = ProviderEndpoint(
|
||
id=str(uuid.uuid4()),
|
||
provider_id=provider_id,
|
||
api_format=ep_data["api_format"],
|
||
base_url=ep_data["base_url"],
|
||
headers=ep_data.get("headers"),
|
||
timeout=ep_data.get("timeout", 300),
|
||
max_retries=ep_data.get("max_retries", 3),
|
||
max_concurrent=ep_data.get("max_concurrent"),
|
||
rate_limit=ep_data.get("rate_limit"),
|
||
is_active=ep_data.get("is_active", True),
|
||
custom_path=ep_data.get("custom_path"),
|
||
config=ep_data.get("config"),
|
||
)
|
||
db.add(new_ep)
|
||
db.flush()
|
||
endpoint_id = new_ep.id
|
||
stats["endpoints"]["created"] += 1
|
||
|
||
# 导入 Keys
|
||
# 获取当前 endpoint 下所有已有的 keys,用于去重
|
||
existing_keys = (
|
||
db.query(ProviderAPIKey)
|
||
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
|
||
.all()
|
||
)
|
||
# 解密已有 keys 用于比对
|
||
existing_key_values = set()
|
||
for ek in existing_keys:
|
||
try:
|
||
decrypted = crypto_service.decrypt(ek.api_key)
|
||
existing_key_values.add(decrypted)
|
||
except Exception:
|
||
pass
|
||
|
||
for key_data in ep_data.get("keys", []):
|
||
if not key_data.get("api_key"):
|
||
stats["errors"].append(
|
||
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
|
||
)
|
||
continue
|
||
|
||
# 检查是否已存在相同的 Key(通过明文比对)
|
||
if key_data["api_key"] in existing_key_values:
|
||
stats["keys"]["skipped"] += 1
|
||
continue
|
||
|
||
encrypted_key = crypto_service.encrypt(key_data["api_key"])
|
||
|
||
new_key = ProviderAPIKey(
|
||
id=str(uuid.uuid4()),
|
||
endpoint_id=endpoint_id,
|
||
api_key=encrypted_key,
|
||
name=key_data.get("name"),
|
||
note=key_data.get("note"),
|
||
rate_multiplier=key_data.get("rate_multiplier", 1.0),
|
||
internal_priority=key_data.get("internal_priority", 100),
|
||
global_priority=key_data.get("global_priority"),
|
||
max_concurrent=key_data.get("max_concurrent"),
|
||
rate_limit=key_data.get("rate_limit"),
|
||
daily_limit=key_data.get("daily_limit"),
|
||
monthly_limit=key_data.get("monthly_limit"),
|
||
allowed_models=key_data.get("allowed_models"),
|
||
capabilities=key_data.get("capabilities"),
|
||
is_active=key_data.get("is_active", True),
|
||
)
|
||
db.add(new_key)
|
||
# 添加到已有集合,防止同一批导入中重复
|
||
existing_key_values.add(key_data["api_key"])
|
||
stats["keys"]["created"] += 1
|
||
|
||
# 导入 Models
|
||
for model_data in prov_data.get("models", []):
|
||
global_model_name = model_data.get("global_model_name")
|
||
if not global_model_name:
|
||
stats["errors"].append(
|
||
f"跳过无 global_model_name 的模型 (Provider: {prov_data['name']})"
|
||
)
|
||
continue
|
||
|
||
global_model_id = global_model_map.get(global_model_name)
|
||
if not global_model_id:
|
||
# 尝试从数据库查找
|
||
existing_gm = (
|
||
db.query(GlobalModel)
|
||
.filter(GlobalModel.name == global_model_name)
|
||
.first()
|
||
)
|
||
if existing_gm:
|
||
global_model_id = existing_gm.id
|
||
else:
|
||
stats["errors"].append(
|
||
f"GlobalModel '{global_model_name}' 不存在,跳过模型"
|
||
)
|
||
continue
|
||
|
||
existing_model = (
|
||
db.query(Model)
|
||
.filter(
|
||
Model.provider_id == provider_id,
|
||
Model.provider_model_name == model_data["provider_model_name"],
|
||
)
|
||
.first()
|
||
)
|
||
|
||
if existing_model:
|
||
if merge_mode == "skip":
|
||
stats["models"]["skipped"] += 1
|
||
elif merge_mode == "error":
|
||
raise InvalidRequestException(
|
||
f"Model '{model_data['provider_model_name']}' 已存在于 Provider '{prov_data['name']}'"
|
||
)
|
||
elif merge_mode == "overwrite":
|
||
existing_model.global_model_id = global_model_id
|
||
existing_model.provider_model_mappings = model_data.get(
|
||
"provider_model_mappings"
|
||
)
|
||
existing_model.price_per_request = model_data.get(
|
||
"price_per_request"
|
||
)
|
||
existing_model.tiered_pricing = model_data.get(
|
||
"tiered_pricing"
|
||
)
|
||
existing_model.supports_vision = model_data.get(
|
||
"supports_vision"
|
||
)
|
||
existing_model.supports_function_calling = model_data.get(
|
||
"supports_function_calling"
|
||
)
|
||
existing_model.supports_streaming = model_data.get(
|
||
"supports_streaming"
|
||
)
|
||
existing_model.supports_extended_thinking = model_data.get(
|
||
"supports_extended_thinking"
|
||
)
|
||
existing_model.supports_image_generation = model_data.get(
|
||
"supports_image_generation"
|
||
)
|
||
existing_model.is_active = model_data.get("is_active", True)
|
||
existing_model.config = model_data.get("config")
|
||
existing_model.updated_at = datetime.now(timezone.utc)
|
||
stats["models"]["updated"] += 1
|
||
else:
|
||
new_model = Model(
|
||
id=str(uuid.uuid4()),
|
||
provider_id=provider_id,
|
||
global_model_id=global_model_id,
|
||
provider_model_name=model_data["provider_model_name"],
|
||
provider_model_mappings=model_data.get(
|
||
"provider_model_mappings"
|
||
),
|
||
price_per_request=model_data.get("price_per_request"),
|
||
tiered_pricing=model_data.get("tiered_pricing"),
|
||
supports_vision=model_data.get("supports_vision"),
|
||
supports_function_calling=model_data.get(
|
||
"supports_function_calling"
|
||
),
|
||
supports_streaming=model_data.get("supports_streaming"),
|
||
supports_extended_thinking=model_data.get(
|
||
"supports_extended_thinking"
|
||
),
|
||
supports_image_generation=model_data.get(
|
||
"supports_image_generation"
|
||
),
|
||
is_active=model_data.get("is_active", True),
|
||
config=model_data.get("config"),
|
||
)
|
||
db.add(new_model)
|
||
stats["models"]["created"] += 1
|
||
|
||
db.commit()
|
||
|
||
# 失效缓存
|
||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||
|
||
cache_service = get_cache_invalidation_service()
|
||
cache_service.clear_all_caches()
|
||
|
||
return {
|
||
"message": "配置导入成功",
|
||
"stats": stats,
|
||
}
|
||
|
||
except InvalidRequestException:
|
||
db.rollback()
|
||
raise
|
||
except Exception as e:
|
||
db.rollback()
|
||
raise InvalidRequestException(f"导入失败: {str(e)}")
|
||
|
||
|
||
class AdminExportUsersAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
"""导出用户数据(保留加密数据,排除管理员)"""
|
||
from datetime import datetime, timezone
|
||
|
||
from src.core.enums import UserRole
|
||
from src.models.database import ApiKey, User
|
||
|
||
db = context.db
|
||
|
||
# 导出 Users(排除管理员)
|
||
users = db.query(User).filter(
|
||
User.is_deleted.is_(False),
|
||
User.role != UserRole.ADMIN
|
||
).all()
|
||
users_data = []
|
||
for user in users:
|
||
# 导出用户的 API Keys(保留加密数据)
|
||
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
|
||
api_keys_data = []
|
||
for key in api_keys:
|
||
api_keys_data.append(
|
||
{
|
||
"key_hash": key.key_hash,
|
||
"key_encrypted": key.key_encrypted,
|
||
"name": key.name,
|
||
"is_standalone": key.is_standalone,
|
||
"balance_used_usd": key.balance_used_usd,
|
||
"current_balance_usd": key.current_balance_usd,
|
||
"allowed_providers": key.allowed_providers,
|
||
"allowed_endpoints": key.allowed_endpoints,
|
||
"allowed_api_formats": key.allowed_api_formats,
|
||
"allowed_models": key.allowed_models,
|
||
"rate_limit": key.rate_limit,
|
||
"concurrent_limit": key.concurrent_limit,
|
||
"force_capabilities": key.force_capabilities,
|
||
"is_active": key.is_active,
|
||
"auto_delete_on_expiry": key.auto_delete_on_expiry,
|
||
"total_requests": key.total_requests,
|
||
"total_cost_usd": key.total_cost_usd,
|
||
}
|
||
)
|
||
|
||
users_data.append(
|
||
{
|
||
"email": user.email,
|
||
"username": user.username,
|
||
"password_hash": user.password_hash,
|
||
"role": user.role.value if user.role else "user",
|
||
"allowed_providers": user.allowed_providers,
|
||
"allowed_endpoints": user.allowed_endpoints,
|
||
"allowed_models": user.allowed_models,
|
||
"model_capability_settings": user.model_capability_settings,
|
||
"quota_usd": user.quota_usd,
|
||
"used_usd": user.used_usd,
|
||
"total_usd": user.total_usd,
|
||
"is_active": user.is_active,
|
||
"api_keys": api_keys_data,
|
||
}
|
||
)
|
||
|
||
return {
|
||
"version": "1.0",
|
||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||
"users": users_data,
|
||
}
|
||
|
||
|
||
class AdminImportUsersAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
"""导入用户数据"""
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
|
||
from src.core.enums import UserRole
|
||
from src.models.database import ApiKey, User
|
||
|
||
# 检查请求体大小
|
||
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
|
||
raise InvalidRequestException("请求体大小不能超过 10MB")
|
||
|
||
db = context.db
|
||
payload = context.ensure_json_body()
|
||
|
||
# 验证配置版本
|
||
version = payload.get("version")
|
||
if version != "1.0":
|
||
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||
|
||
# 获取导入选项
|
||
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
||
users_data = payload.get("users", [])
|
||
|
||
stats = {
|
||
"users": {"created": 0, "updated": 0, "skipped": 0},
|
||
"api_keys": {"created": 0, "skipped": 0},
|
||
"errors": [],
|
||
}
|
||
|
||
try:
|
||
for user_data in users_data:
|
||
# 跳过管理员角色的导入(不区分大小写)
|
||
role_str = str(user_data.get("role", "")).lower()
|
||
if role_str == "admin":
|
||
stats["errors"].append(f"跳过管理员用户: {user_data.get('email')}")
|
||
stats["users"]["skipped"] += 1
|
||
continue
|
||
|
||
existing_user = (
|
||
db.query(User).filter(User.email == user_data["email"]).first()
|
||
)
|
||
|
||
if existing_user:
|
||
user_id = existing_user.id
|
||
if merge_mode == "skip":
|
||
stats["users"]["skipped"] += 1
|
||
elif merge_mode == "error":
|
||
raise InvalidRequestException(
|
||
f"用户 '{user_data['email']}' 已存在"
|
||
)
|
||
elif merge_mode == "overwrite":
|
||
# 更新现有用户
|
||
existing_user.username = user_data.get(
|
||
"username", existing_user.username
|
||
)
|
||
if user_data.get("password_hash"):
|
||
existing_user.password_hash = user_data["password_hash"]
|
||
if user_data.get("role"):
|
||
existing_user.role = UserRole(user_data["role"])
|
||
existing_user.allowed_providers = user_data.get("allowed_providers")
|
||
existing_user.allowed_endpoints = user_data.get("allowed_endpoints")
|
||
existing_user.allowed_models = user_data.get("allowed_models")
|
||
existing_user.model_capability_settings = user_data.get(
|
||
"model_capability_settings"
|
||
)
|
||
existing_user.quota_usd = user_data.get("quota_usd")
|
||
existing_user.used_usd = user_data.get("used_usd", 0.0)
|
||
existing_user.total_usd = user_data.get("total_usd", 0.0)
|
||
existing_user.is_active = user_data.get("is_active", True)
|
||
existing_user.updated_at = datetime.now(timezone.utc)
|
||
stats["users"]["updated"] += 1
|
||
else:
|
||
# 创建新用户
|
||
role = UserRole.USER
|
||
if user_data.get("role"):
|
||
role = UserRole(user_data["role"])
|
||
|
||
new_user = User(
|
||
id=str(uuid.uuid4()),
|
||
email=user_data["email"],
|
||
username=user_data.get("username", user_data["email"].split("@")[0]),
|
||
password_hash=user_data.get("password_hash", ""),
|
||
role=role,
|
||
allowed_providers=user_data.get("allowed_providers"),
|
||
allowed_endpoints=user_data.get("allowed_endpoints"),
|
||
allowed_models=user_data.get("allowed_models"),
|
||
model_capability_settings=user_data.get("model_capability_settings"),
|
||
quota_usd=user_data.get("quota_usd"),
|
||
used_usd=user_data.get("used_usd", 0.0),
|
||
total_usd=user_data.get("total_usd", 0.0),
|
||
is_active=user_data.get("is_active", True),
|
||
)
|
||
db.add(new_user)
|
||
db.flush()
|
||
user_id = new_user.id
|
||
stats["users"]["created"] += 1
|
||
|
||
# 导入 API Keys
|
||
for key_data in user_data.get("api_keys", []):
|
||
# 检查是否已存在相同的 key_hash
|
||
if key_data.get("key_hash"):
|
||
existing_key = (
|
||
db.query(ApiKey)
|
||
.filter(ApiKey.key_hash == key_data["key_hash"])
|
||
.first()
|
||
)
|
||
if existing_key:
|
||
stats["api_keys"]["skipped"] += 1
|
||
continue
|
||
|
||
new_key = ApiKey(
|
||
id=str(uuid.uuid4()),
|
||
user_id=user_id,
|
||
key_hash=key_data.get("key_hash", ""),
|
||
key_encrypted=key_data.get("key_encrypted"),
|
||
name=key_data.get("name"),
|
||
is_standalone=key_data.get("is_standalone", False),
|
||
balance_used_usd=key_data.get("balance_used_usd", 0.0),
|
||
current_balance_usd=key_data.get("current_balance_usd"),
|
||
allowed_providers=key_data.get("allowed_providers"),
|
||
allowed_endpoints=key_data.get("allowed_endpoints"),
|
||
allowed_api_formats=key_data.get("allowed_api_formats"),
|
||
allowed_models=key_data.get("allowed_models"),
|
||
rate_limit=key_data.get("rate_limit", 100),
|
||
concurrent_limit=key_data.get("concurrent_limit", 5),
|
||
force_capabilities=key_data.get("force_capabilities"),
|
||
is_active=key_data.get("is_active", True),
|
||
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
|
||
total_requests=key_data.get("total_requests", 0),
|
||
total_cost_usd=key_data.get("total_cost_usd", 0.0),
|
||
)
|
||
db.add(new_key)
|
||
stats["api_keys"]["created"] += 1
|
||
|
||
db.commit()
|
||
|
||
return {
|
||
"message": "用户数据导入成功",
|
||
"stats": stats,
|
||
}
|
||
|
||
except InvalidRequestException:
|
||
db.rollback()
|
||
raise
|
||
except Exception as e:
|
||
db.rollback()
|
||
raise InvalidRequestException(f"导入失败: {str(e)}")
|
||
|
||
|
||
class AdminTestSmtpAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
"""测试 SMTP 连接"""
|
||
from src.services.system.config import ConfigService
|
||
from src.services.verification.email_sender import EmailSenderService
|
||
|
||
db = context.db
|
||
payload = context.ensure_json_body() or {}
|
||
|
||
# 前端可传入未保存的配置,优先使用前端值,否则回退数据库
|
||
config = {
|
||
"smtp_host": payload.get("smtp_host") or ConfigService.get_config(db, "smtp_host"),
|
||
"smtp_port": payload.get("smtp_port") or ConfigService.get_config(db, "smtp_port", default=587),
|
||
"smtp_user": payload.get("smtp_user") or ConfigService.get_config(db, "smtp_user"),
|
||
"smtp_password": payload.get("smtp_password") or ConfigService.get_config(db, "smtp_password"),
|
||
"smtp_use_tls": payload.get("smtp_use_tls")
|
||
if payload.get("smtp_use_tls") is not None
|
||
else ConfigService.get_config(db, "smtp_use_tls", default=True),
|
||
"smtp_use_ssl": payload.get("smtp_use_ssl")
|
||
if payload.get("smtp_use_ssl") is not None
|
||
else ConfigService.get_config(db, "smtp_use_ssl", default=False),
|
||
"smtp_from_email": payload.get("smtp_from_email")
|
||
or ConfigService.get_config(db, "smtp_from_email"),
|
||
"smtp_from_name": payload.get("smtp_from_name")
|
||
or ConfigService.get_config(db, "smtp_from_name", default="Aether"),
|
||
}
|
||
|
||
# 验证必要配置
|
||
missing_fields = [
|
||
field for field in ["smtp_host", "smtp_user", "smtp_password", "smtp_from_email"] if not config.get(field)
|
||
]
|
||
if missing_fields:
|
||
return {
|
||
"success": False,
|
||
"message": f"SMTP 配置不完整,请检查 {', '.join(missing_fields)}"
|
||
}
|
||
|
||
# 测试连接
|
||
try:
|
||
success, error_msg = await EmailSenderService.test_smtp_connection(
|
||
db=db, override_config=config
|
||
)
|
||
|
||
if success:
|
||
return {
|
||
"success": True,
|
||
"message": "SMTP 连接测试成功"
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"message": f"SMTP 连接测试失败: {error_msg}"
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"success": False,
|
||
"message": f"SMTP 连接测试失败: {str(e)}"
|
||
}
|