"use client" import { AlertTriangle, Bot, Check, ChevronDown, Server, Settings2, } from "lucide-react" import { useEffect, useMemo, useRef, useState } from "react" import { ModelSelectorContent, ModelSelectorEmpty, ModelSelectorGroup, ModelSelectorInput, ModelSelectorItem, ModelSelectorList, ModelSelectorLogo, ModelSelectorName, ModelSelector as ModelSelectorRoot, ModelSelectorSeparator, ModelSelectorTrigger, } from "@/components/ai-elements/model-selector" import { ButtonWithTooltip } from "@/components/button-with-tooltip" import { useDictionary } from "@/hooks/use-dictionary" import type { FlattenedModel } from "@/lib/types/model-config" import { cn } from "@/lib/utils" interface ModelSelectorProps { models: FlattenedModel[] selectedModelId: string | undefined onSelect: (modelId: string | undefined) => void onConfigure: () => void disabled?: boolean showUnvalidatedModels?: boolean } // Map our provider names to models.dev logo names const PROVIDER_LOGO_MAP: Record = { openai: "openai", anthropic: "anthropic", google: "google", azure: "azure", bedrock: "amazon-bedrock", openrouter: "openrouter", deepseek: "deepseek", siliconflow: "siliconflow", sglang: "openai", // SGLang is OpenAI-compatible, use OpenAI logo gateway: "vercel", edgeone: "tencent-cloud", doubao: "bytedance", } // Group models by providerLabel (handles duplicate providers) function groupModelsByProvider( models: FlattenedModel[], ): Map { const groups = new Map< string, { provider: string; models: FlattenedModel[] } >() for (const model of models) { const key = model.providerLabel const existing = groups.get(key) if (existing) { existing.models.push(model) } else { groups.set(key, { provider: model.provider, models: [model] }) } } return groups } export function ModelSelector({ models, selectedModelId, onSelect, onConfigure, disabled = false, showUnvalidatedModels = false, }: ModelSelectorProps) { const dict = useDictionary() const [open, setOpen] = useState(false) // Filter models based on showUnvalidatedModels setting const displayModels = useMemo(() => { if (showUnvalidatedModels) { return models } return models.filter((m) => m.validated === true) }, [models, showUnvalidatedModels]) const groupedModels = useMemo( () => groupModelsByProvider(displayModels), [displayModels], ) // Find selected model for display const selectedModel = useMemo( () => models.find((m) => m.id === selectedModelId), [models, selectedModelId], ) const handleSelect = (value: string) => { if (value === "__configure__") { onConfigure() } else if (value === "__server_default__") { onSelect(undefined) } else { onSelect(value) } setOpen(false) } const tooltipContent = selectedModel ? `${selectedModel.modelId} ${dict.modelConfig.clickToChange}` : `${dict.modelConfig.usingServerDefault} ${dict.modelConfig.clickToChange}` const wrapperRef = useRef(null) const [showLabel, setShowLabel] = useState(true) // Threshold (px) under which we hide the label (tweak as needed) const HIDE_THRESHOLD = 240 const SHOW_THRESHOLD = 260 useEffect(() => { const el = wrapperRef.current if (!el) return const target = el.parentElement ?? el const ro = new ResizeObserver((entries) => { for (const entry of entries) { const width = entry.contentRect.width setShowLabel((prev) => { // if currently showing and width dropped below hide threshold -> hide if (prev && width <= HIDE_THRESHOLD) return false // if currently hidden and width rose above show threshold -> show if (!prev && width >= SHOW_THRESHOLD) return true // otherwise keep previous state (hysteresis) return prev }) } }) ro.observe(target) const initialWidth = target.getBoundingClientRect().width setShowLabel(initialWidth >= SHOW_THRESHOLD) return () => ro.disconnect() }, []) return (
{/* show/hide visible label based on measured width */} {showLabel ? ( {selectedModel ? selectedModel.modelId : dict.modelConfig.default} ) : ( // Keep an sr-only label for screen readers when hidden {selectedModel ? selectedModel.modelId : dict.modelConfig.default} )} {displayModels.length === 0 && models.length > 0 ? dict.modelConfig.noVerifiedModels : dict.modelConfig.noModelsFound} {/* Server Default Option */} {dict.modelConfig.serverDefault} {/* Configured Models by Provider */} {Array.from(groupedModels.entries()).map( ([ providerLabel, { provider, models: providerModels }, ]) => ( {providerModels.map((model) => ( handleSelect(model.id) } className="cursor-pointer" > {model.modelId} {model.validated !== true && ( )} ))} ), )} {/* Configure Option */} {dict.modelConfig.configureModels} {/* Info text */}
{showUnvalidatedModels ? dict.modelConfig.allModelsShown : dict.modelConfig.onlyVerifiedShown}
) }