diff --git a/README.md b/README.md index 7d675fa..d222199 100644 --- a/README.md +++ b/README.md @@ -60,11 +60,11 @@ python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env # 3. 部署 docker-compose up -d -# 4. 更新 -docker-compose pull && docker-compose up -d - -# 5. 数据库迁移 - 更新后执行 +# 4. 首次部署时, 初始化数据库 ./migrate.sh + +# 5. 更新 +docker-compose pull && docker-compose up -d && ./migrate.sh ``` ### Docker Compose(本地构建镜像) diff --git a/frontend/src/api/models-dev.ts b/frontend/src/api/models-dev.ts index 594df17..c66377c 100644 --- a/frontend/src/api/models-dev.ts +++ b/frontend/src/api/models-dev.ts @@ -50,6 +50,7 @@ export interface ModelsDevProvider { name: string doc?: string models: Record + official?: boolean // 是否为官方提供商 } export type ModelsDevData = Record @@ -68,7 +69,12 @@ export interface ModelsDevModelItem { supportsVision?: boolean supportsToolCall?: boolean supportsReasoning?: boolean + supportsStructuredOutput?: boolean + supportsTemperature?: boolean + supportsAttachment?: boolean + openWeights?: boolean deprecated?: boolean + official?: boolean // 是否来自官方提供商 // 用于 display_metadata 的额外字段 knowledgeCutoff?: string releaseDate?: string @@ -84,13 +90,21 @@ interface CacheData { // 内存缓存 let memoryCache: CacheData | null = null +function hasOfficialFlag(data: ModelsDevData): boolean { + return Object.values(data).some(provider => typeof provider?.official === 'boolean') +} + /** * 获取 models.dev 数据(带缓存) */ export async function getModelsDevData(): Promise { // 1. 检查内存缓存 if (memoryCache && Date.now() - memoryCache.timestamp < CACHE_DURATION) { - return memoryCache.data + // 兼容旧缓存:没有 official 字段时丢弃,强制刷新一次 + if (hasOfficialFlag(memoryCache.data)) { + return memoryCache.data + } + memoryCache = null } // 2. 检查 localStorage 缓存 @@ -99,8 +113,12 @@ export async function getModelsDevData(): Promise { if (cached) { const cacheData: CacheData = JSON.parse(cached) if (Date.now() - cacheData.timestamp < CACHE_DURATION) { - memoryCache = cacheData - return cacheData.data + // 兼容旧缓存:没有 official 字段时丢弃,强制刷新一次 + if (hasOfficialFlag(cacheData.data)) { + memoryCache = cacheData + return cacheData.data + } + localStorage.removeItem(CACHE_KEY) } } } catch { @@ -126,52 +144,75 @@ export async function getModelsDevData(): Promise { return data } +// 模型列表缓存(避免重复转换) +let modelsListCache: ModelsDevModelItem[] | null = null +let modelsListCacheTimestamp: number | null = null + /** * 获取扁平化的模型列表 + * 数据只加载一次,通过参数过滤官方/全部 */ -export async function getModelsDevList(): Promise { +export async function getModelsDevList(officialOnly: boolean = true): Promise { const data = await getModelsDevData() - const items: ModelsDevModelItem[] = [] + const currentTimestamp = memoryCache?.timestamp ?? 0 - for (const [providerId, provider] of Object.entries(data)) { - if (!provider.models) continue + // 如果缓存为空或数据已刷新,构建一次 + if (!modelsListCache || modelsListCacheTimestamp !== currentTimestamp) { + const items: ModelsDevModelItem[] = [] - for (const [modelId, model] of Object.entries(provider.models)) { - items.push({ - providerId, - providerName: provider.name, - modelId, - modelName: model.name || modelId, - family: model.family, - inputPrice: model.cost?.input, - outputPrice: model.cost?.output, - contextLimit: model.limit?.context, - outputLimit: model.limit?.output, - supportsVision: model.input?.includes('image'), - supportsToolCall: model.tool_call, - supportsReasoning: model.reasoning, - deprecated: model.deprecated, - // display_metadata 相关字段 - knowledgeCutoff: model.knowledge, - releaseDate: model.release_date, - inputModalities: model.input, - outputModalities: model.output, - }) + for (const [providerId, provider] of Object.entries(data)) { + if (!provider.models) continue + + for (const [modelId, model] of Object.entries(provider.models)) { + items.push({ + providerId, + providerName: provider.name, + modelId, + modelName: model.name || modelId, + family: model.family, + inputPrice: model.cost?.input, + outputPrice: model.cost?.output, + contextLimit: model.limit?.context, + outputLimit: model.limit?.output, + supportsVision: model.input?.includes('image'), + supportsToolCall: model.tool_call, + supportsReasoning: model.reasoning, + supportsStructuredOutput: model.structured_output, + supportsTemperature: model.temperature, + supportsAttachment: model.attachment, + openWeights: model.open_weights, + deprecated: model.deprecated, + official: provider.official, + // display_metadata 相关字段 + knowledgeCutoff: model.knowledge, + releaseDate: model.release_date, + inputModalities: model.input, + outputModalities: model.output, + }) + } } + + // 按 provider 名称和模型名称排序 + items.sort((a, b) => { + const providerCompare = a.providerName.localeCompare(b.providerName) + if (providerCompare !== 0) return providerCompare + return a.modelName.localeCompare(b.modelName) + }) + + modelsListCache = items + modelsListCacheTimestamp = currentTimestamp } - // 按 provider 名称和模型名称排序 - items.sort((a, b) => { - const providerCompare = a.providerName.localeCompare(b.providerName) - if (providerCompare !== 0) return providerCompare - return a.modelName.localeCompare(b.modelName) - }) - - return items + // 根据参数过滤 + if (officialOnly) { + return modelsListCache.filter(m => m.official) + } + return modelsListCache } /** * 搜索模型 + * 搜索时包含所有提供商(包括第三方) */ export async function searchModelsDevModels( query: string, @@ -180,7 +221,8 @@ export async function searchModelsDevModels( excludeDeprecated?: boolean } ): Promise { - const allModels = await getModelsDevList() + // 搜索时包含全部提供商 + const allModels = await getModelsDevList(false) const { limit = 50, excludeDeprecated = true } = options || {} const queryLower = query.toLowerCase() @@ -236,6 +278,8 @@ export function getProviderLogoUrl(providerId: string): string { */ export function clearModelsDevCache(): void { memoryCache = null + modelsListCache = null + modelsListCacheTimestamp = null try { localStorage.removeItem(CACHE_KEY) } catch { diff --git a/frontend/src/features/models/components/GlobalModelFormDialog.vue b/frontend/src/features/models/components/GlobalModelFormDialog.vue index 55c19df..e32dc13 100644 --- a/frontend/src/features/models/components/GlobalModelFormDialog.vue +++ b/frontend/src/features/models/components/GlobalModelFormDialog.vue @@ -14,7 +14,7 @@
@@ -22,13 +22,13 @@
-
+
{{ group.providerName }} @@ -90,7 +90,7 @@
+
+
+ + +
+
+ + +
+
+ + +
+
@@ -329,10 +369,18 @@ const tieredPricingEditorRef = ref | nu // 模型列表相关 const loading = ref(false) const searchQuery = ref('') -const allModels = ref([]) +const allModelsCache = ref([]) // 全部模型(缓存) const selectedModel = ref(null) const expandedProvider = ref(null) +// 当前显示的模型列表:有搜索词时用全部,否则只用官方 +const allModels = computed(() => { + if (searchQuery.value) { + return allModelsCache.value + } + return allModelsCache.value.filter(m => m.official) +}) + // 按提供商分组的模型 interface ProviderGroup { providerId: string @@ -440,10 +488,11 @@ const availableCapabilities = ref([]) // 加载模型列表 async function loadModels() { - if (allModels.value.length > 0) return + if (allModelsCache.value.length > 0) return loading.value = true try { - allModels.value = await getModelsDevList() + // 只加载一次全部模型,过滤在 computed 中完成 + allModelsCache.value = await getModelsDevList(false) } catch (err) { log.error('Failed to load models:', err) } finally { @@ -498,6 +547,10 @@ function selectModel(model: ModelsDevModelItem) { if (model.supportsVision) config.vision = true if (model.supportsToolCall) config.function_calling = true if (model.supportsReasoning) config.extended_thinking = true + if (model.supportsStructuredOutput) config.structured_output = true + if (model.supportsTemperature !== false) config.temperature = model.supportsTemperature + if (model.supportsAttachment) config.attachment = true + if (model.openWeights) config.open_weights = true if (model.contextLimit) config.context_limit = model.contextLimit if (model.outputLimit) config.output_limit = model.outputLimit if (model.knowledgeCutoff) config.knowledge_cutoff = model.knowledgeCutoff diff --git a/frontend/src/style.css b/frontend/src/style.css index 13b5256..5ac9c7e 100644 --- a/frontend/src/style.css +++ b/frontend/src/style.css @@ -1169,4 +1169,26 @@ body[theme-mode='dark'] .literary-annotation { .scrollbar-hide::-webkit-scrollbar { display: none; } -} \ No newline at end of file + + .scrollbar-thin { + scrollbar-width: thin; + scrollbar-color: hsl(var(--border)) transparent; + } + + .scrollbar-thin::-webkit-scrollbar { + width: 6px; + } + + .scrollbar-thin::-webkit-scrollbar-track { + background: transparent; + } + + .scrollbar-thin::-webkit-scrollbar-thumb { + background-color: hsl(var(--border)); + border-radius: 3px; + } + + .scrollbar-thin::-webkit-scrollbar-thumb:hover { + background-color: hsl(var(--muted-foreground) / 0.5); + } +} diff --git a/src/api/admin/models/external.py b/src/api/admin/models/external.py index 1e24093..ecdffdd 100644 --- a/src/api/admin/models/external.py +++ b/src/api/admin/models/external.py @@ -20,6 +20,27 @@ router = APIRouter() CACHE_KEY = "aether:external:models_dev" CACHE_TTL = 15 * 60 # 15 分钟 +# 标记官方/一手提供商,前端可据此过滤第三方转售商 +OFFICIAL_PROVIDERS = { + "anthropic", # Claude 官方 + "openai", # OpenAI 官方 + "google", # Gemini 官方 + "google-vertex", # Google Vertex AI + "azure", # Azure OpenAI + "amazon-bedrock", # AWS Bedrock + "xai", # Grok 官方 + "meta", # Llama 官方 + "deepseek", # DeepSeek 官方 + "mistral", # Mistral 官方 + "cohere", # Cohere 官方 + "zhipuai", # 智谱 AI 官方 + "alibaba", # 阿里云(通义千问) + "minimax", # MiniMax 官方 + "moonshot", # 月之暗面(Kimi) + "baichuan", # 百川智能 + "ai21", # AI21 Labs +} + async def _get_cached_data() -> Optional[dict[str, Any]]: """从 Redis 获取缓存数据""" @@ -47,15 +68,40 @@ async def _set_cached_data(data: dict) -> None: logger.warning(f"写入 models.dev 缓存失败: {e}") +def _mark_official_providers(data: dict[str, Any]) -> dict[str, Any]: + """为每个提供商标记是否为官方""" + result = {} + for provider_id, provider_data in data.items(): + result[provider_id] = { + **provider_data, + "official": provider_id in OFFICIAL_PROVIDERS, + } + return result + + @router.get("/external") async def get_external_models(_: User = Depends(require_admin)) -> JSONResponse: """ 获取 models.dev 的模型数据(代理请求,解决跨域问题) 数据缓存 15 分钟(使用 Redis,多 worker 共享) + 每个提供商会标记 official 字段,前端可据此过滤 """ # 检查缓存 cached = await _get_cached_data() if cached is not None: + # 兼容旧缓存:如果没有 official 字段则补全并回写 + try: + needs_mark = False + for provider_data in cached.values(): + if not isinstance(provider_data, dict) or "official" not in provider_data: + needs_mark = True + break + if needs_mark: + marked_cached = _mark_official_providers(cached) + await _set_cached_data(marked_cached) + return JSONResponse(content=marked_cached) + except Exception as e: + logger.warning(f"处理 models.dev 缓存数据失败,将直接返回原缓存: {e}") return JSONResponse(content=cached) # 从 models.dev 获取数据 @@ -65,10 +111,13 @@ async def get_external_models(_: User = Depends(require_admin)) -> JSONResponse: response.raise_for_status() data = response.json() - # 写入缓存 - await _set_cached_data(data) + # 标记官方提供商 + marked_data = _mark_official_providers(data) - return JSONResponse(content=data) + # 写入缓存 + await _set_cached_data(marked_data) + + return JSONResponse(content=marked_data) except httpx.TimeoutException: raise HTTPException(status_code=504, detail="请求 models.dev 超时") except httpx.HTTPStatusError as e: @@ -77,3 +126,16 @@ async def get_external_models(_: User = Depends(require_admin)) -> JSONResponse: ) except Exception as e: raise HTTPException(status_code=502, detail=f"获取外部模型数据失败: {str(e)}") + + +@router.delete("/external/cache") +async def clear_external_models_cache(_: User = Depends(require_admin)) -> dict: + """清除 models.dev 缓存""" + redis = await get_redis_client() + if redis is None: + return {"cleared": False, "message": "Redis 未启用"} + try: + await redis.delete(CACHE_KEY) + return {"cleared": True} + except Exception as e: + raise HTTPException(status_code=500, detail=f"清除缓存失败: {str(e)}") diff --git a/src/api/base/models_service.py b/src/api/base/models_service.py index 93c00c9..d4aeb67 100644 --- a/src/api/base/models_service.py +++ b/src/api/base/models_service.py @@ -65,6 +65,21 @@ class ModelInfo: created_at: Optional[str] # ISO 格式 created_timestamp: int # Unix 时间戳 provider_name: str + # 能力配置 + streaming: bool = True + vision: bool = False + function_calling: bool = False + extended_thinking: bool = False + image_generation: bool = False + structured_output: bool = False + # 规格参数 + context_limit: Optional[int] = None + output_limit: Optional[int] = None + # 元信息 + family: Optional[str] = None + knowledge_cutoff: Optional[str] = None + input_modalities: Optional[list[str]] = None + output_modalities: Optional[list[str]] = None def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]: @@ -181,13 +196,19 @@ def _extract_model_info(model: Any) -> ModelInfo: global_model = model.global_model model_id: str = global_model.name if global_model else model.provider_model_name display_name: str = global_model.display_name if global_model else model.provider_model_name - description: Optional[str] = global_model.description if global_model else None created_at: Optional[str] = ( model.created_at.strftime("%Y-%m-%dT%H:%M:%SZ") if model.created_at else None ) created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0 provider_name: str = model.provider.name if model.provider else "unknown" + # 从 GlobalModel.config 提取配置信息 + config: dict = {} + description: Optional[str] = None + if global_model: + config = global_model.config or {} + description = config.get("description") + return ModelInfo( id=model_id, display_name=display_name, @@ -195,6 +216,21 @@ def _extract_model_info(model: Any) -> ModelInfo: created_at=created_at, created_timestamp=created_timestamp, provider_name=provider_name, + # 能力配置 + streaming=config.get("streaming", True), + vision=config.get("vision", False), + function_calling=config.get("function_calling", False), + extended_thinking=config.get("extended_thinking", False), + image_generation=config.get("image_generation", False), + structured_output=config.get("structured_output", False), + # 规格参数 + context_limit=config.get("context_limit"), + output_limit=config.get("output_limit"), + # 元信息 + family=config.get("family"), + knowledge_cutoff=config.get("knowledge_cutoff"), + input_modalities=config.get("input_modalities"), + output_modalities=config.get("output_modalities"), ) diff --git a/src/api/public/models.py b/src/api/public/models.py index cfffc47..829559d 100644 --- a/src/api/public/models.py +++ b/src/api/public/models.py @@ -251,8 +251,8 @@ def _build_gemini_list_response( "version": "001", "displayName": m.display_name, "description": m.description or f"Model {m.id}", - "inputTokenLimit": 128000, - "outputTokenLimit": 8192, + "inputTokenLimit": m.context_limit if m.context_limit is not None else 128000, + "outputTokenLimit": m.output_limit if m.output_limit is not None else 8192, "supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1.0, "maxTemperature": 2.0, @@ -297,8 +297,8 @@ def _build_gemini_model_response(model_info: ModelInfo) -> dict: "version": "001", "displayName": model_info.display_name, "description": model_info.description or f"Model {model_info.id}", - "inputTokenLimit": 128000, - "outputTokenLimit": 8192, + "inputTokenLimit": model_info.context_limit if model_info.context_limit is not None else 128000, + "outputTokenLimit": model_info.output_limit if model_info.output_limit is not None else 8192, "supportedGenerationMethods": ["generateContent", "countTokens"], "temperature": 1.0, "maxTemperature": 2.0, diff --git a/src/database/database.py b/src/database/database.py index 7397d46..6a07e9d 100644 --- a/src/database/database.py +++ b/src/database/database.py @@ -273,16 +273,17 @@ def get_db_url() -> str: def init_db(): - """初始化数据库""" + """初始化数据库 + + 注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh + """ logger.info("初始化数据库...") # 确保引擎已创建 - engine = _ensure_engine() + _ensure_engine() - # 创建所有表 - Base.metadata.create_all(bind=engine) - - # 数据库表已通过SQLAlchemy自动创建 + # 数据库表结构由 Alembic 迁移管理 + # 首次部署或更新后请运行: ./migrate.sh db = _SessionLocal() try: