diff --git a/app/api/chat/route.ts b/app/api/chat/route.ts index 5c44683..b9024e4 100644 --- a/app/api/chat/route.ts +++ b/app/api/chat/route.ts @@ -214,6 +214,11 @@ async function handleChatRequest(req: Request): Promise { baseUrl: req.headers.get("x-ai-base-url"), apiKey: req.headers.get("x-ai-api-key"), modelId: req.headers.get("x-ai-model"), + // AWS Bedrock credentials + awsAccessKeyId: req.headers.get("x-aws-access-key-id"), + awsSecretAccessKey: req.headers.get("x-aws-secret-access-key"), + awsRegion: req.headers.get("x-aws-region"), + awsSessionToken: req.headers.get("x-aws-session-token"), } // Read minimal style preference from header diff --git a/app/api/validate-model/route.ts b/app/api/validate-model/route.ts new file mode 100644 index 0000000..7beeaa4 --- /dev/null +++ b/app/api/validate-model/route.ts @@ -0,0 +1,213 @@ +import { createAmazonBedrock } from "@ai-sdk/amazon-bedrock" +import { createAnthropic } from "@ai-sdk/anthropic" +import { createDeepSeek, deepseek } from "@ai-sdk/deepseek" +import { createGateway } from "@ai-sdk/gateway" +import { createGoogleGenerativeAI } from "@ai-sdk/google" +import { createOpenAI } from "@ai-sdk/openai" +import { createOpenRouter } from "@openrouter/ai-sdk-provider" +import { generateText } from "ai" +import { NextResponse } from "next/server" +import { createOllama } from "ollama-ai-provider-v2" + +export const runtime = "nodejs" + +interface ValidateRequest { + provider: string + apiKey: string + baseUrl?: string + modelId: string + // AWS Bedrock specific + awsAccessKeyId?: string + awsSecretAccessKey?: string + awsRegion?: string +} + +export async function POST(req: Request) { + try { + const body: ValidateRequest = await req.json() + const { + provider, + apiKey, + baseUrl, + modelId, + awsAccessKeyId, + awsSecretAccessKey, + awsRegion, + } = body + + if (!provider || !modelId) { + return NextResponse.json( + { valid: false, error: "Provider and model ID are required" }, + { status: 400 }, + ) + } + + // Validate credentials based on provider + if (provider === "bedrock") { + if (!awsAccessKeyId || !awsSecretAccessKey || !awsRegion) { + return NextResponse.json( + { + valid: false, + error: "AWS credentials (Access Key ID, Secret Access Key, Region) are required", + }, + { status: 400 }, + ) + } + } else if (provider !== "ollama" && !apiKey) { + return NextResponse.json( + { valid: false, error: "API key is required" }, + { status: 400 }, + ) + } + + let model: any + + switch (provider) { + case "openai": { + const openai = createOpenAI({ + apiKey, + ...(baseUrl && { baseURL: baseUrl }), + }) + model = openai.chat(modelId) + break + } + + case "anthropic": { + const anthropic = createAnthropic({ + apiKey, + baseURL: baseUrl || "https://api.anthropic.com/v1", + }) + model = anthropic(modelId) + break + } + + case "google": { + const google = createGoogleGenerativeAI({ + apiKey, + ...(baseUrl && { baseURL: baseUrl }), + }) + model = google(modelId) + break + } + + case "azure": { + const azure = createOpenAI({ + apiKey, + baseURL: baseUrl, + }) + model = azure.chat(modelId) + break + } + + case "bedrock": { + const bedrock = createAmazonBedrock({ + accessKeyId: awsAccessKeyId, + secretAccessKey: awsSecretAccessKey, + region: awsRegion, + }) + model = bedrock(modelId) + break + } + + case "openrouter": { + const openrouter = createOpenRouter({ + apiKey, + ...(baseUrl && { baseURL: baseUrl }), + }) + model = openrouter(modelId) + break + } + + case "deepseek": { + if (baseUrl || apiKey) { + const ds = createDeepSeek({ + apiKey, + ...(baseUrl && { baseURL: baseUrl }), + }) + model = ds(modelId) + } else { + model = deepseek(modelId) + } + break + } + + case "siliconflow": { + const sf = createOpenAI({ + apiKey, + baseURL: baseUrl || "https://api.siliconflow.com/v1", + }) + model = sf.chat(modelId) + break + } + + case "ollama": { + const ollama = createOllama({ + baseURL: baseUrl || "http://localhost:11434", + }) + model = ollama(modelId) + break + } + + case "gateway": { + const gw = createGateway({ + apiKey, + ...(baseUrl && { baseURL: baseUrl }), + }) + model = gw(modelId) + break + } + + default: + return NextResponse.json( + { valid: false, error: `Unknown provider: ${provider}` }, + { status: 400 }, + ) + } + + // Make a minimal test request + const startTime = Date.now() + await generateText({ + model, + prompt: "Say 'OK'", + maxOutputTokens: 20, + }) + const responseTime = Date.now() - startTime + + return NextResponse.json({ + valid: true, + responseTime, + }) + } catch (error) { + console.error("[validate-model] Error:", error) + + let errorMessage = "Validation failed" + if (error instanceof Error) { + // Extract meaningful error message + if ( + error.message.includes("401") || + error.message.includes("Unauthorized") + ) { + errorMessage = "Invalid API key" + } else if ( + error.message.includes("404") || + error.message.includes("not found") + ) { + errorMessage = "Model not found" + } else if ( + error.message.includes("429") || + error.message.includes("rate limit") + ) { + errorMessage = "Rate limited - try again later" + } else if (error.message.includes("ECONNREFUSED")) { + errorMessage = "Cannot connect to server" + } else { + errorMessage = error.message.slice(0, 100) + } + } + + return NextResponse.json( + { valid: false, error: errorMessage }, + { status: 200 }, // Return 200 so client can read error message + ) + } +} diff --git a/components/ai-elements/model-selector.tsx b/components/ai-elements/model-selector.tsx new file mode 100644 index 0000000..d5bcae2 --- /dev/null +++ b/components/ai-elements/model-selector.tsx @@ -0,0 +1,156 @@ +import { Cloud } from "lucide-react" +import type { ComponentProps, ReactNode } from "react" +import { + Command, + CommandDialog, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, + CommandSeparator, + CommandShortcut, +} from "@/components/ui/command" +import { + Dialog, + DialogContent, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog" +import { cn } from "@/lib/utils" + +export type ModelSelectorProps = ComponentProps + +export const ModelSelector = (props: ModelSelectorProps) => ( + +) + +export type ModelSelectorTriggerProps = ComponentProps + +export const ModelSelectorTrigger = (props: ModelSelectorTriggerProps) => ( + +) + +export type ModelSelectorContentProps = ComponentProps & { + title?: ReactNode +} + +export const ModelSelectorContent = ({ + className, + children, + title = "Model Selector", + ...props +}: ModelSelectorContentProps) => ( + + {title} + + {children} + + +) + +export type ModelSelectorDialogProps = ComponentProps + +export const ModelSelectorDialog = (props: ModelSelectorDialogProps) => ( + +) + +export type ModelSelectorInputProps = ComponentProps + +export const ModelSelectorInput = ({ + className, + ...props +}: ModelSelectorInputProps) => ( + +) + +export type ModelSelectorListProps = ComponentProps + +export const ModelSelectorList = (props: ModelSelectorListProps) => ( + +) + +export type ModelSelectorEmptyProps = ComponentProps + +export const ModelSelectorEmpty = (props: ModelSelectorEmptyProps) => ( + +) + +export type ModelSelectorGroupProps = ComponentProps + +export const ModelSelectorGroup = (props: ModelSelectorGroupProps) => ( + +) + +export type ModelSelectorItemProps = ComponentProps + +export const ModelSelectorItem = (props: ModelSelectorItemProps) => ( + +) + +export type ModelSelectorShortcutProps = ComponentProps + +export const ModelSelectorShortcut = (props: ModelSelectorShortcutProps) => ( + +) + +export type ModelSelectorSeparatorProps = ComponentProps< + typeof CommandSeparator +> + +export const ModelSelectorSeparator = (props: ModelSelectorSeparatorProps) => ( + +) + +export type ModelSelectorLogoProps = Omit< + ComponentProps<"img">, + "src" | "alt" +> & { + provider: string +} + +export const ModelSelectorLogo = ({ + provider, + className, + ...props +}: ModelSelectorLogoProps) => { + // Use Lucide icon for bedrock since models.dev doesn't have a good AWS icon + if (provider === "amazon-bedrock") { + return + } + + return ( + {`${provider} + ) +} + +export type ModelSelectorLogoGroupProps = ComponentProps<"div"> + +export const ModelSelectorLogoGroup = ({ + className, + ...props +}: ModelSelectorLogoGroupProps) => ( +
img]:rounded-full [&>img]:bg-background [&>img]:p-px [&>img]:ring-1 dark:[&>img]:bg-foreground", + className, + )} + {...props} + /> +) + +export type ModelSelectorNameProps = ComponentProps<"span"> + +export const ModelSelectorName = ({ + className, + ...props +}: ModelSelectorNameProps) => ( + +) diff --git a/components/chat-input.tsx b/components/chat-input.tsx index 684d94d..fd67c84 100644 --- a/components/chat-input.tsx +++ b/components/chat-input.tsx @@ -14,6 +14,7 @@ import { toast } from "sonner" import { ButtonWithTooltip } from "@/components/button-with-tooltip" import { ErrorToast } from "@/components/error-toast" import { HistoryDialog } from "@/components/history-dialog" +import { ModelSelector } from "@/components/model-selector" import { ResetWarningModal } from "@/components/reset-warning-modal" import { SaveDialog } from "@/components/save-dialog" import { Button } from "@/components/ui/button" @@ -28,6 +29,7 @@ import { useDiagram } from "@/contexts/diagram-context" import { useDictionary } from "@/hooks/use-dictionary" import { formatMessage } from "@/lib/i18n/utils" import { isPdfFile, isTextFile } from "@/lib/pdf-utils" +import type { FlattenedModel } from "@/lib/types/model-config" import { FilePreviewList } from "./file-preview-list" const MAX_IMAGE_SIZE = 2 * 1024 * 1024 // 2MB @@ -156,6 +158,11 @@ interface ChatInputProps { error?: Error | null minimalStyle?: boolean onMinimalStyleChange?: (value: boolean) => void + // Model selector props + models?: FlattenedModel[] + selectedModelId?: string + onModelSelect?: (modelId: string | undefined) => void + onConfigureModels?: () => void } export function ChatInput({ @@ -173,6 +180,10 @@ export function ChatInput({ error = null, minimalStyle = false, onMinimalStyleChange = () => {}, + models = [], + selectedModelId, + onModelSelect = () => {}, + onConfigureModels = () => {}, }: ChatInputProps) { const dict = useDictionary() const { @@ -465,6 +476,14 @@ export function ChatInput({ disabled={isDisabled} /> + +
+ ) +} + +export function ModelConfigDialog({ + open, + onOpenChange, + modelConfig, +}: ModelConfigDialogProps) { + const dict = useDictionary() + const [selectedProviderId, setSelectedProviderId] = useState( + null, + ) + const [showApiKey, setShowApiKey] = useState(false) + const [validationStatus, setValidationStatus] = + useState("idle") + const [validationError, setValidationError] = useState("") + const [scrollState, setScrollState] = useState({ top: false, bottom: true }) + const [customModelInput, setCustomModelInput] = useState("") + const scrollRef = useRef(null) + const validationResetTimeoutRef = useRef | null>(null) + const [deleteConfirmOpen, setDeleteConfirmOpen] = useState(false) + const [deleteConfirmText, setDeleteConfirmText] = useState("") + const [validatingModelIndex, setValidatingModelIndex] = useState< + number | null + >(null) + const [duplicateError, setDuplicateError] = useState("") + const [editError, setEditError] = useState<{ + modelId: string + message: string + } | null>(null) + + const { + config, + addProvider, + updateProvider, + deleteProvider, + addModel, + updateModel, + deleteModel, + } = modelConfig + + // Get selected provider + const selectedProvider = config.providers.find( + (p) => p.id === selectedProviderId, + ) + + // Track scroll position for gradient shadows + useEffect(() => { + const scrollEl = scrollRef.current?.querySelector( + "[data-radix-scroll-area-viewport]", + ) as HTMLElement | null + if (!scrollEl) return + + const handleScroll = () => { + const { scrollTop, scrollHeight, clientHeight } = scrollEl + setScrollState({ + top: scrollTop > 10, + bottom: scrollTop < scrollHeight - clientHeight - 10, + }) + } + + handleScroll() // Initial check + scrollEl.addEventListener("scroll", handleScroll) + return () => scrollEl.removeEventListener("scroll", handleScroll) + }, [selectedProvider]) + + // Cleanup validation reset timeout on unmount + useEffect(() => { + return () => { + if (validationResetTimeoutRef.current) { + clearTimeout(validationResetTimeoutRef.current) + } + } + }, []) + + // Get suggested models for current provider + const suggestedModels = selectedProvider + ? SUGGESTED_MODELS[selectedProvider.provider] || [] + : [] + + // Filter out already-added models from suggestions + const existingModelIds = + selectedProvider?.models.map((m) => m.modelId) || [] + const availableSuggestions = suggestedModels.filter( + (modelId) => !existingModelIds.includes(modelId), + ) + + // Handle adding a new provider + const handleAddProvider = (providerType: ProviderName) => { + const newProvider = addProvider(providerType) + setSelectedProviderId(newProvider.id) + setValidationStatus("idle") + } + + // Handle provider field updates + const handleProviderUpdate = ( + field: keyof ProviderConfig, + value: string | boolean, + ) => { + if (!selectedProviderId) return + updateProvider(selectedProviderId, { [field]: value }) + // Reset validation when credentials change + const credentialFields = [ + "apiKey", + "baseUrl", + "awsAccessKeyId", + "awsSecretAccessKey", + "awsRegion", + ] + if (credentialFields.includes(field)) { + setValidationStatus("idle") + updateProvider(selectedProviderId, { validated: false }) + } + } + + // Handle adding a model to current provider + // Returns true if model was added successfully, false otherwise + const handleAddModel = (modelId: string): boolean => { + if (!selectedProviderId || !selectedProvider) return false + // Prevent duplicate model IDs + if (existingModelIds.includes(modelId)) { + setDuplicateError(`Model "${modelId}" already exists`) + return false + } + setDuplicateError("") + addModel(selectedProviderId, modelId) + return true + } + + // Handle deleting a model + const handleDeleteModel = (modelConfigId: string) => { + if (!selectedProviderId) return + deleteModel(selectedProviderId, modelConfigId) + } + + // Handle deleting the provider + const handleDeleteProvider = () => { + if (!selectedProviderId) return + deleteProvider(selectedProviderId) + setSelectedProviderId(null) + setValidationStatus("idle") + setDeleteConfirmOpen(false) + } + + // Validate all models + const handleValidate = useCallback(async () => { + if (!selectedProvider) return + + // Check credentials based on provider type + const isBedrock = selectedProvider.provider === "bedrock" + if (isBedrock) { + if ( + !selectedProvider.awsAccessKeyId || + !selectedProvider.awsSecretAccessKey || + !selectedProvider.awsRegion + ) { + return + } + } else if (!selectedProvider.apiKey) { + return + } + + // Need at least one model to validate + if (selectedProvider.models.length === 0) { + setValidationError("Add at least one model to validate") + setValidationStatus("error") + return + } + + setValidationStatus("validating") + setValidationError("") + + let allValid = true + let errorCount = 0 + + // Validate each model + for (let i = 0; i < selectedProvider.models.length; i++) { + const model = selectedProvider.models[i] + setValidatingModelIndex(i) + + try { + const response = await fetch("/api/validate-model", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + provider: selectedProvider.provider, + apiKey: selectedProvider.apiKey, + baseUrl: selectedProvider.baseUrl, + modelId: model.modelId, + // AWS Bedrock credentials + awsAccessKeyId: selectedProvider.awsAccessKeyId, + awsSecretAccessKey: selectedProvider.awsSecretAccessKey, + awsRegion: selectedProvider.awsRegion, + }), + }) + + const data = await response.json() + + if (data.valid) { + updateModel(selectedProviderId!, model.id, { + validated: true, + validationError: undefined, + }) + } else { + allValid = false + errorCount++ + updateModel(selectedProviderId!, model.id, { + validated: false, + validationError: data.error || "Validation failed", + }) + } + } catch { + allValid = false + errorCount++ + updateModel(selectedProviderId!, model.id, { + validated: false, + validationError: "Network error", + }) + } + } + + setValidatingModelIndex(null) + + if (allValid) { + setValidationStatus("success") + updateProvider(selectedProviderId!, { validated: true }) + // Reset to idle after showing success briefly (with cleanup) + if (validationResetTimeoutRef.current) { + clearTimeout(validationResetTimeoutRef.current) + } + validationResetTimeoutRef.current = setTimeout(() => { + setValidationStatus("idle") + validationResetTimeoutRef.current = null + }, 1500) + } else { + setValidationStatus("error") + setValidationError(`${errorCount} model(s) failed validation`) + } + }, [selectedProvider, selectedProviderId, updateProvider, updateModel]) + + // Get all available provider types + const availableProviders = Object.keys(PROVIDER_INFO) as ProviderName[] + + // Get display name for provider + const getProviderDisplayName = (provider: ProviderConfig) => { + return provider.name || PROVIDER_INFO[provider.provider].label + } + + return ( + + + + +
+ +
+ {dict.modelConfig?.title || "AI Model Configuration"} +
+ + {dict.modelConfig?.description || + "Configure multiple AI providers and models for your workspace"} + +
+ +
+ {/* Provider List (Left Sidebar) */} +
+
+ + Providers + +
+ + +
+ {config.providers.length === 0 ? ( +
+
+ +
+

+ Add a provider to get started +

+
+ ) : ( +
+ {config.providers.map((provider) => ( + + ))} +
+ )} +
+
+ + {/* Add Provider */} +
+ +
+
+ + {/* Provider Details (Right Panel) */} +
+ {selectedProvider ? ( + <> + {/* Top gradient shadow */} +
+ {/* Bottom gradient shadow */} +
+ +
+ {/* Provider Header */} +
+
+ +
+
+

+ { + PROVIDER_INFO[ + selectedProvider + .provider + ].label + } +

+

+ {selectedProvider.models + .length === 0 + ? "No models configured" + : `${selectedProvider.models.length} model${selectedProvider.models.length > 1 ? "s" : ""} configured`} +

+
+ {selectedProvider.validated && ( +
+ + + Verified + +
+ )} +
+ + {/* Configuration Section */} +
+
+ + Configuration +
+ +
+ {/* Display Name */} +
+ + + handleProviderUpdate( + "name", + e.target.value, + ) + } + placeholder={ + PROVIDER_INFO[ + selectedProvider + .provider + ].label + } + className="h-9" + /> +
+ + {/* Credentials - different for Bedrock vs other providers */} + {selectedProvider.provider === + "bedrock" ? ( + <> + {/* AWS Access Key ID */} +
+ + + handleProviderUpdate( + "awsAccessKeyId", + e.target + .value, + ) + } + placeholder="AKIA..." + className="h-9 font-mono text-xs" + /> +
+ + {/* AWS Secret Access Key */} +
+ +
+ + handleProviderUpdate( + "awsSecretAccessKey", + e + .target + .value, + ) + } + placeholder="Enter your secret access key" + className="h-9 pr-10 font-mono text-xs" + /> + +
+
+ + {/* AWS Region */} +
+ + +
+ + {/* Test Button for Bedrock */} +
+ + {validationStatus === + "error" && + validationError && ( +

+ + { + validationError + } +

+ )} +
+ + ) : ( + <> + {/* API Key */} +
+ +
+
+ + handleProviderUpdate( + "apiKey", + e + .target + .value, + ) + } + placeholder="Enter your API key" + className="h-9 pr-10 font-mono text-xs" + /> + +
+ +
+ {validationStatus === + "error" && + validationError && ( +

+ + { + validationError + } +

+ )} +
+ + {/* Base URL */} +
+ + + handleProviderUpdate( + "baseUrl", + e.target + .value, + ) + } + placeholder={ + PROVIDER_INFO[ + selectedProvider + .provider + ] + .defaultBaseUrl || + "Custom endpoint URL" + } + className="h-9 font-mono text-xs" + /> +
+ + )} +
+
+ + {/* Models Section */} +
+
+
+ + Models +
+
+
+ { + setCustomModelInput( + e.target + .value, + ) + // Clear duplicate error when typing + if ( + duplicateError + ) { + setDuplicateError( + "", + ) + } + }} + onKeyDown={(e) => { + if ( + e.key === + "Enter" && + customModelInput.trim() + ) { + const success = + handleAddModel( + customModelInput.trim(), + ) + if ( + success + ) { + setCustomModelInput( + "", + ) + } + } + }} + className={cn( + "h-8 w-48 font-mono text-xs", + duplicateError && + "border-destructive focus-visible:ring-destructive", + )} + /> + {/* Show duplicate error for custom model input */} + {duplicateError && ( +

+ {duplicateError} +

+ )} +
+ + +
+
+ + {/* Model List */} +
+ {selectedProvider.models + .length === 0 ? ( +
+
+ +
+

+ No models configured +

+
+ ) : ( +
+ {selectedProvider.models.map( + (model, index) => ( +
+
+ {/* Status icon */} +
+ {validatingModelIndex !== + null && + index === + validatingModelIndex ? ( + // Currently validating +
+ +
+ ) : validatingModelIndex !== + null && + index > + validatingModelIndex && + model.validated === + undefined ? ( + // Queued +
+ +
+ ) : model.validated === + true ? ( + // Valid +
+ +
+ ) : model.validated === + false ? ( + // Invalid +
+ +
+ ) : ( + // Not validated yet +
+ +
+ )} +
+ { + // Allow free typing - validation happens on blur + // Clear edit error when typing + if ( + editError?.modelId === + model.id + ) { + setEditError( + null, + ) + } + updateModel( + selectedProviderId!, + model.id, + { + modelId: + e + .target + .value, + validated: + undefined, + validationError: + undefined, + }, + ) + }} + onKeyDown={( + e, + ) => { + if ( + e.key === + "Enter" + ) { + e.currentTarget.blur() + } + }} + onBlur={( + e, + ) => { + const newModelId = + e.target.value.trim() + + // Helper to show error with shake + const showError = + ( + message: string, + ) => { + setEditError( + { + modelId: + model.id, + message, + }, + ) + e.target.animate( + [ + { + transform: + "translateX(0)", + }, + { + transform: + "translateX(-4px)", + }, + { + transform: + "translateX(4px)", + }, + { + transform: + "translateX(-4px)", + }, + { + transform: + "translateX(4px)", + }, + { + transform: + "translateX(0)", + }, + ], + { + duration: 400, + easing: "ease-in-out", + }, + ) + e.target.focus() + } + + // Check for empty model name + if ( + !newModelId + ) { + showError( + "Model ID cannot be empty", + ) + return + } + + // Check for duplicate + const otherModelIds = + selectedProvider?.models + .filter( + ( + m, + ) => + m.id !== + model.id, + ) + .map( + ( + m, + ) => + m.modelId, + ) || + [] + if ( + otherModelIds.includes( + newModelId, + ) + ) { + showError( + "This model ID already exists", + ) + return + } + + // Clear error on valid blur + setEditError( + null, + ) + }} + className="flex-1 min-w-0 font-mono text-sm h-8 border-0 bg-transparent focus-visible:bg-background focus-visible:ring-1" + /> + +
+ {/* Show validation error inline */} + {model.validated === + false && + model.validationError && ( +

+ { + model.validationError + } +

+ )} + {/* Show edit error inline */} + {editError?.modelId === + model.id && ( +

+ { + editError.message + } +

+ )} +
+ ), + )} +
+ )} +
+
+ + {/* Danger Zone */} +
+ +
+
+
+ + ) : ( +
+
+ +
+

+ Configure AI Providers +

+

+ Select a provider from the list or add a new + one to configure API keys and models +

+
+ )} +
+
+ + {/* Footer */} +
+

