diff --git a/frontend/src/api/admin.ts b/frontend/src/api/admin.ts index 648e71d..b581e16 100644 --- a/frontend/src/api/admin.ts +++ b/frontend/src/api/admin.ts @@ -1,5 +1,158 @@ import apiClient from './client' +// 配置导出数据结构 +export interface ConfigExportData { + version: string + exported_at: string + global_models: GlobalModelExport[] + providers: ProviderExport[] +} + +// 用户导出数据结构 +export interface UsersExportData { + version: string + exported_at: string + users: UserExport[] +} + +export interface UserExport { + email: string + username: string + password_hash: string + role: string + allowed_providers?: string[] | null + allowed_endpoints?: string[] | null + allowed_models?: string[] | null + model_capability_settings?: any + quota_usd?: number | null + used_usd?: number + total_usd?: number + is_active: boolean + api_keys: UserApiKeyExport[] +} + +export interface UserApiKeyExport { + key_hash: string + key_encrypted?: string | null + name?: string | null + is_standalone: boolean + balance_used_usd?: number + current_balance_usd?: number | null + allowed_providers?: string[] | null + allowed_endpoints?: string[] | null + allowed_api_formats?: string[] | null + allowed_models?: string[] | null + rate_limit?: number + concurrent_limit?: number | null + force_capabilities?: any + is_active: boolean + auto_delete_on_expiry?: boolean + total_requests?: number + total_cost_usd?: number +} + +export interface GlobalModelExport { + name: string + display_name: string + default_price_per_request?: number | null + default_tiered_pricing: any + supported_capabilities?: string[] | null + config?: any + is_active: boolean +} + +export interface ProviderExport { + name: string + display_name: string + description?: string | null + website?: string | null + billing_type?: string | null + monthly_quota_usd?: number | null + quota_reset_day?: number + rpm_limit?: number | null + provider_priority?: number + is_active: boolean + rate_limit?: number | null + concurrent_limit?: number | null + config?: any + endpoints: EndpointExport[] + models: ModelExport[] +} + +export interface EndpointExport { + api_format: string + base_url: string + headers?: any + timeout?: number + max_retries?: number + max_concurrent?: number | null + rate_limit?: number | null + is_active: boolean + custom_path?: string | null + config?: any + keys: KeyExport[] +} + +export interface KeyExport { + api_key: string + name?: string | null + note?: string | null + rate_multiplier?: number + internal_priority?: number + global_priority?: number | null + max_concurrent?: number | null + rate_limit?: number | null + daily_limit?: number | null + monthly_limit?: number | null + allowed_models?: string[] | null + capabilities?: any + is_active: boolean +} + +export interface ModelExport { + global_model_name: string | null + provider_model_name: string + provider_model_aliases?: any + price_per_request?: number | null + tiered_pricing?: any + supports_vision?: boolean | null + supports_function_calling?: boolean | null + supports_streaming?: boolean | null + supports_extended_thinking?: boolean | null + supports_image_generation?: boolean | null + is_active: boolean + config?: any +} + +export interface ConfigImportRequest extends ConfigExportData { + merge_mode: 'skip' | 'overwrite' | 'error' +} + +export interface UsersImportRequest extends UsersExportData { + merge_mode: 'skip' | 'overwrite' | 'error' +} + +export interface UsersImportResponse { + message: string + stats: { + users: { created: number; updated: number; skipped: number } + api_keys: { created: number; skipped: number } + errors: string[] + } +} + +export interface ConfigImportResponse { + message: string + stats: { + global_models: { created: number; updated: number; skipped: number } + providers: { created: number; updated: number; skipped: number } + endpoints: { created: number; updated: number; skipped: number } + keys: { created: number; updated: number; skipped: number } + models: { created: number; updated: number; skipped: number } + errors: string[] + } +} + // API密钥管理相关接口定义 export interface AdminApiKey { id: string // UUID @@ -173,5 +326,35 @@ export const adminApi = { '/api/admin/system/api-formats' ) return response.data + }, + + // 导出配置 + async exportConfig(): Promise { + const response = await apiClient.get('/api/admin/system/config/export') + return response.data + }, + + // 导入配置 + async importConfig(data: ConfigImportRequest): Promise { + const response = await apiClient.post( + '/api/admin/system/config/import', + data + ) + return response.data + }, + + // 导出用户数据 + async exportUsers(): Promise { + const response = await apiClient.get('/api/admin/system/users/export') + return response.data + }, + + // 导入用户数据 + async importUsers(data: UsersImportRequest): Promise { + const response = await apiClient.post( + '/api/admin/system/users/import', + data + ) + return response.data } } diff --git a/frontend/src/features/models/components/GlobalModelFormDialog.vue b/frontend/src/features/models/components/GlobalModelFormDialog.vue index e32dc13..c56f359 100644 --- a/frontend/src/features/models/components/GlobalModelFormDialog.vue +++ b/frontend/src/features/models/components/GlobalModelFormDialog.vue @@ -415,7 +415,7 @@ const groupedModels = computed(() => { } // 转换为数组并排序 - let result = Array.from(groups.values()) + const result = Array.from(groups.values()) // 如果有搜索词,把提供商名称/ID匹配的排在前面 if (searchQuery.value) { diff --git a/frontend/src/views/admin/SystemSettings.vue b/frontend/src/views/admin/SystemSettings.vue index 8ddf4eb..e136d55 100644 --- a/frontend/src/views/admin/SystemSettings.vue +++ b/frontend/src/views/admin/SystemSettings.vue @@ -15,6 +15,94 @@
+ + +
+
+

