mirror of
https://github.com/DayuanJiang/next-ai-draw-io.git
synced 2026-01-02 14:22:28 +08:00
feat: add tool call JSON repair and Bedrock compatibility (#127)
- Add fixToolCallInputs() to fix Bedrock API requirement (JSON object, not string) - Add experimental_repairToolCall for malformed JSON from model - Add stepCountIs(5) limit to prevent infinite loops - Update edit_diagram tool description with JSON escaping warning Co-authored-by: dayuan.jiang <jiangdy@amazon.co.jp>
This commit is contained in:
@@ -2,6 +2,7 @@ import {
|
|||||||
convertToModelMessages,
|
convertToModelMessages,
|
||||||
createUIMessageStream,
|
createUIMessageStream,
|
||||||
createUIMessageStreamResponse,
|
createUIMessageStreamResponse,
|
||||||
|
stepCountIs,
|
||||||
streamText,
|
streamText,
|
||||||
} from "ai"
|
} from "ai"
|
||||||
import { z } from "zod"
|
import { z } from "zod"
|
||||||
@@ -63,6 +64,28 @@ function isMinimalDiagram(xml: string): boolean {
|
|||||||
return !stripped.includes('id="2"')
|
return !stripped.includes('id="2"')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to fix tool call inputs for Bedrock API
|
||||||
|
// Bedrock requires toolUse.input to be a JSON object, not a string
|
||||||
|
function fixToolCallInputs(messages: any[]): any[] {
|
||||||
|
return messages.map((msg) => {
|
||||||
|
if (msg.role !== "assistant" || !Array.isArray(msg.content)) {
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
const fixedContent = msg.content.map((part: any) => {
|
||||||
|
if (part.type === "tool-call" && typeof part.input === "string") {
|
||||||
|
try {
|
||||||
|
return { ...part, input: JSON.parse(part.input) }
|
||||||
|
} catch {
|
||||||
|
// If parsing fails, wrap the string in an object
|
||||||
|
return { ...part, input: { rawInput: part.input } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return part
|
||||||
|
})
|
||||||
|
return { ...msg, content: fixedContent }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to create cached stream response
|
// Helper function to create cached stream response
|
||||||
function createCachedStreamResponse(xml: string): Response {
|
function createCachedStreamResponse(xml: string): Response {
|
||||||
const toolCallId = `cached-${Date.now()}`
|
const toolCallId = `cached-${Date.now()}`
|
||||||
@@ -189,9 +212,12 @@ ${lastMessageText}
|
|||||||
// Convert UIMessages to ModelMessages and add system message
|
// Convert UIMessages to ModelMessages and add system message
|
||||||
const modelMessages = convertToModelMessages(messages)
|
const modelMessages = convertToModelMessages(messages)
|
||||||
|
|
||||||
|
// Fix tool call inputs for Bedrock API (requires JSON objects, not strings)
|
||||||
|
const fixedMessages = fixToolCallInputs(modelMessages)
|
||||||
|
|
||||||
// Filter out messages with empty content arrays (Bedrock API rejects these)
|
// Filter out messages with empty content arrays (Bedrock API rejects these)
|
||||||
// This is a safety measure - ideally convertToModelMessages should handle all cases
|
// This is a safety measure - ideally convertToModelMessages should handle all cases
|
||||||
let enhancedMessages = modelMessages.filter(
|
let enhancedMessages = fixedMessages.filter(
|
||||||
(msg: any) =>
|
(msg: any) =>
|
||||||
msg.content && Array.isArray(msg.content) && msg.content.length > 0,
|
msg.content && Array.isArray(msg.content) && msg.content.length > 0,
|
||||||
)
|
)
|
||||||
@@ -267,6 +293,7 @@ ${lastMessageText}
|
|||||||
|
|
||||||
const result = streamText({
|
const result = streamText({
|
||||||
model,
|
model,
|
||||||
|
stopWhen: stepCountIs(5),
|
||||||
messages: allMessages,
|
messages: allMessages,
|
||||||
...(providerOptions && { providerOptions }),
|
...(providerOptions && { providerOptions }),
|
||||||
...(headers && { headers }),
|
...(headers && { headers }),
|
||||||
@@ -277,6 +304,32 @@ ${lastMessageText}
|
|||||||
userId,
|
userId,
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
// Repair malformed tool calls (model sometimes generates invalid JSON with unescaped quotes)
|
||||||
|
experimental_repairToolCall: async ({ toolCall }) => {
|
||||||
|
// The toolCall.input contains the raw JSON string that failed to parse
|
||||||
|
const rawJson =
|
||||||
|
typeof toolCall.input === "string" ? toolCall.input : null
|
||||||
|
|
||||||
|
if (rawJson) {
|
||||||
|
try {
|
||||||
|
// Fix unescaped quotes: x="520" should be x=\"520\"
|
||||||
|
const fixed = rawJson.replace(
|
||||||
|
/([a-zA-Z])="(\d+)"/g,
|
||||||
|
'$1=\\"$2\\"',
|
||||||
|
)
|
||||||
|
const parsed = JSON.parse(fixed)
|
||||||
|
return {
|
||||||
|
type: "tool-call" as const,
|
||||||
|
toolCallId: toolCall.toolCallId,
|
||||||
|
toolName: toolCall.toolName,
|
||||||
|
input: JSON.stringify(parsed),
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Repair failed, return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
},
|
||||||
onFinish: ({ text, usage, providerMetadata }) => {
|
onFinish: ({ text, usage, providerMetadata }) => {
|
||||||
console.log(
|
console.log(
|
||||||
"[Cache] Full providerMetadata:",
|
"[Cache] Full providerMetadata:",
|
||||||
@@ -342,7 +395,9 @@ IMPORTANT: Keep edits concise:
|
|||||||
- Only include the lines that are changing, plus 1-2 surrounding lines for context if needed
|
- Only include the lines that are changing, plus 1-2 surrounding lines for context if needed
|
||||||
- Break large changes into multiple smaller edits
|
- Break large changes into multiple smaller edits
|
||||||
- Each search must contain complete lines (never truncate mid-line)
|
- Each search must contain complete lines (never truncate mid-line)
|
||||||
- First match only - be specific enough to target the right element`,
|
- First match only - be specific enough to target the right element
|
||||||
|
|
||||||
|
⚠️ JSON ESCAPING: Every " inside string values MUST be escaped as \\". Example: x=\\"100\\" y=\\"200\\" - BOTH quotes need backslashes!`,
|
||||||
inputSchema: z.object({
|
inputSchema: z.object({
|
||||||
edits: z
|
edits: z
|
||||||
.array(
|
.array(
|
||||||
|
|||||||
Reference in New Issue
Block a user