+ + API keys are stored locally in your browser +

+
+ + + {/* Delete Confirmation Dialog */} + { + setDeleteConfirmOpen(open) + if (!open) setDeleteConfirmText("") + }} + > + + +
+ +
+ + Delete Provider + + + Are you sure you want to delete{" "} + + {selectedProvider + ? selectedProvider.name || + PROVIDER_INFO[selectedProvider.provider] + .label + : "this provider"} + + ? This will remove all configured models and cannot + be undone. + +
+ {selectedProvider && + selectedProvider.models.length >= 3 && ( +
+ + + setDeleteConfirmText(e.target.value) + } + placeholder="Type provider name..." + className="h-9" + /> +
+ )} + + Cancel + = 3 && + deleteConfirmText !== + (selectedProvider.name || + PROVIDER_INFO[selectedProvider.provider] + .label) + } + className="bg-destructive text-destructive-foreground hover:bg-destructive/90 disabled:opacity-50" + > + Delete + + +
+
+
+ ) +} diff --git a/components/model-selector.tsx b/components/model-selector.tsx new file mode 100644 index 0000000..37baa4c --- /dev/null +++ b/components/model-selector.tsx @@ -0,0 +1,216 @@ +"use client" + +import { Bot, Check, ChevronDown, Server, Settings2 } from "lucide-react" +import { useMemo, useState } from "react" +import { + ModelSelectorContent, + ModelSelectorEmpty, + ModelSelectorGroup, + ModelSelectorInput, + ModelSelectorItem, + ModelSelectorList, + ModelSelectorLogo, + ModelSelectorName, + ModelSelector as ModelSelectorRoot, + ModelSelectorSeparator, + ModelSelectorTrigger, +} from "@/components/ai-elements/model-selector" +import { ButtonWithTooltip } from "@/components/button-with-tooltip" +import type { FlattenedModel } from "@/lib/types/model-config" +import { cn } from "@/lib/utils" + +interface ModelSelectorProps { + models: FlattenedModel[] + selectedModelId: string | undefined + onSelect: (modelId: string | undefined) => void + onConfigure: () => void + disabled?: boolean +} + +// Map our provider names to models.dev logo names +const PROVIDER_LOGO_MAP: Record = { + openai: "openai", + anthropic: "anthropic", + google: "google", + azure: "azure", + bedrock: "amazon-bedrock", + openrouter: "openrouter", + deepseek: "deepseek", + siliconflow: "siliconflow", + gateway: "vercel", +} + +// Group models by providerLabel (handles duplicate providers) +function groupModelsByProvider( + models: FlattenedModel[], +): Map { + const groups = new Map< + string, + { provider: string; models: FlattenedModel[] } + >() + for (const model of models) { + const key = model.providerLabel + const existing = groups.get(key) + if (existing) { + existing.models.push(model) + } else { + groups.set(key, { provider: model.provider, models: [model] }) + } + } + return groups +} + +export function ModelSelector({ + models, + selectedModelId, + onSelect, + onConfigure, + disabled = false, +}: ModelSelectorProps) { + const [open, setOpen] = useState(false) + // Only show validated models in the selector + const validatedModels = useMemo( + () => models.filter((m) => m.validated === true), + [models], + ) + const groupedModels = useMemo( + () => groupModelsByProvider(validatedModels), + [validatedModels], + ) + + // Find selected model for display + const selectedModel = useMemo( + () => models.find((m) => m.id === selectedModelId), + [models, selectedModelId], + ) + + const handleSelect = (value: string) => { + if (value === "__configure__") { + onConfigure() + } else if (value === "__server_default__") { + onSelect(undefined) + } else { + onSelect(value) + } + setOpen(false) + } + + const tooltipContent = selectedModel + ? `${selectedModel.modelId} (click to change)` + : "Using server default model (click to change)" + + return ( + + + + + + {selectedModel ? selectedModel.modelId : "Default"} + + + + + + + + + {validatedModels.length === 0 && models.length > 0 + ? "No verified models. Test your models first." + : "No models found."} + + + {/* Server Default Option */} + + + + + + Server Default + + + + + {/* Configured Models by Provider */} + {Array.from(groupedModels.entries()).map( + ([ + providerLabel, + { provider, models: providerModels }, + ]) => ( + + {providerModels.map((model) => ( + handleSelect(model.id)} + className="cursor-pointer" + > + + + + {model.modelId} + + + ))} + + ), + )} + + {/* Configure Option */} + + + + + + Configure Models... + + + + {/* Info text */} +
+ Only verified models are shown +
+
+
+
+ ) +} diff --git a/components/settings-dialog.tsx b/components/settings-dialog.tsx index b2ba8b7..029fc8d 100644 --- a/components/settings-dialog.tsx +++ b/components/settings-dialog.tsx @@ -12,13 +12,6 @@ import { } from "@/components/ui/dialog" import { Input } from "@/components/ui/input" import { Label } from "@/components/ui/label" -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select" import { Switch } from "@/components/ui/switch" import { useDictionary } from "@/hooks/use-dictionary" @@ -35,10 +28,6 @@ interface SettingsDialogProps { export const STORAGE_ACCESS_CODE_KEY = "next-ai-draw-io-access-code" export const STORAGE_CLOSE_PROTECTION_KEY = "next-ai-draw-io-close-protection" const STORAGE_ACCESS_CODE_REQUIRED_KEY = "next-ai-draw-io-access-code-required" -export const STORAGE_AI_PROVIDER_KEY = "next-ai-draw-io-ai-provider" -export const STORAGE_AI_BASE_URL_KEY = "next-ai-draw-io-ai-base-url" -export const STORAGE_AI_API_KEY_KEY = "next-ai-draw-io-ai-api-key" -export const STORAGE_AI_MODEL_KEY = "next-ai-draw-io-ai-model" function getStoredAccessCodeRequired(): boolean | null { if (typeof window === "undefined") return null @@ -64,10 +53,6 @@ export function SettingsDialog({ const [accessCodeRequired, setAccessCodeRequired] = useState( () => getStoredAccessCodeRequired() ?? false, ) - const [provider, setProvider] = useState("") - const [baseUrl, setBaseUrl] = useState("") - const [apiKey, setApiKey] = useState("") - const [modelId, setModelId] = useState("") useEffect(() => { // Only fetch if not cached in localStorage @@ -104,12 +89,6 @@ export function SettingsDialog({ // Default to true if not set setCloseProtection(storedCloseProtection !== "false") - // Load AI provider settings - setProvider(localStorage.getItem(STORAGE_AI_PROVIDER_KEY) || "") - setBaseUrl(localStorage.getItem(STORAGE_AI_BASE_URL_KEY) || "") - setApiKey(localStorage.getItem(STORAGE_AI_API_KEY_KEY) || "") - setModelId(localStorage.getItem(STORAGE_AI_MODEL_KEY) || "") - setError("") } }, [open]) @@ -197,190 +176,6 @@ export function SettingsDialog({ )}
)} -
- -

