mirror of
https://github.com/DayuanJiang/next-ai-draw-io.git
synced 2026-01-02 14:22:28 +08:00
fix:feature/sglang-provider (#302)
Co-authored-by: zhaochaojin <zhaochaojin@didiglobal.com> Co-authored-by: dayuan.jiang <jdy.toh@gmail.com>
This commit is contained in:
@@ -68,6 +68,10 @@ AI_MODEL=global.anthropic.claude-sonnet-4-5-20250929-v1:0
|
|||||||
# SILICONFLOW_API_KEY=sk-...
|
# SILICONFLOW_API_KEY=sk-...
|
||||||
# SILICONFLOW_BASE_URL=https://api.siliconflow.com/v1 # Optional: switch to https://api.siliconflow.cn/v1 if needed
|
# SILICONFLOW_BASE_URL=https://api.siliconflow.com/v1 # Optional: switch to https://api.siliconflow.cn/v1 if needed
|
||||||
|
|
||||||
|
# SGLang Configuration (OpenAI-compatible)
|
||||||
|
# SGLANG_API_KEY=your-sglang-api-key
|
||||||
|
# SGLANG_BASE_URL=http://127.0.0.1:8000/v1 # Your SGLang endpoint
|
||||||
|
|
||||||
# Vercel AI Gateway Configuration
|
# Vercel AI Gateway Configuration
|
||||||
# Get your API key from: https://vercel.com/ai-gateway
|
# Get your API key from: https://vercel.com/ai-gateway
|
||||||
# Model format: "provider/model" e.g., "openai/gpt-4o", "anthropic/claude-sonnet-4-5"
|
# Model format: "provider/model" e.g., "openai/gpt-4o", "anthropic/claude-sonnet-4-5"
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ export type ProviderName =
|
|||||||
| "openrouter"
|
| "openrouter"
|
||||||
| "deepseek"
|
| "deepseek"
|
||||||
| "siliconflow"
|
| "siliconflow"
|
||||||
|
| "sglang"
|
||||||
| "gateway"
|
| "gateway"
|
||||||
|
|
||||||
interface ModelConfig {
|
interface ModelConfig {
|
||||||
@@ -50,6 +51,7 @@ const ALLOWED_CLIENT_PROVIDERS: ProviderName[] = [
|
|||||||
"openrouter",
|
"openrouter",
|
||||||
"deepseek",
|
"deepseek",
|
||||||
"siliconflow",
|
"siliconflow",
|
||||||
|
"sglang",
|
||||||
"gateway",
|
"gateway",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -343,6 +345,7 @@ function buildProviderOptions(
|
|||||||
case "deepseek":
|
case "deepseek":
|
||||||
case "openrouter":
|
case "openrouter":
|
||||||
case "siliconflow":
|
case "siliconflow":
|
||||||
|
case "sglang":
|
||||||
case "gateway": {
|
case "gateway": {
|
||||||
// These providers don't have reasoning configs in AI SDK yet
|
// These providers don't have reasoning configs in AI SDK yet
|
||||||
// Gateway passes through to underlying providers which handle their own configs
|
// Gateway passes through to underlying providers which handle their own configs
|
||||||
@@ -367,6 +370,7 @@ const PROVIDER_ENV_VARS: Record<ProviderName, string | null> = {
|
|||||||
openrouter: "OPENROUTER_API_KEY",
|
openrouter: "OPENROUTER_API_KEY",
|
||||||
deepseek: "DEEPSEEK_API_KEY",
|
deepseek: "DEEPSEEK_API_KEY",
|
||||||
siliconflow: "SILICONFLOW_API_KEY",
|
siliconflow: "SILICONFLOW_API_KEY",
|
||||||
|
sglang: "SGLANG_API_KEY",
|
||||||
gateway: "AI_GATEWAY_API_KEY",
|
gateway: "AI_GATEWAY_API_KEY",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -432,7 +436,7 @@ function validateProviderCredentials(provider: ProviderName): void {
|
|||||||
* Get the AI model based on environment variables
|
* Get the AI model based on environment variables
|
||||||
*
|
*
|
||||||
* Environment variables:
|
* Environment variables:
|
||||||
* - AI_PROVIDER: The provider to use (bedrock, openai, anthropic, google, azure, ollama, openrouter, deepseek, siliconflow)
|
* - AI_PROVIDER: The provider to use (bedrock, openai, anthropic, google, azure, ollama, openrouter, deepseek, siliconflow, sglang, gateway)
|
||||||
* - AI_MODEL: The model ID/name for the selected provider
|
* - AI_MODEL: The model ID/name for the selected provider
|
||||||
*
|
*
|
||||||
* Provider-specific env vars:
|
* Provider-specific env vars:
|
||||||
@@ -448,6 +452,8 @@ function validateProviderCredentials(provider: ProviderName): void {
|
|||||||
* - DEEPSEEK_BASE_URL: DeepSeek endpoint (optional)
|
* - DEEPSEEK_BASE_URL: DeepSeek endpoint (optional)
|
||||||
* - SILICONFLOW_API_KEY: SiliconFlow API key
|
* - SILICONFLOW_API_KEY: SiliconFlow API key
|
||||||
* - SILICONFLOW_BASE_URL: SiliconFlow endpoint (optional, defaults to https://api.siliconflow.com/v1)
|
* - SILICONFLOW_BASE_URL: SiliconFlow endpoint (optional, defaults to https://api.siliconflow.com/v1)
|
||||||
|
* - SGLANG_API_KEY: SGLang API key
|
||||||
|
* - SGLANG_BASE_URL: SGLang endpoint (optional)
|
||||||
*/
|
*/
|
||||||
export function getAIModel(overrides?: ClientOverrides): ModelConfig {
|
export function getAIModel(overrides?: ClientOverrides): ModelConfig {
|
||||||
// SECURITY: Prevent SSRF attacks (GHSA-9qf7-mprq-9qgm)
|
// SECURITY: Prevent SSRF attacks (GHSA-9qf7-mprq-9qgm)
|
||||||
@@ -516,6 +522,7 @@ export function getAIModel(overrides?: ClientOverrides): ModelConfig {
|
|||||||
`- OPENROUTER_API_KEY for OpenRouter\n` +
|
`- OPENROUTER_API_KEY for OpenRouter\n` +
|
||||||
`- AZURE_API_KEY for Azure\n` +
|
`- AZURE_API_KEY for Azure\n` +
|
||||||
`- SILICONFLOW_API_KEY for SiliconFlow\n` +
|
`- SILICONFLOW_API_KEY for SiliconFlow\n` +
|
||||||
|
`- SGLANG_API_KEY for SGLang\n` +
|
||||||
`Or set AI_PROVIDER=ollama for local Ollama.`,
|
`Or set AI_PROVIDER=ollama for local Ollama.`,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@@ -698,6 +705,112 @@ export function getAIModel(overrides?: ClientOverrides): ModelConfig {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "sglang": {
|
||||||
|
const apiKey = overrides?.apiKey || process.env.SGLANG_API_KEY
|
||||||
|
const baseURL = overrides?.baseUrl || process.env.SGLANG_BASE_URL
|
||||||
|
|
||||||
|
const sglangProvider = createOpenAI({
|
||||||
|
apiKey,
|
||||||
|
baseURL,
|
||||||
|
// Add a custom fetch wrapper to intercept and fix the stream from sglang
|
||||||
|
fetch: async (url, options) => {
|
||||||
|
const response = await fetch(url, options)
|
||||||
|
if (!response.body) {
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a transform stream to fix the non-compliant sglang stream
|
||||||
|
let buffer = ""
|
||||||
|
const decoder = new TextDecoder()
|
||||||
|
|
||||||
|
const transformStream = new TransformStream({
|
||||||
|
transform(chunk, controller) {
|
||||||
|
buffer += decoder.decode(chunk, { stream: true })
|
||||||
|
// Process all complete messages in the buffer
|
||||||
|
let messageEndPos
|
||||||
|
while (
|
||||||
|
(messageEndPos = buffer.indexOf("\n\n")) !== -1
|
||||||
|
) {
|
||||||
|
const message = buffer.substring(
|
||||||
|
0,
|
||||||
|
messageEndPos,
|
||||||
|
)
|
||||||
|
buffer = buffer.substring(messageEndPos + 2) // Move past the '\n\n'
|
||||||
|
|
||||||
|
if (message.startsWith("data: ")) {
|
||||||
|
const jsonStr = message.substring(6).trim()
|
||||||
|
if (jsonStr === "[DONE]") {
|
||||||
|
controller.enqueue(
|
||||||
|
new TextEncoder().encode(
|
||||||
|
message + "\n\n",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(jsonStr)
|
||||||
|
const delta = data.choices?.[0]?.delta
|
||||||
|
|
||||||
|
if (delta) {
|
||||||
|
// Fix 1: remove invalid empty role
|
||||||
|
if (delta.role === "") {
|
||||||
|
delete delta.role
|
||||||
|
}
|
||||||
|
// Fix 2: remove non-standard reasoning_content field
|
||||||
|
if ("reasoning_content" in delta) {
|
||||||
|
delete delta.reasoning_content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-serialize and forward the corrected data with the correct SSE format
|
||||||
|
controller.enqueue(
|
||||||
|
new TextEncoder().encode(
|
||||||
|
`data: ${JSON.stringify(data)}\n\n`,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
} catch (e) {
|
||||||
|
// If parsing fails, forward the original message to avoid breaking the stream.
|
||||||
|
controller.enqueue(
|
||||||
|
new TextEncoder().encode(
|
||||||
|
message + "\n\n",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else if (message.trim() !== "") {
|
||||||
|
// Pass through other message types (e.g., 'event: ...')
|
||||||
|
controller.enqueue(
|
||||||
|
new TextEncoder().encode(
|
||||||
|
message + "\n\n",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
flush(controller) {
|
||||||
|
// If there's anything left in the buffer, forward it.
|
||||||
|
if (buffer.trim()) {
|
||||||
|
controller.enqueue(
|
||||||
|
new TextEncoder().encode(buffer),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const transformedBody =
|
||||||
|
response.body.pipeThrough(transformStream)
|
||||||
|
|
||||||
|
// Return a new response with the transformed body
|
||||||
|
return new Response(transformedBody, {
|
||||||
|
status: response.status,
|
||||||
|
statusText: response.statusText,
|
||||||
|
headers: response.headers,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
})
|
||||||
|
model = sglangProvider.chat(modelId)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
case "gateway": {
|
case "gateway": {
|
||||||
// Vercel AI Gateway - unified access to multiple AI providers
|
// Vercel AI Gateway - unified access to multiple AI providers
|
||||||
// Model format: "provider/model" e.g., "openai/gpt-4o", "anthropic/claude-sonnet-4-5"
|
// Model format: "provider/model" e.g., "openai/gpt-4o", "anthropic/claude-sonnet-4-5"
|
||||||
@@ -721,7 +834,7 @@ export function getAIModel(overrides?: ClientOverrides): ModelConfig {
|
|||||||
|
|
||||||
default:
|
default:
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Unknown AI provider: ${provider}. Supported providers: bedrock, openai, anthropic, google, azure, ollama, openrouter, deepseek, siliconflow, gateway`,
|
`Unknown AI provider: ${provider}. Supported providers: bedrock, openai, anthropic, google, azure, ollama, openrouter, deepseek, siliconflow, sglang, gateway`,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user