diff --git a/components/chat-message-display.tsx b/components/chat-message-display.tsx index 56e720c..2a9eb3a 100644 --- a/components/chat-message-display.tsx +++ b/components/chat-message-display.tsx @@ -1,7 +1,7 @@ "use client"; import type React from "react"; -import { useRef, useEffect, useState } from "react"; +import { useRef, useEffect, useState, useCallback } from "react"; import Image from "next/image"; import { ScrollArea } from "@/components/ui/scroll-area"; import ExamplePanel from "./chat-example-panel"; @@ -26,36 +26,11 @@ export function ChatMessageDisplay({ const { chartXML, loadDiagram: onDisplayChart } = useDiagram(); const messagesEndRef = useRef(null); const previousXML = useRef(""); + const processedToolCalls = useRef>(new Set()); const [expandedTools, setExpandedTools] = useState>( {} ); - useEffect(() => { - if (messagesEndRef.current) { - messagesEndRef.current.scrollIntoView({ behavior: "smooth" }); - } - }, [messages]); - - // Auto-collapse args when diagrams are generated - useEffect(() => { - messages.forEach((message) => { - if (message.parts) { - message.parts.forEach((part) => { - if ( - part.type === "tool-invocation" && - part.toolInvocation.state === "result" - ) { - const callId = part.toolInvocation.toolCallId; - setExpandedTools((prev) => ({ - ...prev, - [callId]: false, - })); - } - }); - } - }); - }, [messages]); - - function handleDisplayChart(xml: string) { + const handleDisplayChart = useCallback((xml: string) => { const currentXml = xml || ""; const convertedXml = convertToLegalXml(currentXml); if (convertedXml !== previousXML.current) { @@ -63,13 +38,52 @@ export function ChatMessageDisplay({ const replacedXML = replaceNodes(chartXML, convertedXml); onDisplayChart(replacedXML); } - } + }, [chartXML, onDisplayChart]); + + useEffect(() => { + if (messagesEndRef.current) { + messagesEndRef.current.scrollIntoView({ behavior: "smooth" }); + } + }, [messages]); + + // Handle tool invocations and update diagram when needed + useEffect(() => { + messages.forEach((message) => { + if (message.parts) { + message.parts.forEach((part) => { + if (part.type === "tool-invocation") { + const { toolCallId, state, args, toolName } = part.toolInvocation; + + // Auto-collapse args when diagrams are generated + if (state === "result") { + setExpandedTools((prev) => ({ + ...prev, + [toolCallId]: false, + })); + } + + // Handle diagram updates for display_diagram tool + if (toolName === "display_diagram" && args?.xml) { + // For partial calls, always update to show streaming + if (state === "partial-call") { + handleDisplayChart(args.xml); + } + // For completed calls, only update if not processed yet + else if (state === "result" && !processedToolCalls.current.has(toolCallId)) { + handleDisplayChart(args.xml); + processedToolCalls.current.add(toolCallId); + } + } + } + }); + } + }); + }, [messages, handleDisplayChart]); const renderToolInvocation = (toolInvocation: any) => { const callId = toolInvocation.toolCallId; const { toolName, args, state } = toolInvocation; const isExpanded = expandedTools[callId] ?? true; - handleDisplayChart(args?.xml); const toggleExpanded = () => { setExpandedTools((prev) => ({