diff --git a/components/ai-elements/model-selector.tsx b/components/ai-elements/model-selector.tsx
index 9b5aff8..d5bcae2 100644
--- a/components/ai-elements/model-selector.tsx
+++ b/components/ai-elements/model-selector.tsx
@@ -1,3 +1,4 @@
+import { Cloud } from "lucide-react"
import type { ComponentProps, ReactNode } from "react"
import {
Command,
@@ -112,16 +113,23 @@ export const ModelSelectorLogo = ({
provider,
className,
...props
-}: ModelSelectorLogoProps) => (
-
-)
+}: ModelSelectorLogoProps) => {
+ // Use Lucide icon for bedrock since models.dev doesn't have a good AWS icon
+ if (provider === "amazon-bedrock") {
+ return
+ }
+
+ return (
+
+ )
+}
export type ModelSelectorLogoGroupProps = ComponentProps<"div">
diff --git a/components/chat-input.tsx b/components/chat-input.tsx
index 684d94d..fd67c84 100644
--- a/components/chat-input.tsx
+++ b/components/chat-input.tsx
@@ -14,6 +14,7 @@ import { toast } from "sonner"
import { ButtonWithTooltip } from "@/components/button-with-tooltip"
import { ErrorToast } from "@/components/error-toast"
import { HistoryDialog } from "@/components/history-dialog"
+import { ModelSelector } from "@/components/model-selector"
import { ResetWarningModal } from "@/components/reset-warning-modal"
import { SaveDialog } from "@/components/save-dialog"
import { Button } from "@/components/ui/button"
@@ -28,6 +29,7 @@ import { useDiagram } from "@/contexts/diagram-context"
import { useDictionary } from "@/hooks/use-dictionary"
import { formatMessage } from "@/lib/i18n/utils"
import { isPdfFile, isTextFile } from "@/lib/pdf-utils"
+import type { FlattenedModel } from "@/lib/types/model-config"
import { FilePreviewList } from "./file-preview-list"
const MAX_IMAGE_SIZE = 2 * 1024 * 1024 // 2MB
@@ -156,6 +158,11 @@ interface ChatInputProps {
error?: Error | null
minimalStyle?: boolean
onMinimalStyleChange?: (value: boolean) => void
+ // Model selector props
+ models?: FlattenedModel[]
+ selectedModelId?: string
+ onModelSelect?: (modelId: string | undefined) => void
+ onConfigureModels?: () => void
}
export function ChatInput({
@@ -173,6 +180,10 @@ export function ChatInput({
error = null,
minimalStyle = false,
onMinimalStyleChange = () => {},
+ models = [],
+ selectedModelId,
+ onModelSelect = () => {},
+ onConfigureModels = () => {},
}: ChatInputProps) {
const dict = useDictionary()
const {
@@ -465,6 +476,14 @@ export function ChatInput({
disabled={isDisabled}
/>
+
+
- {!isMobile && modelConfig.isLoaded && (
-
- setShowModelConfigDialog(true)
- }
- />
- )}
setShowModelConfigDialog(true)}
/>
diff --git a/components/model-config-dialog.tsx b/components/model-config-dialog.tsx
index 99aa596..b7a5ba0 100644
--- a/components/model-config-dialog.tsx
+++ b/components/model-config-dialog.tsx
@@ -1,16 +1,36 @@
"use client"
-import { Check, Eye, EyeOff, Loader2, Plus, Trash2, X } from "lucide-react"
-import { useCallback, useState } from "react"
-import { Button } from "@/components/ui/button"
import {
- Command,
- CommandEmpty,
- CommandGroup,
- CommandInput,
- CommandItem,
- CommandList,
-} from "@/components/ui/command"
+ AlertCircle,
+ Check,
+ ChevronRight,
+ Clock,
+ Cloud,
+ Eye,
+ EyeOff,
+ Key,
+ Link2,
+ Loader2,
+ Plus,
+ Server,
+ Sparkles,
+ Tag,
+ Trash2,
+ X,
+ Zap,
+} from "lucide-react"
+import { useCallback, useEffect, useRef, useState } from "react"
+import {
+ AlertDialog,
+ AlertDialogAction,
+ AlertDialogCancel,
+ AlertDialogContent,
+ AlertDialogDescription,
+ AlertDialogFooter,
+ AlertDialogHeader,
+ AlertDialogTitle,
+} from "@/components/ui/alert-dialog"
+import { Button } from "@/components/ui/button"
import {
Dialog,
DialogContent,
@@ -20,11 +40,6 @@ import {
} from "@/components/ui/dialog"
import { Input } from "@/components/ui/input"
import { Label } from "@/components/ui/label"
-import {
- Popover,
- PopoverContent,
- PopoverTrigger,
-} from "@/components/ui/popover"
import { ScrollArea } from "@/components/ui/scroll-area"
import {
Select,
@@ -33,15 +48,11 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/ui/select"
-import { Switch } from "@/components/ui/switch"
import { useDictionary } from "@/hooks/use-dictionary"
import type { UseModelConfigReturn } from "@/hooks/use-model-config"
-import type {
- ModelConfig,
- ProviderConfig,
- ProviderName,
-} from "@/lib/types/model-config"
+import type { ProviderConfig, ProviderName } from "@/lib/types/model-config"
import { PROVIDER_INFO, SUGGESTED_MODELS } from "@/lib/types/model-config"
+import { cn } from "@/lib/utils"
interface ModelConfigDialogProps {
open: boolean
@@ -51,6 +62,44 @@ interface ModelConfigDialogProps {
type ValidationStatus = "idle" | "validating" | "success" | "error"
+// Map provider names to models.dev logo names
+const PROVIDER_LOGO_MAP: Record = {
+ openai: "openai",
+ anthropic: "anthropic",
+ google: "google",
+ azure: "azure",
+ bedrock: "amazon-bedrock",
+ openrouter: "openrouter",
+ deepseek: "deepseek",
+ siliconflow: "siliconflow",
+ gateway: "vercel",
+}
+
+// Provider logo component
+function ProviderLogo({
+ provider,
+ className,
+}: {
+ provider: ProviderName
+ className?: string
+}) {
+ // Use Lucide icon for bedrock since models.dev doesn't have a good AWS icon
+ if (provider === "bedrock") {
+ return
+ }
+
+ const logoName = PROVIDER_LOGO_MAP[provider] || provider
+ return (
+
+ )
+}
+
export function ModelConfigDialog({
open,
onOpenChange,
@@ -64,8 +113,13 @@ export function ModelConfigDialog({
const [validationStatus, setValidationStatus] =
useState("idle")
const [validationError, setValidationError] = useState("")
- const [modelPopoverOpen, setModelPopoverOpen] = useState(false)
- const [modelSearchValue, setModelSearchValue] = useState("")
+ const [scrollState, setScrollState] = useState({ top: false, bottom: true })
+ const [customModelInput, setCustomModelInput] = useState("")
+ const scrollRef = useRef(null)
+ const [deleteConfirmOpen, setDeleteConfirmOpen] = useState(false)
+ const [validatingModelIndex, setValidatingModelIndex] = useState<
+ number | null
+ >(null)
const {
config,
@@ -82,6 +136,26 @@ export function ModelConfigDialog({
(p) => p.id === selectedProviderId,
)
+ // Track scroll position for gradient shadows
+ useEffect(() => {
+ const scrollEl = scrollRef.current?.querySelector(
+ "[data-radix-scroll-area-viewport]",
+ ) as HTMLElement | null
+ if (!scrollEl) return
+
+ const handleScroll = () => {
+ const { scrollTop, scrollHeight, clientHeight } = scrollEl
+ setScrollState({
+ top: scrollTop > 10,
+ bottom: scrollTop < scrollHeight - clientHeight - 10,
+ })
+ }
+
+ handleScroll() // Initial check
+ scrollEl.addEventListener("scroll", handleScroll)
+ return () => scrollEl.removeEventListener("scroll", handleScroll)
+ }, [selectedProvider])
+
// Get suggested models for current provider
const suggestedModels = selectedProvider
? SUGGESTED_MODELS[selectedProvider.provider] || []
@@ -114,16 +188,6 @@ export function ModelConfigDialog({
addModel(selectedProviderId, modelId)
}
- // Handle model field updates
- const handleModelUpdate = (
- modelConfigId: string,
- field: keyof ModelConfig,
- value: string | boolean,
- ) => {
- if (!selectedProviderId) return
- updateModel(selectedProviderId, modelConfigId, { [field]: value })
- }
-
// Handle deleting a model
const handleDeleteModel = (modelConfigId: string) => {
if (!selectedProviderId) return
@@ -136,9 +200,10 @@ export function ModelConfigDialog({
deleteProvider(selectedProviderId)
setSelectedProviderId(null)
setValidationStatus("idle")
+ setDeleteConfirmOpen(false)
}
- // Validate API key
+ // Validate all models
const handleValidate = useCallback(async () => {
if (!selectedProvider || !selectedProvider.apiKey) return
@@ -152,456 +217,711 @@ export function ModelConfigDialog({
setValidationStatus("validating")
setValidationError("")
- try {
- const response = await fetch("/api/validate-model", {
- method: "POST",
- headers: { "Content-Type": "application/json" },
- body: JSON.stringify({
- provider: selectedProvider.provider,
- apiKey: selectedProvider.apiKey,
- baseUrl: selectedProvider.baseUrl,
- modelId: selectedProvider.models[0].modelId,
- }),
- })
+ let allValid = true
+ let errorCount = 0
- const data = await response.json()
+ // Validate each model
+ for (let i = 0; i < selectedProvider.models.length; i++) {
+ const model = selectedProvider.models[i]
+ setValidatingModelIndex(i)
- if (data.valid) {
- setValidationStatus("success")
- updateProvider(selectedProviderId!, { validated: true })
- } else {
- setValidationStatus("error")
- setValidationError(data.error || "Validation failed")
+ try {
+ const response = await fetch("/api/validate-model", {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify({
+ provider: selectedProvider.provider,
+ apiKey: selectedProvider.apiKey,
+ baseUrl: selectedProvider.baseUrl,
+ modelId: model.modelId,
+ }),
+ })
+
+ const data = await response.json()
+
+ if (data.valid) {
+ updateModel(selectedProviderId!, model.id, {
+ validated: true,
+ validationError: undefined,
+ })
+ } else {
+ allValid = false
+ errorCount++
+ updateModel(selectedProviderId!, model.id, {
+ validated: false,
+ validationError: data.error || "Validation failed",
+ })
+ }
+ } catch {
+ allValid = false
+ errorCount++
+ updateModel(selectedProviderId!, model.id, {
+ validated: false,
+ validationError: "Network error",
+ })
}
- } catch {
- setValidationStatus("error")
- setValidationError("Network error")
}
- }, [selectedProvider, selectedProviderId, updateProvider])
- // Get all available provider types (allow duplicates for different base URLs)
+ setValidatingModelIndex(null)
+
+ if (allValid) {
+ setValidationStatus("success")
+ updateProvider(selectedProviderId!, { validated: true })
+ } else {
+ setValidationStatus("error")
+ setValidationError(`${errorCount} model(s) failed validation`)
+ }
+ }, [selectedProvider, selectedProviderId, updateProvider, updateModel])
+
+ // Get all available provider types
const availableProviders = Object.keys(PROVIDER_INFO) as ProviderName[]
- // Get display name for provider (use custom name if set)
+ // Get display name for provider
const getProviderDisplayName = (provider: ProviderConfig) => {
return provider.name || PROVIDER_INFO[provider.provider].label
}
return (