+ 导出当前所有提供商、端点、API Key 和模型配置到 JSON 文件 +

+ +
+
+

+ 从 JSON 文件导入配置,支持跳过、覆盖或报错三种冲突处理模式 +

+
+ + +
+
+
+
+ + + +
+
+

+ 导出所有普通用户及其 API Keys 到 JSON 文件 +

+ +
+
+

+ 从 JSON 文件导入用户数据(需相同 ENCRYPTION_KEY) +

+
+ + +
+
+
+
+
+ + + + + + 导入配置 + + 选择冲突处理模式并确认导入 + + + +
+
+

+ 配置预览 +

+
    +
  • 全局模型: {{ importPreview.global_models?.length || 0 }} 个
  • +
  • 提供商: {{ importPreview.providers?.length || 0 }} 个
  • +
  • + 端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }} 个 +
  • +
  • + API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }} 个 +
  • +
+
+ +
+ + +

+ + + +

+
+ +
+

+ 注意:相同的 API Keys 会自动跳过,不会创建重复记录。 +

+
+
+ + + + + +
+
+ + + + + + 导入完成 + + +
+
+
+

+ 全局模型 +

+

+ 创建: {{ importResult.stats.global_models.created }}, + 更新: {{ importResult.stats.global_models.updated }}, + 跳过: {{ importResult.stats.global_models.skipped }} +

+
+
+

+ 提供商 +

+

+ 创建: {{ importResult.stats.providers.created }}, + 更新: {{ importResult.stats.providers.updated }}, + 跳过: {{ importResult.stats.providers.skipped }} +

+
+
+

+ 端点 +

+

+ 创建: {{ importResult.stats.endpoints.created }}, + 更新: {{ importResult.stats.endpoints.updated }}, + 跳过: {{ importResult.stats.endpoints.skipped }} +

+
+
+

+ API Keys +

+

+ 创建: {{ importResult.stats.keys.created }}, + 跳过: {{ importResult.stats.keys.skipped }} +

+
+
+

+ 模型配置 +

+

+ 创建: {{ importResult.stats.models.created }}, + 更新: {{ importResult.stats.models.updated }}, + 跳过: {{ importResult.stats.models.skipped }} +

+
+
+ +
+

+ 警告信息 +

+
    +
  • + {{ err }} +
  • +
+
+
+ + + + +
+
+ + + + + + 导入用户数据 + + 选择冲突处理模式并确认导入 + + + +
+
+

+ 数据预览 +

+
    +
  • 用户: {{ importUsersPreview.users?.length || 0 }} 个
  • +
  • + API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }} 个 +
  • +
+
+ +
+ + +

+ + + +

+
+ +
+

+ 注意:用户 API Keys 需要目标系统使用相同的 ENCRYPTION_KEY 环境变量才能正常工作。 +

+
+
+ + + + + +
+
+ + + + + + 用户数据导入完成 + + +
+
+
+

+ 用户 +

+

+ 创建: {{ importUsersResult.stats.users.created }}, + 更新: {{ importUsersResult.stats.users.updated }}, + 跳过: {{ importUsersResult.stats.users.skipped }} +

