"""系统设置API端点。""" from dataclasses import dataclass from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import ValidationError from sqlalchemy.orm import Session from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.pipeline import ApiRequestPipeline from src.core.exceptions import InvalidRequestException, NotFoundException, translate_pydantic_error from src.database import get_db from src.models.api import SystemSettingsRequest, SystemSettingsResponse from src.models.database import ApiKey, Provider, Usage, User from src.services.system.config import SystemConfigService router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"]) pipeline = ApiRequestPipeline() @router.get("/settings") async def get_system_settings(request: Request, db: Session = Depends(get_db)): """获取系统设置(管理员)""" adapter = AdminGetSystemSettingsAdapter() return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.put("/settings") async def update_system_settings(http_request: Request, db: Session = Depends(get_db)): """更新系统设置(管理员)""" adapter = AdminUpdateSystemSettingsAdapter() return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode) @router.get("/configs") async def get_all_system_configs(request: Request, db: Session = Depends(get_db)): """获取所有系统配置(管理员)""" adapter = AdminGetAllConfigsAdapter() return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/configs/{key}") async def get_system_config(key: str, request: Request, db: Session = Depends(get_db)): """获取特定系统配置(管理员)""" adapter = AdminGetSystemConfigAdapter(key=key) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.put("/configs/{key}") async def set_system_config( key: str, request: Request, db: Session = Depends(get_db), ): """设置系统配置(管理员)""" adapter = AdminSetSystemConfigAdapter(key=key) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.delete("/configs/{key}") async def delete_system_config(key: str, request: Request, db: Session = Depends(get_db)): """删除系统配置(管理员)""" adapter = AdminDeleteSystemConfigAdapter(key=key) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/stats") async def get_system_stats(request: Request, db: Session = Depends(get_db)): adapter = AdminSystemStatsAdapter() return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.post("/cleanup") async def trigger_cleanup(request: Request, db: Session = Depends(get_db)): """Manually trigger usage record cleanup task""" adapter = AdminTriggerCleanupAdapter() return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/api-formats") async def get_api_formats(request: Request, db: Session = Depends(get_db)): """获取所有可用的API格式列表""" adapter = AdminGetApiFormatsAdapter() return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/config/export") async def export_config(request: Request, db: Session = Depends(get_db)): """导出提供商和模型配置(管理员)""" adapter = AdminExportConfigAdapter() return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.post("/config/import") async def import_config(request: Request, db: Session = Depends(get_db)): """导入提供商和模型配置(管理员)""" adapter = AdminImportConfigAdapter() return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/users/export") async def export_users(request: Request, db: Session = Depends(get_db)): """导出用户数据(管理员)""" adapter = AdminExportUsersAdapter() return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.post("/users/import") async def import_users(request: Request, db: Session = Depends(get_db)): """导入用户数据(管理员)""" adapter = AdminImportUsersAdapter() return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.post("/smtp/test") async def test_smtp(request: Request, db: Session = Depends(get_db)): """测试 SMTP 连接(管理员)""" adapter = AdminTestSmtpAdapter() return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) # -------- 系统设置适配器 -------- class AdminGetSystemSettingsAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] db = context.db default_provider = SystemConfigService.get_default_provider(db) default_model = SystemConfigService.get_config(db, "default_model") enable_usage_tracking = ( SystemConfigService.get_config(db, "enable_usage_tracking", "true") == "true" ) return SystemSettingsResponse( default_provider=default_provider, default_model=default_model, enable_usage_tracking=enable_usage_tracking, ) class AdminUpdateSystemSettingsAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] db = context.db payload = context.ensure_json_body() try: settings_request = SystemSettingsRequest.model_validate(payload) except ValidationError as e: errors = e.errors() if errors: raise InvalidRequestException(translate_pydantic_error(errors[0])) raise InvalidRequestException("请求数据验证失败") if settings_request.default_provider is not None: provider = ( db.query(Provider) .filter( Provider.name == settings_request.default_provider, Provider.is_active.is_(True), ) .first() ) if not provider and settings_request.default_provider != "": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"提供商 '{settings_request.default_provider}' 不存在或未启用", ) if settings_request.default_provider: SystemConfigService.set_default_provider(db, settings_request.default_provider) else: SystemConfigService.delete_config(db, "default_provider") if settings_request.default_model is not None: if settings_request.default_model: SystemConfigService.set_config(db, "default_model", settings_request.default_model) else: SystemConfigService.delete_config(db, "default_model") if settings_request.enable_usage_tracking is not None: SystemConfigService.set_config( db, "enable_usage_tracking", str(settings_request.enable_usage_tracking).lower(), ) return {"message": "系统设置更新成功"} class AdminGetAllConfigsAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] return SystemConfigService.get_all_configs(context.db) @dataclass class AdminGetSystemConfigAdapter(AdminApiAdapter): key: str async def handle(self, context): # type: ignore[override] value = SystemConfigService.get_config(context.db, self.key) if value is None: raise NotFoundException(f"配置项 '{self.key}' 不存在") return {"key": self.key, "value": value} @dataclass class AdminSetSystemConfigAdapter(AdminApiAdapter): key: str async def handle(self, context): # type: ignore[override] payload = context.ensure_json_body() config = SystemConfigService.set_config( context.db, self.key, payload.get("value"), payload.get("description"), ) return { "key": config.key, "value": config.value, "description": config.description, "updated_at": config.updated_at.isoformat(), } @dataclass class AdminDeleteSystemConfigAdapter(AdminApiAdapter): key: str async def handle(self, context): # type: ignore[override] deleted = SystemConfigService.delete_config(context.db, self.key) if not deleted: raise NotFoundException(f"配置项 '{self.key}' 不存在") return {"message": f"配置项 '{self.key}' 已删除"} class AdminSystemStatsAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] db = context.db total_users = db.query(User).count() active_users = db.query(User).filter(User.is_active.is_(True)).count() total_providers = db.query(Provider).count() active_providers = db.query(Provider).filter(Provider.is_active.is_(True)).count() total_api_keys = db.query(ApiKey).count() total_requests = db.query(Usage).count() return { "users": {"total": total_users, "active": active_users}, "providers": {"total": total_providers, "active": active_providers}, "api_keys": total_api_keys, "requests": total_requests, } class AdminTriggerCleanupAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] """手动触发清理任务""" from datetime import datetime, timedelta, timezone from sqlalchemy import func from src.services.system.cleanup_scheduler import get_cleanup_scheduler db = context.db # 获取清理前的统计信息 total_before = db.query(Usage).count() with_body_before = ( db.query(Usage) .filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None))) .count() ) with_headers_before = ( db.query(Usage) .filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None))) .count() ) # 触发清理 cleanup_scheduler = get_cleanup_scheduler() await cleanup_scheduler._perform_cleanup() # 获取清理后的统计信息 total_after = db.query(Usage).count() with_body_after = ( db.query(Usage) .filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None))) .count() ) with_headers_after = ( db.query(Usage) .filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None))) .count() ) return { "message": "清理任务执行完成", "stats": { "total_records": { "before": total_before, "after": total_after, "deleted": total_before - total_after, }, "body_fields": { "before": with_body_before, "after": with_body_after, "cleaned": with_body_before - with_body_after, }, "header_fields": { "before": with_headers_before, "after": with_headers_after, "cleaned": with_headers_before - with_headers_after, }, }, "timestamp": datetime.now(timezone.utc).isoformat(), } class AdminGetApiFormatsAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] """获取所有可用的API格式""" from src.core.api_format_metadata import API_FORMAT_DEFINITIONS from src.core.enums import APIFormat _ = context # 参数保留以符合接口规范 formats = [] for api_format in APIFormat: definition = API_FORMAT_DEFINITIONS.get(api_format) formats.append( { "value": api_format.value, "label": api_format.value, "default_path": definition.default_path if definition else "/", "aliases": list(definition.aliases) if definition else [], } ) return {"formats": formats} class AdminExportConfigAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] """导出提供商和模型配置(解密数据)""" from datetime import datetime, timezone from src.core.crypto import crypto_service from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint db = context.db # 导出 GlobalModels global_models = db.query(GlobalModel).all() global_models_data = [] for gm in global_models: global_models_data.append( { "name": gm.name, "display_name": gm.display_name, "default_price_per_request": gm.default_price_per_request, "default_tiered_pricing": gm.default_tiered_pricing, "supported_capabilities": gm.supported_capabilities, "config": gm.config, "is_active": gm.is_active, } ) # 导出 Providers 及其关联数据 providers = db.query(Provider).all() providers_data = [] for provider in providers: # 导出 Endpoints endpoints = ( db.query(ProviderEndpoint) .filter(ProviderEndpoint.provider_id == provider.id) .all() ) endpoints_data = [] for ep in endpoints: # 导出 Endpoint Keys keys = ( db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all() ) keys_data = [] for key in keys: # 解密 API Key try: decrypted_key = crypto_service.decrypt(key.api_key) except Exception: decrypted_key = "" keys_data.append( { "api_key": decrypted_key, "name": key.name, "note": key.note, "rate_multiplier": key.rate_multiplier, "internal_priority": key.internal_priority, "global_priority": key.global_priority, "max_concurrent": key.max_concurrent, "rate_limit": key.rate_limit, "daily_limit": key.daily_limit, "monthly_limit": key.monthly_limit, "allowed_models": key.allowed_models, "capabilities": key.capabilities, "is_active": key.is_active, } ) endpoints_data.append( { "api_format": ep.api_format, "base_url": ep.base_url, "headers": ep.headers, "timeout": ep.timeout, "max_retries": ep.max_retries, "max_concurrent": ep.max_concurrent, "rate_limit": ep.rate_limit, "is_active": ep.is_active, "custom_path": ep.custom_path, "config": ep.config, "keys": keys_data, } ) # 导出 Provider Models models = db.query(Model).filter(Model.provider_id == provider.id).all() models_data = [] for model in models: # 获取关联的 GlobalModel 名称 global_model = ( db.query(GlobalModel).filter(GlobalModel.id == model.global_model_id).first() ) models_data.append( { "global_model_name": global_model.name if global_model else None, "provider_model_name": model.provider_model_name, "provider_model_mappings": model.provider_model_mappings, "price_per_request": model.price_per_request, "tiered_pricing": model.tiered_pricing, "supports_vision": model.supports_vision, "supports_function_calling": model.supports_function_calling, "supports_streaming": model.supports_streaming, "supports_extended_thinking": model.supports_extended_thinking, "supports_image_generation": model.supports_image_generation, "is_active": model.is_active, "config": model.config, } ) providers_data.append( { "name": provider.name, "display_name": provider.display_name, "description": provider.description, "website": provider.website, "billing_type": provider.billing_type.value if provider.billing_type else None, "monthly_quota_usd": provider.monthly_quota_usd, "quota_reset_day": provider.quota_reset_day, "rpm_limit": provider.rpm_limit, "provider_priority": provider.provider_priority, "is_active": provider.is_active, "rate_limit": provider.rate_limit, "concurrent_limit": provider.concurrent_limit, "config": provider.config, "endpoints": endpoints_data, "models": models_data, } ) return { "version": "1.0", "exported_at": datetime.now(timezone.utc).isoformat(), "global_models": global_models_data, "providers": providers_data, } MAX_IMPORT_SIZE = 10 * 1024 * 1024 # 10MB class AdminImportConfigAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] """导入提供商和模型配置""" import uuid from datetime import datetime, timezone from src.core.crypto import crypto_service from src.core.enums import ProviderBillingType from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint # 检查请求体大小 if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE: raise InvalidRequestException("请求体大小不能超过 10MB") db = context.db payload = context.ensure_json_body() # 验证配置版本 version = payload.get("version") if version != "1.0": raise InvalidRequestException(f"不支持的配置版本: {version}") # 获取导入选项 merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error global_models_data = payload.get("global_models", []) providers_data = payload.get("providers", []) stats = { "global_models": {"created": 0, "updated": 0, "skipped": 0}, "providers": {"created": 0, "updated": 0, "skipped": 0}, "endpoints": {"created": 0, "updated": 0, "skipped": 0}, "keys": {"created": 0, "updated": 0, "skipped": 0}, "models": {"created": 0, "updated": 0, "skipped": 0}, "errors": [], } try: # 导入 GlobalModels global_model_map = {} # name -> id 映射 for gm_data in global_models_data: existing = ( db.query(GlobalModel).filter(GlobalModel.name == gm_data["name"]).first() ) if existing: global_model_map[gm_data["name"]] = existing.id if merge_mode == "skip": stats["global_models"]["skipped"] += 1 continue elif merge_mode == "error": raise InvalidRequestException( f"GlobalModel '{gm_data['name']}' 已存在" ) elif merge_mode == "overwrite": # 更新现有记录 existing.display_name = gm_data.get( "display_name", existing.display_name ) existing.default_price_per_request = gm_data.get( "default_price_per_request" ) existing.default_tiered_pricing = gm_data.get( "default_tiered_pricing", existing.default_tiered_pricing ) existing.supported_capabilities = gm_data.get( "supported_capabilities" ) existing.config = gm_data.get("config") existing.is_active = gm_data.get("is_active", True) existing.updated_at = datetime.now(timezone.utc) stats["global_models"]["updated"] += 1 else: # 创建新记录 new_gm = GlobalModel( id=str(uuid.uuid4()), name=gm_data["name"], display_name=gm_data.get("display_name", gm_data["name"]), default_price_per_request=gm_data.get("default_price_per_request"), default_tiered_pricing=gm_data.get( "default_tiered_pricing", {"tiers": [{"up_to": None, "input_price_per_1m": 0, "output_price_per_1m": 0}]}, ), supported_capabilities=gm_data.get("supported_capabilities"), config=gm_data.get("config"), is_active=gm_data.get("is_active", True), ) db.add(new_gm) db.flush() global_model_map[gm_data["name"]] = new_gm.id stats["global_models"]["created"] += 1 # 导入 Providers for prov_data in providers_data: existing_provider = ( db.query(Provider).filter(Provider.name == prov_data["name"]).first() ) if existing_provider: provider_id = existing_provider.id if merge_mode == "skip": stats["providers"]["skipped"] += 1 # 仍然需要处理 endpoints 和 models(如果存在) elif merge_mode == "error": raise InvalidRequestException( f"Provider '{prov_data['name']}' 已存在" ) elif merge_mode == "overwrite": # 更新现有记录 existing_provider.display_name = prov_data.get( "display_name", existing_provider.display_name ) existing_provider.description = prov_data.get("description") existing_provider.website = prov_data.get("website") if prov_data.get("billing_type"): existing_provider.billing_type = ProviderBillingType( prov_data["billing_type"] ) existing_provider.monthly_quota_usd = prov_data.get( "monthly_quota_usd" ) existing_provider.quota_reset_day = prov_data.get( "quota_reset_day", 30 ) existing_provider.rpm_limit = prov_data.get("rpm_limit") existing_provider.provider_priority = prov_data.get( "provider_priority", 100 ) existing_provider.is_active = prov_data.get("is_active", True) existing_provider.rate_limit = prov_data.get("rate_limit") existing_provider.concurrent_limit = prov_data.get( "concurrent_limit" ) existing_provider.config = prov_data.get("config") existing_provider.updated_at = datetime.now(timezone.utc) stats["providers"]["updated"] += 1 else: # 创建新 Provider billing_type = ProviderBillingType.PAY_AS_YOU_GO if prov_data.get("billing_type"): billing_type = ProviderBillingType(prov_data["billing_type"]) new_provider = Provider( id=str(uuid.uuid4()), name=prov_data["name"], display_name=prov_data.get("display_name", prov_data["name"]), description=prov_data.get("description"), website=prov_data.get("website"), billing_type=billing_type, monthly_quota_usd=prov_data.get("monthly_quota_usd"), quota_reset_day=prov_data.get("quota_reset_day", 30), rpm_limit=prov_data.get("rpm_limit"), provider_priority=prov_data.get("provider_priority", 100), is_active=prov_data.get("is_active", True), rate_limit=prov_data.get("rate_limit"), concurrent_limit=prov_data.get("concurrent_limit"), config=prov_data.get("config"), ) db.add(new_provider) db.flush() provider_id = new_provider.id stats["providers"]["created"] += 1 # 导入 Endpoints for ep_data in prov_data.get("endpoints", []): existing_ep = ( db.query(ProviderEndpoint) .filter( ProviderEndpoint.provider_id == provider_id, ProviderEndpoint.api_format == ep_data["api_format"], ) .first() ) if existing_ep: endpoint_id = existing_ep.id if merge_mode == "skip": stats["endpoints"]["skipped"] += 1 elif merge_mode == "error": raise InvalidRequestException( f"Endpoint '{ep_data['api_format']}' 已存在于 Provider '{prov_data['name']}'" ) elif merge_mode == "overwrite": existing_ep.base_url = ep_data.get( "base_url", existing_ep.base_url ) existing_ep.headers = ep_data.get("headers") existing_ep.timeout = ep_data.get("timeout", 300) existing_ep.max_retries = ep_data.get("max_retries", 3) existing_ep.max_concurrent = ep_data.get("max_concurrent") existing_ep.rate_limit = ep_data.get("rate_limit") existing_ep.is_active = ep_data.get("is_active", True) existing_ep.custom_path = ep_data.get("custom_path") existing_ep.config = ep_data.get("config") existing_ep.updated_at = datetime.now(timezone.utc) stats["endpoints"]["updated"] += 1 else: new_ep = ProviderEndpoint( id=str(uuid.uuid4()), provider_id=provider_id, api_format=ep_data["api_format"], base_url=ep_data["base_url"], headers=ep_data.get("headers"), timeout=ep_data.get("timeout", 300), max_retries=ep_data.get("max_retries", 3), max_concurrent=ep_data.get("max_concurrent"), rate_limit=ep_data.get("rate_limit"), is_active=ep_data.get("is_active", True), custom_path=ep_data.get("custom_path"), config=ep_data.get("config"), ) db.add(new_ep) db.flush() endpoint_id = new_ep.id stats["endpoints"]["created"] += 1 # 导入 Keys # 获取当前 endpoint 下所有已有的 keys,用于去重 existing_keys = ( db.query(ProviderAPIKey) .filter(ProviderAPIKey.endpoint_id == endpoint_id) .all() ) # 解密已有 keys 用于比对 existing_key_values = set() for ek in existing_keys: try: decrypted = crypto_service.decrypt(ek.api_key) existing_key_values.add(decrypted) except Exception: pass for key_data in ep_data.get("keys", []): if not key_data.get("api_key"): stats["errors"].append( f"跳过空 API Key (Endpoint: {ep_data['api_format']})" ) continue # 检查是否已存在相同的 Key(通过明文比对) if key_data["api_key"] in existing_key_values: stats["keys"]["skipped"] += 1 continue encrypted_key = crypto_service.encrypt(key_data["api_key"]) new_key = ProviderAPIKey( id=str(uuid.uuid4()), endpoint_id=endpoint_id, api_key=encrypted_key, name=key_data.get("name"), note=key_data.get("note"), rate_multiplier=key_data.get("rate_multiplier", 1.0), internal_priority=key_data.get("internal_priority", 100), global_priority=key_data.get("global_priority"), max_concurrent=key_data.get("max_concurrent"), rate_limit=key_data.get("rate_limit"), daily_limit=key_data.get("daily_limit"), monthly_limit=key_data.get("monthly_limit"), allowed_models=key_data.get("allowed_models"), capabilities=key_data.get("capabilities"), is_active=key_data.get("is_active", True), ) db.add(new_key) # 添加到已有集合,防止同一批导入中重复 existing_key_values.add(key_data["api_key"]) stats["keys"]["created"] += 1 # 导入 Models for model_data in prov_data.get("models", []): global_model_name = model_data.get("global_model_name") if not global_model_name: stats["errors"].append( f"跳过无 global_model_name 的模型 (Provider: {prov_data['name']})" ) continue global_model_id = global_model_map.get(global_model_name) if not global_model_id: # 尝试从数据库查找 existing_gm = ( db.query(GlobalModel) .filter(GlobalModel.name == global_model_name) .first() ) if existing_gm: global_model_id = existing_gm.id else: stats["errors"].append( f"GlobalModel '{global_model_name}' 不存在,跳过模型" ) continue existing_model = ( db.query(Model) .filter( Model.provider_id == provider_id, Model.provider_model_name == model_data["provider_model_name"], ) .first() ) if existing_model: if merge_mode == "skip": stats["models"]["skipped"] += 1 elif merge_mode == "error": raise InvalidRequestException( f"Model '{model_data['provider_model_name']}' 已存在于 Provider '{prov_data['name']}'" ) elif merge_mode == "overwrite": existing_model.global_model_id = global_model_id existing_model.provider_model_mappings = model_data.get( "provider_model_mappings" ) existing_model.price_per_request = model_data.get( "price_per_request" ) existing_model.tiered_pricing = model_data.get( "tiered_pricing" ) existing_model.supports_vision = model_data.get( "supports_vision" ) existing_model.supports_function_calling = model_data.get( "supports_function_calling" ) existing_model.supports_streaming = model_data.get( "supports_streaming" ) existing_model.supports_extended_thinking = model_data.get( "supports_extended_thinking" ) existing_model.supports_image_generation = model_data.get( "supports_image_generation" ) existing_model.is_active = model_data.get("is_active", True) existing_model.config = model_data.get("config") existing_model.updated_at = datetime.now(timezone.utc) stats["models"]["updated"] += 1 else: new_model = Model( id=str(uuid.uuid4()), provider_id=provider_id, global_model_id=global_model_id, provider_model_name=model_data["provider_model_name"], provider_model_mappings=model_data.get( "provider_model_mappings" ), price_per_request=model_data.get("price_per_request"), tiered_pricing=model_data.get("tiered_pricing"), supports_vision=model_data.get("supports_vision"), supports_function_calling=model_data.get( "supports_function_calling" ), supports_streaming=model_data.get("supports_streaming"), supports_extended_thinking=model_data.get( "supports_extended_thinking" ), supports_image_generation=model_data.get( "supports_image_generation" ), is_active=model_data.get("is_active", True), config=model_data.get("config"), ) db.add(new_model) stats["models"]["created"] += 1 db.commit() # 失效缓存 from src.services.cache.invalidation import get_cache_invalidation_service cache_service = get_cache_invalidation_service() cache_service.clear_all_caches() return { "message": "配置导入成功", "stats": stats, } except InvalidRequestException: db.rollback() raise except Exception as e: db.rollback() raise InvalidRequestException(f"导入失败: {str(e)}") class AdminExportUsersAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] """导出用户数据(保留加密数据,排除管理员)""" from datetime import datetime, timezone from src.core.enums import UserRole from src.models.database import ApiKey, User db = context.db # 导出 Users(排除管理员) users = db.query(User).filter( User.is_deleted.is_(False), User.role != UserRole.ADMIN ).all() users_data = [] for user in users: # 导出用户的 API Keys(保留加密数据) api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all() api_keys_data = [] for key in api_keys: api_keys_data.append( { "key_hash": key.key_hash, "key_encrypted": key.key_encrypted, "name": key.name, "is_standalone": key.is_standalone, "balance_used_usd": key.balance_used_usd, "current_balance_usd": key.current_balance_usd, "allowed_providers": key.allowed_providers, "allowed_endpoints": key.allowed_endpoints, "allowed_api_formats": key.allowed_api_formats, "allowed_models": key.allowed_models, "rate_limit": key.rate_limit, "concurrent_limit": key.concurrent_limit, "force_capabilities": key.force_capabilities, "is_active": key.is_active, "auto_delete_on_expiry": key.auto_delete_on_expiry, "total_requests": key.total_requests, "total_cost_usd": key.total_cost_usd, } ) users_data.append( { "email": user.email, "username": user.username, "password_hash": user.password_hash, "role": user.role.value if user.role else "user", "allowed_providers": user.allowed_providers, "allowed_endpoints": user.allowed_endpoints, "allowed_models": user.allowed_models, "model_capability_settings": user.model_capability_settings, "quota_usd": user.quota_usd, "used_usd": user.used_usd, "total_usd": user.total_usd, "is_active": user.is_active, "api_keys": api_keys_data, } ) return { "version": "1.0", "exported_at": datetime.now(timezone.utc).isoformat(), "users": users_data, } class AdminImportUsersAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] """导入用户数据""" import uuid from datetime import datetime, timezone from src.core.enums import UserRole from src.models.database import ApiKey, User # 检查请求体大小 if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE: raise InvalidRequestException("请求体大小不能超过 10MB") db = context.db payload = context.ensure_json_body() # 验证配置版本 version = payload.get("version") if version != "1.0": raise InvalidRequestException(f"不支持的配置版本: {version}") # 获取导入选项 merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error users_data = payload.get("users", []) stats = { "users": {"created": 0, "updated": 0, "skipped": 0}, "api_keys": {"created": 0, "skipped": 0}, "errors": [], } try: for user_data in users_data: # 跳过管理员角色的导入(不区分大小写) role_str = str(user_data.get("role", "")).lower() if role_str == "admin": stats["errors"].append(f"跳过管理员用户: {user_data.get('email')}") stats["users"]["skipped"] += 1 continue existing_user = ( db.query(User).filter(User.email == user_data["email"]).first() ) if existing_user: user_id = existing_user.id if merge_mode == "skip": stats["users"]["skipped"] += 1 elif merge_mode == "error": raise InvalidRequestException( f"用户 '{user_data['email']}' 已存在" ) elif merge_mode == "overwrite": # 更新现有用户 existing_user.username = user_data.get( "username", existing_user.username ) if user_data.get("password_hash"): existing_user.password_hash = user_data["password_hash"] if user_data.get("role"): existing_user.role = UserRole(user_data["role"]) existing_user.allowed_providers = user_data.get("allowed_providers") existing_user.allowed_endpoints = user_data.get("allowed_endpoints") existing_user.allowed_models = user_data.get("allowed_models") existing_user.model_capability_settings = user_data.get( "model_capability_settings" ) existing_user.quota_usd = user_data.get("quota_usd") existing_user.used_usd = user_data.get("used_usd", 0.0) existing_user.total_usd = user_data.get("total_usd", 0.0) existing_user.is_active = user_data.get("is_active", True) existing_user.updated_at = datetime.now(timezone.utc) stats["users"]["updated"] += 1 else: # 创建新用户 role = UserRole.USER if user_data.get("role"): role = UserRole(user_data["role"]) new_user = User( id=str(uuid.uuid4()), email=user_data["email"], username=user_data.get("username", user_data["email"].split("@")[0]), password_hash=user_data.get("password_hash", ""), role=role, allowed_providers=user_data.get("allowed_providers"), allowed_endpoints=user_data.get("allowed_endpoints"), allowed_models=user_data.get("allowed_models"), model_capability_settings=user_data.get("model_capability_settings"), quota_usd=user_data.get("quota_usd"), used_usd=user_data.get("used_usd", 0.0), total_usd=user_data.get("total_usd", 0.0), is_active=user_data.get("is_active", True), ) db.add(new_user) db.flush() user_id = new_user.id stats["users"]["created"] += 1 # 导入 API Keys for key_data in user_data.get("api_keys", []): # 检查是否已存在相同的 key_hash if key_data.get("key_hash"): existing_key = ( db.query(ApiKey) .filter(ApiKey.key_hash == key_data["key_hash"]) .first() ) if existing_key: stats["api_keys"]["skipped"] += 1 continue new_key = ApiKey( id=str(uuid.uuid4()), user_id=user_id, key_hash=key_data.get("key_hash", ""), key_encrypted=key_data.get("key_encrypted"), name=key_data.get("name"), is_standalone=key_data.get("is_standalone", False), balance_used_usd=key_data.get("balance_used_usd", 0.0), current_balance_usd=key_data.get("current_balance_usd"), allowed_providers=key_data.get("allowed_providers"), allowed_endpoints=key_data.get("allowed_endpoints"), allowed_api_formats=key_data.get("allowed_api_formats"), allowed_models=key_data.get("allowed_models"), rate_limit=key_data.get("rate_limit", 100), concurrent_limit=key_data.get("concurrent_limit", 5), force_capabilities=key_data.get("force_capabilities"), is_active=key_data.get("is_active", True), auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False), total_requests=key_data.get("total_requests", 0), total_cost_usd=key_data.get("total_cost_usd", 0.0), ) db.add(new_key) stats["api_keys"]["created"] += 1 db.commit() return { "message": "用户数据导入成功", "stats": stats, } except InvalidRequestException: db.rollback() raise except Exception as e: db.rollback() raise InvalidRequestException(f"导入失败: {str(e)}") class AdminTestSmtpAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] """测试 SMTP 连接""" from src.services.system.config import ConfigService from src.services.verification.email_sender import EmailSenderService db = context.db payload = context.ensure_json_body() or {} # 前端可传入未保存的配置,优先使用前端值,否则回退数据库 config = { "smtp_host": payload.get("smtp_host") or ConfigService.get_config(db, "smtp_host"), "smtp_port": payload.get("smtp_port") or ConfigService.get_config(db, "smtp_port", default=587), "smtp_user": payload.get("smtp_user") or ConfigService.get_config(db, "smtp_user"), "smtp_password": payload.get("smtp_password") or ConfigService.get_config(db, "smtp_password"), "smtp_use_tls": payload.get("smtp_use_tls") if payload.get("smtp_use_tls") is not None else ConfigService.get_config(db, "smtp_use_tls", default=True), "smtp_use_ssl": payload.get("smtp_use_ssl") if payload.get("smtp_use_ssl") is not None else ConfigService.get_config(db, "smtp_use_ssl", default=False), "smtp_from_email": payload.get("smtp_from_email") or ConfigService.get_config(db, "smtp_from_email"), "smtp_from_name": payload.get("smtp_from_name") or ConfigService.get_config(db, "smtp_from_name", default="Aether"), } # 验证必要配置 missing_fields = [ field for field in ["smtp_host", "smtp_user", "smtp_password", "smtp_from_email"] if not config.get(field) ] if missing_fields: return { "success": False, "message": f"SMTP 配置不完整,请检查 {', '.join(missing_fields)}" } # 测试连接 try: success, error_msg = await EmailSenderService.test_smtp_connection( db=db, override_config=config ) if success: return { "success": True, "message": "SMTP 连接测试成功" } else: return { "success": False, "message": f"SMTP 连接测试失败: {error_msg}" } except Exception as e: return { "success": False, "message": f"SMTP 连接测试失败: {str(e)}" }