- {dict.settings.aiProviderDescription} -

-
-
- - -
- {provider && provider !== "default" && ( - <> -
- - { - setModelId(e.target.value) - localStorage.setItem( - STORAGE_AI_MODEL_KEY, - e.target.value, - ) - }} - placeholder={ - provider === "openai" - ? "e.g., gpt-4o" - : provider === "anthropic" - ? "e.g., claude-sonnet-4-5" - : provider === "google" - ? "e.g., gemini-2.0-flash-exp" - : provider === - "deepseek" - ? "e.g., deepseek-chat" - : dict.settings - .modelId - } - /> -
-
- - { - setApiKey(e.target.value) - localStorage.setItem( - STORAGE_AI_API_KEY_KEY, - e.target.value, - ) - }} - placeholder={ - dict.settings.apiKeyPlaceholder - } - autoComplete="off" - /> -

- {dict.settings.overrides}{" "} - {provider === "openai" - ? "OPENAI_API_KEY" - : provider === "anthropic" - ? "ANTHROPIC_API_KEY" - : provider === "google" - ? "GOOGLE_GENERATIVE_AI_API_KEY" - : provider === "azure" - ? "AZURE_API_KEY" - : provider === - "openrouter" - ? "OPENROUTER_API_KEY" - : provider === - "deepseek" - ? "DEEPSEEK_API_KEY" - : provider === - "siliconflow" - ? "SILICONFLOW_API_KEY" - : "server API key"} -

-
-
- - { - setBaseUrl(e.target.value) - localStorage.setItem( - STORAGE_AI_BASE_URL_KEY, - e.target.value, - ) - }} - placeholder={ - provider === "anthropic" - ? "https://api.anthropic.com/v1" - : provider === "siliconflow" - ? "https://api.siliconflow.com/v1" - : dict.settings - .customEndpoint - } - /> -
- - - )} -
-
-