+
+
+

+ API Keys +

+

+ 创建: {{ importUsersResult.stats.api_keys.created }}, + 跳过: {{ importUsersResult.stats.api_keys.skipped }} +

+
+
+ +
+

+ 警告信息 +

+
    +
  • + {{ err }} +
  • +
+
+
+ + + + +
+
diff --git a/src/api/admin/system.py b/src/api/admin/system.py index d005b10..c2be36a 100644 --- a/src/api/admin/system.py +++ b/src/api/admin/system.py @@ -91,6 +91,34 @@ async def get_api_formats(request: Request, db: Session = Depends(get_db)): return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) +@router.get("/config/export") +async def export_config(request: Request, db: Session = Depends(get_db)): + """导出提供商和模型配置(管理员)""" + adapter = AdminExportConfigAdapter() + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +@router.post("/config/import") +async def import_config(request: Request, db: Session = Depends(get_db)): + """导入提供商和模型配置(管理员)""" + adapter = AdminImportConfigAdapter() + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +@router.get("/users/export") +async def export_users(request: Request, db: Session = Depends(get_db)): + """导出用户数据(管理员)""" + adapter = AdminExportUsersAdapter() + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +@router.post("/users/import") +async def import_users(request: Request, db: Session = Depends(get_db)): + """导入用户数据(管理员)""" + adapter = AdminImportUsersAdapter() + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + # -------- 系统设置适配器 -------- @@ -310,3 +338,749 @@ class AdminGetApiFormatsAdapter(AdminApiAdapter): ) return {"formats": formats} + + +class AdminExportConfigAdapter(AdminApiAdapter): + async def handle(self, context): # type: ignore[override] + """导出提供商和模型配置(解密数据)""" + from datetime import datetime, timezone + + from src.core.crypto import crypto_service + from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint + + db = context.db + + # 导出 GlobalModels + global_models = db.query(GlobalModel).all() + global_models_data = [] + for gm in global_models: + global_models_data.append( + { + "name": gm.name, + "display_name": gm.display_name, + "default_price_per_request": gm.default_price_per_request, + "default_tiered_pricing": gm.default_tiered_pricing, + "supported_capabilities": gm.supported_capabilities, + "config": gm.config, + "is_active": gm.is_active, + } + ) + + # 导出 Providers 及其关联数据 + providers = db.query(Provider).all() + providers_data = [] + for provider in providers: + # 导出 Endpoints + endpoints = ( + db.query(ProviderEndpoint) + .filter(ProviderEndpoint.provider_id == provider.id) + .all() + ) + endpoints_data = [] + for ep in endpoints: + # 导出 Endpoint Keys + keys = ( + db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all() + ) + keys_data = [] + for key in keys: + # 解密 API Key + try: + decrypted_key = crypto_service.decrypt(key.api_key) + except Exception: + decrypted_key = "" + + keys_data.append( + { + "api_key": decrypted_key, + "name": key.name, + "note": key.note, + "rate_multiplier": key.rate_multiplier, + "internal_priority": key.internal_priority, + "global_priority": key.global_priority, + "max_concurrent": key.max_concurrent, + "rate_limit": key.rate_limit, + "daily_limit": key.daily_limit, + "monthly_limit": key.monthly_limit, + "allowed_models": key.allowed_models, + "capabilities": key.capabilities, + "is_active": key.is_active, + } + ) + + endpoints_data.append( + { + "api_format": ep.api_format, + "base_url": ep.base_url, + "headers": ep.headers, + "timeout": ep.timeout, + "max_retries": ep.max_retries, + "max_concurrent": ep.max_concurrent, + "rate_limit": ep.rate_limit, + "is_active": ep.is_active, + "custom_path": ep.custom_path, + "config": ep.config, + "keys": keys_data, + } + ) + + # 导出 Provider Models + models = db.query(Model).filter(Model.provider_id == provider.id).all() + models_data = [] + for model in models: + # 获取关联的 GlobalModel 名称 + global_model = ( + db.query(GlobalModel).filter(GlobalModel.id == model.global_model_id).first() + ) + models_data.append( + { + "global_model_name": global_model.name if global_model else None, + "provider_model_name": model.provider_model_name, + "provider_model_aliases": model.provider_model_aliases, + "price_per_request": model.price_per_request, + "tiered_pricing": model.tiered_pricing, + "supports_vision": model.supports_vision, + "supports_function_calling": model.supports_function_calling, + "supports_streaming": model.supports_streaming, + "supports_extended_thinking": model.supports_extended_thinking, + "supports_image_generation": model.supports_image_generation, + "is_active": model.is_active, + "config": model.config, + } + ) + + providers_data.append( + { + "name": provider.name, + "display_name": provider.display_name, + "description": provider.description, + "website": provider.website, + "billing_type": provider.billing_type.value if provider.billing_type else None, + "monthly_quota_usd": provider.monthly_quota_usd, + "quota_reset_day": provider.quota_reset_day, + "rpm_limit": provider.rpm_limit, + "provider_priority": provider.provider_priority, + "is_active": provider.is_active, + "rate_limit": provider.rate_limit, + "concurrent_limit": provider.concurrent_limit, + "config": provider.config, + "endpoints": endpoints_data, + "models": models_data, + } + ) + + return { + "version": "1.0", + "exported_at": datetime.now(timezone.utc).isoformat(), + "global_models": global_models_data, + "providers": providers_data, + } + + +MAX_IMPORT_SIZE = 10 * 1024 * 1024 # 10MB + + +class AdminImportConfigAdapter(AdminApiAdapter): + async def handle(self, context): # type: ignore[override] + """导入提供商和模型配置""" + import uuid + from datetime import datetime, timezone + + from src.core.crypto import crypto_service + from src.core.enums import ProviderBillingType + from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint + + # 检查请求体大小 + if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE: + raise InvalidRequestException("请求体大小不能超过 10MB") + + db = context.db + payload = context.ensure_json_body() + + # 验证配置版本 + version = payload.get("version") + if version != "1.0": + raise InvalidRequestException(f"不支持的配置版本: {version}") + + # 获取导入选项 + merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error + global_models_data = payload.get("global_models", []) + providers_data = payload.get("providers", []) + + stats = { + "global_models": {"created": 0, "updated": 0, "skipped": 0}, + "providers": {"created": 0, "updated": 0, "skipped": 0}, + "endpoints": {"created": 0, "updated": 0, "skipped": 0}, + "keys": {"created": 0, "updated": 0, "skipped": 0}, + "models": {"created": 0, "updated": 0, "skipped": 0}, + "errors": [], + } + + try: + # 导入 GlobalModels + global_model_map = {} # name -> id 映射 + for gm_data in global_models_data: + existing = ( + db.query(GlobalModel).filter(GlobalModel.name == gm_data["name"]).first() + ) + + if existing: + global_model_map[gm_data["name"]] = existing.id + if merge_mode == "skip": + stats["global_models"]["skipped"] += 1 + continue + elif merge_mode == "error": + raise InvalidRequestException( + f"GlobalModel '{gm_data['name']}' 已存在" + ) + elif merge_mode == "overwrite": + # 更新现有记录 + existing.display_name = gm_data.get( + "display_name", existing.display_name + ) + existing.default_price_per_request = gm_data.get( + "default_price_per_request" + ) + existing.default_tiered_pricing = gm_data.get( + "default_tiered_pricing", existing.default_tiered_pricing + ) + existing.supported_capabilities = gm_data.get( + "supported_capabilities" + ) + existing.config = gm_data.get("config") + existing.is_active = gm_data.get("is_active", True) + existing.updated_at = datetime.now(timezone.utc) + stats["global_models"]["updated"] += 1 + else: + # 创建新记录 + new_gm = GlobalModel( + id=str(uuid.uuid4()), + name=gm_data["name"], + display_name=gm_data.get("display_name", gm_data["name"]), + default_price_per_request=gm_data.get("default_price_per_request"), + default_tiered_pricing=gm_data.get( + "default_tiered_pricing", + {"tiers": [{"up_to": None, "input_price_per_1m": 0, "output_price_per_1m": 0}]}, + ), + supported_capabilities=gm_data.get("supported_capabilities"), + config=gm_data.get("config"), + is_active=gm_data.get("is_active", True), + ) + db.add(new_gm) + db.flush() + global_model_map[gm_data["name"]] = new_gm.id + stats["global_models"]["created"] += 1 + + # 导入 Providers + for prov_data in providers_data: + existing_provider = ( + db.query(Provider).filter(Provider.name == prov_data["name"]).first() + ) + + if existing_provider: + provider_id = existing_provider.id + if merge_mode == "skip": + stats["providers"]["skipped"] += 1 + # 仍然需要处理 endpoints 和 models(如果存在) + elif merge_mode == "error": + raise InvalidRequestException( + f"Provider '{prov_data['name']}' 已存在" + ) + elif merge_mode == "overwrite": + # 更新现有记录 + existing_provider.display_name = prov_data.get( + "display_name", existing_provider.display_name + ) + existing_provider.description = prov_data.get("description") + existing_provider.website = prov_data.get("website") + if prov_data.get("billing_type"): + existing_provider.billing_type = ProviderBillingType( + prov_data["billing_type"] + ) + existing_provider.monthly_quota_usd = prov_data.get( + "monthly_quota_usd" + ) + existing_provider.quota_reset_day = prov_data.get( + "quota_reset_day", 30 + ) + existing_provider.rpm_limit = prov_data.get("rpm_limit") + existing_provider.provider_priority = prov_data.get( + "provider_priority", 100 + ) + existing_provider.is_active = prov_data.get("is_active", True) + existing_provider.rate_limit = prov_data.get("rate_limit") + existing_provider.concurrent_limit = prov_data.get( + "concurrent_limit" + ) + existing_provider.config = prov_data.get("config") + existing_provider.updated_at = datetime.now(timezone.utc) + stats["providers"]["updated"] += 1 + else: + # 创建新 Provider + billing_type = ProviderBillingType.PAY_AS_YOU_GO + if prov_data.get("billing_type"): + billing_type = ProviderBillingType(prov_data["billing_type"]) + + new_provider = Provider( + id=str(uuid.uuid4()), + name=prov_data["name"], + display_name=prov_data.get("display_name", prov_data["name"]), + description=prov_data.get("description"), + website=prov_data.get("website"), + billing_type=billing_type, + monthly_quota_usd=prov_data.get("monthly_quota_usd"), + quota_reset_day=prov_data.get("quota_reset_day", 30), + rpm_limit=prov_data.get("rpm_limit"), + provider_priority=prov_data.get("provider_priority", 100), + is_active=prov_data.get("is_active", True), + rate_limit=prov_data.get("rate_limit"), + concurrent_limit=prov_data.get("concurrent_limit"), + config=prov_data.get("config"), + ) + db.add(new_provider) + db.flush() + provider_id = new_provider.id + stats["providers"]["created"] += 1 + + # 导入 Endpoints + for ep_data in prov_data.get("endpoints", []): + existing_ep = ( + db.query(ProviderEndpoint) + .filter( + ProviderEndpoint.provider_id == provider_id, + ProviderEndpoint.api_format == ep_data["api_format"], + ) + .first() + ) + + if existing_ep: + endpoint_id = existing_ep.id + if merge_mode == "skip": + stats["endpoints"]["skipped"] += 1 + elif merge_mode == "error": + raise InvalidRequestException( + f"Endpoint '{ep_data['api_format']}' 已存在于 Provider '{prov_data['name']}'" + ) + elif merge_mode == "overwrite": + existing_ep.base_url = ep_data.get( + "base_url", existing_ep.base_url + ) + existing_ep.headers = ep_data.get("headers") + existing_ep.timeout = ep_data.get("timeout", 300) + existing_ep.max_retries = ep_data.get("max_retries", 3) + existing_ep.max_concurrent = ep_data.get("max_concurrent") + existing_ep.rate_limit = ep_data.get("rate_limit") + existing_ep.is_active = ep_data.get("is_active", True) + existing_ep.custom_path = ep_data.get("custom_path") + existing_ep.config = ep_data.get("config") + existing_ep.updated_at = datetime.now(timezone.utc) + stats["endpoints"]["updated"] += 1 + else: + new_ep = ProviderEndpoint( + id=str(uuid.uuid4()), + provider_id=provider_id, + api_format=ep_data["api_format"], + base_url=ep_data["base_url"], + headers=ep_data.get("headers"), + timeout=ep_data.get("timeout", 300), + max_retries=ep_data.get("max_retries", 3), + max_concurrent=ep_data.get("max_concurrent"), + rate_limit=ep_data.get("rate_limit"), + is_active=ep_data.get("is_active", True), + custom_path=ep_data.get("custom_path"), + config=ep_data.get("config"), + ) + db.add(new_ep) + db.flush() + endpoint_id = new_ep.id + stats["endpoints"]["created"] += 1 + + # 导入 Keys + # 获取当前 endpoint 下所有已有的 keys,用于去重 + existing_keys = ( + db.query(ProviderAPIKey) + .filter(ProviderAPIKey.endpoint_id == endpoint_id) + .all() + ) + # 解密已有 keys 用于比对 + existing_key_values = set() + for ek in existing_keys: + try: + decrypted = crypto_service.decrypt(ek.api_key) + existing_key_values.add(decrypted) + except Exception: + pass + + for key_data in ep_data.get("keys", []): + if not key_data.get("api_key"): + stats["errors"].append( + f"跳过空 API Key (Endpoint: {ep_data['api_format']})" + ) + continue + + # 检查是否已存在相同的 Key(通过明文比对) + if key_data["api_key"] in existing_key_values: + stats["keys"]["skipped"] += 1 + continue + + encrypted_key = crypto_service.encrypt(key_data["api_key"]) + + new_key = ProviderAPIKey( + id=str(uuid.uuid4()), + endpoint_id=endpoint_id, + api_key=encrypted_key, + name=key_data.get("name"), + note=key_data.get("note"), + rate_multiplier=key_data.get("rate_multiplier", 1.0), + internal_priority=key_data.get("internal_priority", 100), + global_priority=key_data.get("global_priority"), + max_concurrent=key_data.get("max_concurrent"), + rate_limit=key_data.get("rate_limit"), + daily_limit=key_data.get("daily_limit"), + monthly_limit=key_data.get("monthly_limit"), + allowed_models=key_data.get("allowed_models"), + capabilities=key_data.get("capabilities"), + is_active=key_data.get("is_active", True), + ) + db.add(new_key) + # 添加到已有集合,防止同一批导入中重复 + existing_key_values.add(key_data["api_key"]) + stats["keys"]["created"] += 1 + + # 导入 Models + for model_data in prov_data.get("models", []): + global_model_name = model_data.get("global_model_name") + if not global_model_name: + stats["errors"].append( + f"跳过无 global_model_name 的模型 (Provider: {prov_data['name']})" + ) + continue + + global_model_id = global_model_map.get(global_model_name) + if not global_model_id: + # 尝试从数据库查找 + existing_gm = ( + db.query(GlobalModel) + .filter(GlobalModel.name == global_model_name) + .first() + ) + if existing_gm: + global_model_id = existing_gm.id + else: + stats["errors"].append( + f"GlobalModel '{global_model_name}' 不存在,跳过模型" + ) + continue + + existing_model = ( + db.query(Model) + .filter( + Model.provider_id == provider_id, + Model.provider_model_name == model_data["provider_model_name"], + ) + .first() + ) + + if existing_model: + if merge_mode == "skip": + stats["models"]["skipped"] += 1 + elif merge_mode == "error": + raise InvalidRequestException( + f"Model '{model_data['provider_model_name']}' 已存在于 Provider '{prov_data['name']}'" + ) + elif merge_mode == "overwrite": + existing_model.global_model_id = global_model_id + existing_model.provider_model_aliases = model_data.get( + "provider_model_aliases" + ) + existing_model.price_per_request = model_data.get( + "price_per_request" + ) + existing_model.tiered_pricing = model_data.get( + "tiered_pricing" + ) + existing_model.supports_vision = model_data.get( + "supports_vision" + ) + existing_model.supports_function_calling = model_data.get( + "supports_function_calling" + ) + existing_model.supports_streaming = model_data.get( + "supports_streaming" + ) + existing_model.supports_extended_thinking = model_data.get( + "supports_extended_thinking" + ) + existing_model.supports_image_generation = model_data.get( + "supports_image_generation" + ) + existing_model.is_active = model_data.get("is_active", True) + existing_model.config = model_data.get("config") + existing_model.updated_at = datetime.now(timezone.utc) + stats["models"]["updated"] += 1 + else: + new_model = Model( + id=str(uuid.uuid4()), + provider_id=provider_id, + global_model_id=global_model_id, + provider_model_name=model_data["provider_model_name"], + provider_model_aliases=model_data.get( + "provider_model_aliases" + ), + price_per_request=model_data.get("price_per_request"), + tiered_pricing=model_data.get("tiered_pricing"), + supports_vision=model_data.get("supports_vision"), + supports_function_calling=model_data.get( + "supports_function_calling" + ), + supports_streaming=model_data.get("supports_streaming"), + supports_extended_thinking=model_data.get( + "supports_extended_thinking" + ), + supports_image_generation=model_data.get( + "supports_image_generation" + ), + is_active=model_data.get("is_active", True), + config=model_data.get("config"), + ) + db.add(new_model) + stats["models"]["created"] += 1 + + db.commit() + + # 失效缓存 + from src.services.cache.invalidation import get_cache_invalidation_service + + cache_service = get_cache_invalidation_service() + cache_service.invalidate_all() + + return { + "message": "配置导入成功", + "stats": stats, + } + + except InvalidRequestException: + db.rollback() + raise + except Exception as e: + db.rollback() + raise InvalidRequestException(f"导入失败: {str(e)}") + + +class AdminExportUsersAdapter(AdminApiAdapter): + async def handle(self, context): # type: ignore[override] + """导出用户数据(保留加密数据,排除管理员)""" + from datetime import datetime, timezone + + from src.core.enums import UserRole + from src.models.database import ApiKey, User + + db = context.db + + # 导出 Users(排除管理员) + users = db.query(User).filter( + User.is_deleted.is_(False), + User.role != UserRole.ADMIN + ).all() + users_data = [] + for user in users: + # 导出用户的 API Keys(保留加密数据) + api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all() + api_keys_data = [] + for key in api_keys: + api_keys_data.append( + { + "key_hash": key.key_hash, + "key_encrypted": key.key_encrypted, + "name": key.name, + "is_standalone": key.is_standalone, + "balance_used_usd": key.balance_used_usd, + "current_balance_usd": key.current_balance_usd, + "allowed_providers": key.allowed_providers, + "allowed_endpoints": key.allowed_endpoints, + "allowed_api_formats": key.allowed_api_formats, + "allowed_models": key.allowed_models, + "rate_limit": key.rate_limit, + "concurrent_limit": key.concurrent_limit, + "force_capabilities": key.force_capabilities, + "is_active": key.is_active, + "auto_delete_on_expiry": key.auto_delete_on_expiry, + "total_requests": key.total_requests, + "total_cost_usd": key.total_cost_usd, + } + ) + + users_data.append( + { + "email": user.email, + "username": user.username, + "password_hash": user.password_hash, + "role": user.role.value if user.role else "user", + "allowed_providers": user.allowed_providers, + "allowed_endpoints": user.allowed_endpoints, + "allowed_models": user.allowed_models, + "model_capability_settings": user.model_capability_settings, + "quota_usd": user.quota_usd, + "used_usd": user.used_usd, + "total_usd": user.total_usd, + "is_active": user.is_active, + "api_keys": api_keys_data, + } + ) + + return { + "version": "1.0", + "exported_at": datetime.now(timezone.utc).isoformat(), + "users": users_data, + } + + +class AdminImportUsersAdapter(AdminApiAdapter): + async def handle(self, context): # type: ignore[override] + """导入用户数据""" + import uuid + from datetime import datetime, timezone + + from src.core.enums import UserRole + from src.models.database import ApiKey, User + + # 检查请求体大小 + if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE: + raise InvalidRequestException("请求体大小不能超过 10MB") + + db = context.db + payload = context.ensure_json_body() + + # 验证配置版本 + version = payload.get("version") + if version != "1.0": + raise InvalidRequestException(f"不支持的配置版本: {version}") + + # 获取导入选项 + merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error + users_data = payload.get("users", []) + + stats = { + "users": {"created": 0, "updated": 0, "skipped": 0}, + "api_keys": {"created": 0, "skipped": 0}, + "errors": [], + } + + try: + for user_data in users_data: + # 跳过管理员角色的导入(不区分大小写) + role_str = str(user_data.get("role", "")).lower() + if role_str == "admin": + stats["errors"].append(f"跳过管理员用户: {user_data.get('email')}") + stats["users"]["skipped"] += 1 + continue + + existing_user = ( + db.query(User).filter(User.email == user_data["email"]).first() + ) + + if existing_user: + user_id = existing_user.id + if merge_mode == "skip": + stats["users"]["skipped"] += 1 + elif merge_mode == "error": + raise InvalidRequestException( + f"用户 '{user_data['email']}' 已存在" + ) + elif merge_mode == "overwrite": + # 更新现有用户 + existing_user.username = user_data.get( + "username", existing_user.username + ) + if user_data.get("password_hash"): + existing_user.password_hash = user_data["password_hash"] + if user_data.get("role"): + existing_user.role = UserRole(user_data["role"]) + existing_user.allowed_providers = user_data.get("allowed_providers") + existing_user.allowed_endpoints = user_data.get("allowed_endpoints") + existing_user.allowed_models = user_data.get("allowed_models") + existing_user.model_capability_settings = user_data.get( + "model_capability_settings" + ) + existing_user.quota_usd = user_data.get("quota_usd") + existing_user.used_usd = user_data.get("used_usd", 0.0) + existing_user.total_usd = user_data.get("total_usd", 0.0) + existing_user.is_active = user_data.get("is_active", True) + existing_user.updated_at = datetime.now(timezone.utc) + stats["users"]["updated"] += 1 + else: + # 创建新用户 + role = UserRole.USER + if user_data.get("role"): + role = UserRole(user_data["role"]) + + new_user = User( + id=str(uuid.uuid4()), + email=user_data["email"], + username=user_data.get("username", user_data["email"].split("@")[0]), + password_hash=user_data.get("password_hash", ""), + role=role, + allowed_providers=user_data.get("allowed_providers"), + allowed_endpoints=user_data.get("allowed_endpoints"), + allowed_models=user_data.get("allowed_models"), + model_capability_settings=user_data.get("model_capability_settings"), + quota_usd=user_data.get("quota_usd"), + used_usd=user_data.get("used_usd", 0.0), + total_usd=user_data.get("total_usd", 0.0), + is_active=user_data.get("is_active", True), + ) + db.add(new_user) + db.flush() + user_id = new_user.id + stats["users"]["created"] += 1 + + # 导入 API Keys + for key_data in user_data.get("api_keys", []): + # 检查是否已存在相同的 key_hash + if key_data.get("key_hash"): + existing_key = ( + db.query(ApiKey) + .filter(ApiKey.key_hash == key_data["key_hash"]) + .first() + ) + if existing_key: + stats["api_keys"]["skipped"] += 1 + continue + + new_key = ApiKey( + id=str(uuid.uuid4()), + user_id=user_id, + key_hash=key_data.get("key_hash", ""), + key_encrypted=key_data.get("key_encrypted"), + name=key_data.get("name"), + is_standalone=key_data.get("is_standalone", False), + balance_used_usd=key_data.get("balance_used_usd", 0.0), + current_balance_usd=key_data.get("current_balance_usd"), + allowed_providers=key_data.get("allowed_providers"), + allowed_endpoints=key_data.get("allowed_endpoints"), + allowed_api_formats=key_data.get("allowed_api_formats"), + allowed_models=key_data.get("allowed_models"), + rate_limit=key_data.get("rate_limit", 100), + concurrent_limit=key_data.get("concurrent_limit", 5), + force_capabilities=key_data.get("force_capabilities"), + is_active=key_data.get("is_active", True), + auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False), + total_requests=key_data.get("total_requests", 0), + total_cost_usd=key_data.get("total_cost_usd", 0.0), + ) + db.add(new_key) + stats["api_keys"]["created"] += 1 + + db.commit() + + return { + "message": "用户数据导入成功", + "stats": stats, + } + + except InvalidRequestException: + db.rollback() + raise + except Exception as e: + db.rollback() + raise InvalidRequestException(f"导入失败: {str(e)}") diff --git a/src/models/pydantic_models.py b/src/models/pydantic_models.py index e5e73ba..c632715 100644 --- a/src/models/pydantic_models.py +++ b/src/models/pydantic_models.py @@ -238,8 +238,8 @@ class GlobalModelResponse(BaseModel): # 按次计费配置 default_price_per_request: Optional[float] = Field(None, description="每次请求固定费用") # 阶梯计费配置 - default_tiered_pricing: TieredPricingConfig = Field( - ..., description="阶梯计费配置" + default_tiered_pricing: Optional[TieredPricingConfig] = Field( + default=None, description="阶梯计费配置" ) # Key 能力配置 - 模型支持的能力列表 supported_capabilities: Optional[List[str]] = Field(