8 Commits

Author SHA1 Message Date
fawney19
dd2fbf4424 style(ui): 调整模型详情抽屉关联提供商表格列宽 2026-01-08 13:37:41 +08:00
fawney19
99b12a49c6 Merge pull request #78 from fawney19/perf/optimize
perf: 优化 HTTP 客户端连接池复用
2026-01-08 13:37:13 +08:00
fawney19
ea35efe440 perf: 优化 HTTP 客户端连接池复用
- 新增 get_proxy_client() 方法,相同代理配置复用同一客户端
- 添加 LRU 淘汰策略,代理客户端上限 50 个防止内存泄漏
- 新增 get_default_client_async() 异步线程安全版本
- 使用模块级锁避免类属性初始化竞态条件
- 优化 ConcurrencyManager 使用 Redis MGET 批量获取减少往返
- 添加 get_pool_stats() 连接池统计信息接口
2026-01-08 13:34:59 +08:00
fawney19
bf09e740e9 fix(ui): 优化提供商详情页的交互体验
- 模型列表删除按钮仅在 hover 时显示红色
- 批量关联模型对话框:只有全局模型时展开,有多个分组时全部折叠
2026-01-08 11:25:52 +08:00
fawney19
60c77cec56 Merge pull request #77 from AAEE86/ui
style(ui): improve text visibility in dark mode for model badges
2026-01-08 10:52:54 +08:00
fawney19
0e4a1dddb5 refactor(ui): 优化批量端点创建的 UI 和性能
- 调整布局: API URL 移至顶部, API 格式选择移至下方
- 优化 checkbox 样式: 使用自定义勾选框替代原生样式
- API 格式按列排序: 基础格式和对应 CLI 格式上下对齐
- 请求配置改为 4 列布局, 更紧凑
- 使用 Promise.allSettled 并发创建端点, 提升性能
- 改进错误提示: 失败时直接展示具体错误信息给用户
- 清理未使用的 Select 组件导入和 selectOpen 变量
2026-01-08 10:50:25 +08:00
AAEE86
1cf18b6e12 feat(ui): support batch endpoint creation with multiple API formats (#76)
Replace single API format selector with multi-select checkbox interface in endpoint creation dialog. Users can now select multiple API formats to create multiple endpoints simultaneously with shared configuration (URL, path, timeout, etc.).

- Change API format selection from dropdown to checkbox grid layout
- Add selectedFormats array to track multiple format selections
- Implement batch creation logic with individual error handling
- Update submit button to show endpoint count being created
- Adjust form layout to improve visual hierarchy
- Display appropriate success/failure messages for batch operations
- Reset selectedFormats on form reset
2026-01-08 10:42:14 +08:00
AAEE86
f9a8be898a style(ui): improve text visibility in dark mode for model badges 2026-01-08 10:26:58 +08:00
9 changed files with 483 additions and 201 deletions

View File

@@ -460,13 +460,13 @@
<TableHead class="h-10 font-semibold"> <TableHead class="h-10 font-semibold">
Provider Provider
</TableHead> </TableHead>
<TableHead class="w-[120px] h-10 font-semibold"> <TableHead class="w-[100px] h-10 font-semibold">
能力 能力
</TableHead> </TableHead>
<TableHead class="w-[180px] h-10 font-semibold"> <TableHead class="w-[200px] h-10 font-semibold">
价格 ($/M) 价格 ($/M)
</TableHead> </TableHead>
<TableHead class="w-[80px] h-10 font-semibold text-center"> <TableHead class="w-[100px] h-10 font-semibold text-center">
操作 操作
</TableHead> </TableHead>
</TableRow> </TableRow>

View File

@@ -531,20 +531,23 @@ watch(() => props.open, async (isOpen) => {
// 加载数据 // 加载数据
async function loadData() { async function loadData() {
await Promise.all([loadGlobalModels(), loadExistingModels()]) await Promise.all([loadGlobalModels(), loadExistingModels()])
// 默认折叠全局模型组
collapsedGroups.value = new Set(['global'])
// 检查缓存,如果有缓存数据则直接使用 // 检查缓存,如果有缓存数据则直接使用
const cachedModels = getCachedModels(props.providerId) const cachedModels = getCachedModels(props.providerId)
if (cachedModels) { if (cachedModels && cachedModels.length > 0) {
upstreamModels.value = cachedModels upstreamModels.value = cachedModels
upstreamModelsLoaded.value = true upstreamModelsLoaded.value = true
// 折叠所有上游模型组 // 有多个分组时全部折叠
const allGroups = new Set(['global'])
for (const model of cachedModels) { for (const model of cachedModels) {
if (model.api_format) { if (model.api_format) {
collapsedGroups.value.add(model.api_format) allGroups.add(model.api_format)
} }
} }
collapsedGroups.value = allGroups
} else {
// 只有全局模型时展开
collapsedGroups.value = new Set()
} }
} }
@@ -585,8 +588,8 @@ async function fetchUpstreamModels(forceRefresh = false) {
} else { } else {
upstreamModels.value = result.models upstreamModels.value = result.models
upstreamModelsLoaded.value = true upstreamModelsLoaded.value = true
// 折叠所有上游模型组 // 有多个分组时全部折叠
const allGroups = new Set(collapsedGroups.value) const allGroups = new Set(['global'])
for (const model of result.models) { for (const model of result.models) {
if (model.api_format) { if (model.api_format) {
allGroups.add(model.api_format) allGroups.add(model.api_format)

View File

@@ -20,44 +20,8 @@
API 配置 API 配置
</h3> </h3>
<!-- API URL 和自定义路径 -->
<div class="grid grid-cols-2 gap-4"> <div class="grid grid-cols-2 gap-4">
<!-- API 格式 -->
<div class="space-y-2">
<Label for="api_format">API 格式 *</Label>
<template v-if="isEditMode">
<Input
id="api_format"
v-model="form.api_format"
disabled
class="bg-muted"
/>
<p class="text-xs text-muted-foreground">
API 格式创建后不可修改
</p>
</template>
<template v-else>
<Select
v-model="form.api_format"
v-model:open="selectOpen"
required
>
<SelectTrigger>
<SelectValue placeholder="请选择 API 格式" />
</SelectTrigger>
<SelectContent>
<SelectItem
v-for="format in apiFormats"
:key="format.value"
:value="format.value"
>
{{ format.label }}
</SelectItem>
</SelectContent>
</Select>
</template>
</div>
<!-- API URL -->
<div class="space-y-2"> <div class="space-y-2">
<Label for="base_url">API URL *</Label> <Label for="base_url">API URL *</Label>
<Input <Input
@@ -67,16 +31,70 @@
required required
/> />
</div> </div>
<div class="space-y-2">
<Label for="custom_path">自定义请求路径可选</Label>
<Input
id="custom_path"
v-model="form.custom_path"
:placeholder="isEditMode ? defaultPathPlaceholder : '留空使用各格式的默认路径'"
/>
</div>
</div> </div>
<!-- 自定义路径 --> <!-- API 格式 -->
<div class="space-y-2"> <div class="space-y-2">
<Label for="custom_path">自定义请求路径可选</Label> <Label for="api_format">API 格式 *</Label>
<Input <template v-if="isEditMode">
id="custom_path" <Input
v-model="form.custom_path" id="api_format"
:placeholder="defaultPathPlaceholder" v-model="form.api_format"
/> disabled
class="bg-muted"
/>
<p class="text-xs text-muted-foreground">
API 格式创建后不可修改
</p>
</template>
<template v-else>
<div class="grid grid-cols-3 grid-flow-col grid-rows-2 gap-2">
<label
v-for="format in sortedApiFormats"
:key="format.value"
class="flex items-center gap-2 rounded-md border px-3 py-2 cursor-pointer transition-all text-sm"
:class="selectedFormats.includes(format.value)
? 'border-primary bg-primary/10 text-primary font-medium'
: 'border-border hover:border-primary/50 hover:bg-accent'"
>
<input
type="checkbox"
:value="format.value"
v-model="selectedFormats"
class="sr-only"
/>
<span
class="flex h-4 w-4 shrink-0 items-center justify-center rounded border transition-colors"
:class="selectedFormats.includes(format.value)
? 'border-primary bg-primary text-primary-foreground'
: 'border-muted-foreground/30'"
>
<svg
v-if="selectedFormats.includes(format.value)"
class="h-3 w-3"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="3"
stroke-linecap="round"
stroke-linejoin="round"
>
<polyline points="20 6 9 17 4 12" />
</svg>
</span>
<span>{{ format.label }}</span>
</label>
</div>
</template>
</div> </div>
</div> </div>
@@ -86,7 +104,7 @@
请求配置 请求配置
</h3> </h3>
<div class="grid grid-cols-3 gap-4"> <div class="grid grid-cols-4 gap-4">
<div class="space-y-2"> <div class="space-y-2">
<Label for="timeout">超时</Label> <Label for="timeout">超时</Label>
<Input <Input
@@ -117,11 +135,9 @@
@update:model-value="(v) => form.max_concurrent = parseNumberInput(v)" @update:model-value="(v) => form.max_concurrent = parseNumberInput(v)"
/> />
</div> </div>
</div>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-2"> <div class="space-y-2">
<Label for="rate_limit">速率限制(请求/分钟)</Label> <Label for="rate_limit">速率限制(/分钟)</Label>
<Input <Input
id="rate_limit" id="rate_limit"
:model-value="form.rate_limit ?? ''" :model-value="form.rate_limit ?? ''"
@@ -217,10 +233,10 @@
取消 取消
</Button> </Button>
<Button <Button
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)" :disabled="loading || !form.base_url || (!isEditMode && selectedFormats.length === 0)"
@click="handleSubmit()" @click="handleSubmit()"
> >
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }} {{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : `创建 ${selectedFormats.length} 个端点`) }}
</Button> </Button>
</template> </template>
</Dialog> </Dialog>
@@ -245,11 +261,6 @@ import {
Button, Button,
Input, Input,
Label, Label,
Select,
SelectTrigger,
SelectValue,
SelectContent,
SelectItem,
Switch, Switch,
} from '@/components/ui' } from '@/components/ui'
import AlertDialog from '@/components/common/AlertDialog.vue' import AlertDialog from '@/components/common/AlertDialog.vue'
@@ -280,7 +291,6 @@ const emit = defineEmits<{
const { success, error: showError } = useToast() const { success, error: showError } = useToast()
const loading = ref(false) const loading = ref(false)
const selectOpen = ref(false)
const proxyEnabled = ref(false) const proxyEnabled = ref(false)
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框 const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
@@ -306,9 +316,28 @@ const form = ref({
proxy_password: '', proxy_password: '',
}) })
// 选中的 API 格式(多选)
const selectedFormats = ref<string[]>([])
// API 格式列表 // API 格式列表
const apiFormats = ref<Array<{ value: string; label: string; default_path: string; aliases: string[] }>>([]) const apiFormats = ref<Array<{ value: string; label: string; default_path: string; aliases: string[] }>>([])
// 排序后的 API 格式:按列排列,每列是基础格式+CLI格式
const sortedApiFormats = computed(() => {
const baseFormats = apiFormats.value.filter(f => !f.value.endsWith('_cli'))
const cliFormats = apiFormats.value.filter(f => f.value.endsWith('_cli'))
// 交错排列base1, cli1, base2, cli2, base3, cli3
const result: typeof apiFormats.value = []
for (let i = 0; i < baseFormats.length; i++) {
result.push(baseFormats[i])
const cliFormat = cliFormats.find(f => f.value === baseFormats[i].value + '_cli')
if (cliFormat) {
result.push(cliFormat)
}
}
return result
})
// 加载API格式列表 // 加载API格式列表
const loadApiFormats = async () => { const loadApiFormats = async () => {
try { try {
@@ -330,7 +359,7 @@ const defaultPath = computed(() => {
// 动态 placeholder // 动态 placeholder
const defaultPathPlaceholder = computed(() => { const defaultPathPlaceholder = computed(() => {
return `留空使用默认路径:${defaultPath.value}` return defaultPath.value
}) })
// 检查是否有已保存的密码(后端返回 *** 表示有密码) // 检查是否有已保存的密码(后端返回 *** 表示有密码)
@@ -400,6 +429,7 @@ function resetForm() {
proxy_username: '', proxy_username: '',
proxy_password: '', proxy_password: '',
} }
selectedFormats.value = []
proxyEnabled.value = false proxyEnabled.value = false
} }
@@ -479,6 +509,8 @@ const handleSubmit = async (skipCredentialCheck = false) => {
} }
loading.value = true loading.value = true
let successCount = 0
try { try {
const proxyConfig = buildProxyConfig() const proxyConfig = buildProxyConfig()
@@ -497,27 +529,56 @@ const handleSubmit = async (skipCredentialCheck = false) => {
success('端点已更新', '保存成功') success('端点已更新', '保存成功')
emit('endpointUpdated') emit('endpointUpdated')
emit('update:modelValue', false)
} else if (props.provider) { } else if (props.provider) {
// 创建端点 // 批量创建端点 - 使用并发请求提升性能
await createEndpoint(props.provider.id, { const results = await Promise.allSettled(
provider_id: props.provider.id, selectedFormats.value.map(apiFormat =>
api_format: form.value.api_format, createEndpoint(props.provider!.id, {
base_url: form.value.base_url, provider_id: props.provider!.id,
custom_path: form.value.custom_path || undefined, api_format: apiFormat,
timeout: form.value.timeout, base_url: form.value.base_url,
max_retries: form.value.max_retries, custom_path: form.value.custom_path || undefined,
max_concurrent: form.value.max_concurrent, timeout: form.value.timeout,
rate_limit: form.value.rate_limit, max_retries: form.value.max_retries,
is_active: form.value.is_active, max_concurrent: form.value.max_concurrent,
proxy: proxyConfig, rate_limit: form.value.rate_limit,
is_active: form.value.is_active,
proxy: proxyConfig,
})
)
)
// 统计结果
const errors: string[] = []
results.forEach((result, index) => {
if (result.status === 'fulfilled') {
successCount++
} else {
const apiFormat = selectedFormats.value[index]
const formatLabel = apiFormats.value.find((f: any) => f.value === apiFormat)?.label || apiFormat
const errorMsg = result.reason?.response?.data?.detail || '创建失败'
errors.push(`${formatLabel}: ${errorMsg}`)
}
}) })
success('端点创建成功', '成功') const failCount = errors.length
emit('endpointCreated')
resetForm()
}
emit('update:modelValue', false) // 显示结果
if (successCount > 0 && failCount === 0) {
success(`成功创建 ${successCount} 个端点`, '创建成功')
} else if (successCount > 0 && failCount > 0) {
showError(`${failCount} 个端点创建失败:\n${errors.join('\n')}`, `${successCount} 个成功,${failCount} 个失败`)
} else {
showError(errors.join('\n') || '创建端点失败', '创建失败')
}
if (successCount > 0) {
emit('endpointCreated')
resetForm()
emit('update:modelValue', false)
}
}
} catch (error: any) { } catch (error: any) {
const action = isEditMode.value ? '更新' : '创建' const action = isEditMode.value ? '更新' : '创建'
showError(error.response?.data?.detail || `${action}端点失败`, '错误') showError(error.response?.data?.detail || `${action}端点失败`, '错误')

View File

@@ -32,11 +32,11 @@
v-for="modelName in selectedModels" v-for="modelName in selectedModels"
:key="modelName" :key="modelName"
variant="secondary" variant="secondary"
class="text-[11px] px-2 py-0.5 bg-background border-border/60 shadow-sm" class="text-[11px] px-2 py-0.5 bg-background border-border/60 shadow-sm text-foreground dark:text-white"
> >
{{ getModelLabel(modelName) }} {{ getModelLabel(modelName) }}
<button <button
class="ml-0.5 hover:text-destructive focus:outline-none" class="ml-0.5 hover:text-destructive focus:outline-none text-foreground dark:text-white"
@click.stop="toggleModel(modelName, false)" @click.stop="toggleModel(modelName, false)"
> >
&times; &times;

View File

@@ -178,7 +178,7 @@
<Button <Button
variant="ghost" variant="ghost"
size="icon" size="icon"
class="h-8 w-8 text-destructive hover:text-destructive" class="h-8 w-8 hover:text-destructive"
title="删除" title="删除"
@click="deleteModel(model)" @click="deleteModel(model)"
> >

View File

@@ -691,64 +691,70 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}" f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}"
) )
# 创建 HTTP 客户端(支持代理配置) # 获取复用的 HTTP 客户端(支持代理配置)
# endpoint.timeout 作为整体请求超时 # 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
from src.clients.http_client import HTTPClientPool from src.clients.http_client import HTTPClientPool
request_timeout = float(endpoint.timeout or 300) request_timeout = float(endpoint.timeout or 300)
http_client = HTTPClientPool.create_client_with_proxy( http_client = await HTTPClientPool.get_proxy_client(
proxy_config=endpoint.proxy, proxy_config=endpoint.proxy,
)
# 注意:不使用 async with因为复用的客户端不应该被关闭
# 超时通过 timeout 参数控制
resp = await http_client.post(
url,
json=provider_payload,
headers=provider_hdrs,
timeout=httpx.Timeout(request_timeout), timeout=httpx.Timeout(request_timeout),
) )
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
status_code = resp.status_code status_code = resp.status_code
response_headers = dict(resp.headers) response_headers = dict(resp.headers)
if resp.status_code == 401: if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}") raise ProviderAuthException(f"提供商认证失败: {provider.name}")
elif resp.status_code == 429: elif resp.status_code == 429:
raise ProviderRateLimitException( raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}", f"提供商速率限制: {provider.name}",
provider_name=str(provider.name), provider_name=str(provider.name),
response_headers=response_headers, response_headers=response_headers,
)
elif resp.status_code >= 500:
# 记录响应体以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.error(
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}"
) )
elif resp.status_code >= 500: except Exception:
# 记录响应体以便调试 pass
error_body = "" raise ProviderNotAvailableException(
try: f"提供商服务不可用: {provider.name}",
error_body = resp.text[:1000] provider_name=str(provider.name),
logger.error( upstream_status=resp.status_code,
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}" upstream_response=error_body,
) )
except Exception: elif resp.status_code != 200:
pass # 记录非200响应以便调试
raise ProviderNotAvailableException( error_body = ""
f"提供商服务不可用: {provider.name}", try:
provider_name=str(provider.name), error_body = resp.text[:1000]
upstream_status=resp.status_code, logger.warning(
upstream_response=error_body, f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
)
elif resp.status_code != 200:
# 记录非200响应以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.warning(
f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
)
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
) )
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
)
response_json = resp.json() response_json = resp.json()
return response_json if isinstance(response_json, dict) else {} return response_json if isinstance(response_json, dict) else {}
try: try:
# 解析能力需求 # 解析能力需求

View File

@@ -1534,72 +1534,78 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}" f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}"
) )
# 创建 HTTP 客户端(支持代理配置) # 获取复用的 HTTP 客户端(支持代理配置)
# endpoint.timeout 作为整体请求超时 # 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
from src.clients.http_client import HTTPClientPool from src.clients.http_client import HTTPClientPool
request_timeout = float(endpoint.timeout or 300) request_timeout = float(endpoint.timeout or 300)
http_client = HTTPClientPool.create_client_with_proxy( http_client = await HTTPClientPool.get_proxy_client(
proxy_config=endpoint.proxy, proxy_config=endpoint.proxy,
)
# 注意:不使用 async with因为复用的客户端不应该被关闭
# 超时通过 timeout 参数控制
resp = await http_client.post(
url,
json=provider_payload,
headers=provider_headers,
timeout=httpx.Timeout(request_timeout), timeout=httpx.Timeout(request_timeout),
) )
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
status_code = resp.status_code status_code = resp.status_code
response_headers = dict(resp.headers) response_headers = dict(resp.headers)
if resp.status_code == 401: if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}") raise ProviderAuthException(f"提供商认证失败: {provider.name}")
elif resp.status_code == 429: elif resp.status_code == 429:
raise ProviderRateLimitException( raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}", f"提供商速率限制: {provider.name}",
provider_name=str(provider.name), provider_name=str(provider.name),
response_headers=response_headers, response_headers=response_headers,
retry_after=int(resp.headers.get("retry-after", 0)) or None, retry_after=int(resp.headers.get("retry-after", 0)) or None,
) )
elif resp.status_code >= 500: elif resp.status_code >= 500:
error_text = resp.text error_text = resp.text
raise ProviderNotAvailableException( raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}", f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name), provider_name=str(provider.name),
upstream_status=resp.status_code, upstream_status=resp.status_code,
upstream_response=error_text, upstream_response=error_text,
) )
elif 300 <= resp.status_code < 400: elif 300 <= resp.status_code < 400:
redirect_url = resp.headers.get("location", "unknown") redirect_url = resp.headers.get("location", "unknown")
raise ProviderNotAvailableException( raise ProviderNotAvailableException(
f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}" f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}"
) )
elif resp.status_code != 200: elif resp.status_code != 200:
error_text = resp.text error_text = resp.text
raise ProviderNotAvailableException( raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}", f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name), provider_name=str(provider.name),
upstream_status=resp.status_code, upstream_status=resp.status_code,
upstream_response=error_text, upstream_response=error_text,
) )
# 安全解析 JSON 响应,处理可能的编码错误 # 安全解析 JSON 响应,处理可能的编码错误
try: try:
response_json = resp.json() response_json = resp.json()
except (UnicodeDecodeError, json.JSONDecodeError) as e: except (UnicodeDecodeError, json.JSONDecodeError) as e:
# 记录原始响应信息用于调试 # 记录原始响应信息用于调试
content_type = resp.headers.get("content-type", "unknown") content_type = resp.headers.get("content-type", "unknown")
content_encoding = resp.headers.get("content-encoding", "none") content_encoding = resp.headers.get("content-encoding", "none")
logger.error( logger.error(
f"[{self.request_id}] 无法解析响应 JSON: {e}, " f"[{self.request_id}] 无法解析响应 JSON: {e}, "
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, " f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
f"响应长度: {len(resp.content)} bytes" f"响应长度: {len(resp.content)} bytes"
) )
raise ProviderNotAvailableException( raise ProviderNotAvailableException(
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}" f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
) )
# 提取 Provider 响应元数据(子类可覆盖) # 提取 Provider 响应元数据(子类可覆盖)
response_metadata_result = self._extract_response_metadata(response_json) response_metadata_result = self._extract_response_metadata(response_json)
return response_json if isinstance(response_json, dict) else {} return response_json if isinstance(response_json, dict) else {}
try: try:
# 解析能力需求 # 解析能力需求

View File

@@ -1,10 +1,18 @@
""" """
全局HTTP客户端池管理 全局HTTP客户端池管理
避免每次请求都创建新的AsyncClient,提高性能 避免每次请求都创建新的AsyncClient,提高性能
性能优化说明:
1. 默认客户端:无代理场景,全局复用单一客户端
2. 代理客户端缓存:相同代理配置复用同一客户端,避免重复创建
3. 连接池复用Keep-alive 连接减少 TCP 握手开销
""" """
import asyncio
import hashlib
import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Tuple
from urllib.parse import quote, urlparse from urllib.parse import quote, urlparse
import httpx import httpx
@@ -12,6 +20,32 @@ import httpx
from src.config import config from src.config import config
from src.core.logger import logger from src.core.logger import logger
# 模块级锁,避免类属性延迟初始化的竞态条件
_proxy_clients_lock = asyncio.Lock()
_default_client_lock = asyncio.Lock()
def _compute_proxy_cache_key(proxy_config: Optional[Dict[str, Any]]) -> str:
"""
计算代理配置的缓存键
Args:
proxy_config: 代理配置字典
Returns:
缓存键字符串,无代理时返回 "__no_proxy__"
"""
if not proxy_config:
return "__no_proxy__"
# 构建代理 URL 作为缓存键的基础
proxy_url = build_proxy_url(proxy_config)
if not proxy_url:
return "__no_proxy__"
# 使用 MD5 哈希来避免过长的键名
return f"proxy:{hashlib.md5(proxy_url.encode()).hexdigest()[:16]}"
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]: def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
""" """
@@ -61,11 +95,20 @@ class HTTPClientPool:
全局HTTP客户端池单例 全局HTTP客户端池单例
管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接 管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接
性能优化:
1. 默认客户端:无代理场景复用
2. 代理客户端缓存:相同代理配置复用同一客户端
3. LRU 淘汰:代理客户端超过上限时淘汰最久未使用的
""" """
_instance: Optional["HTTPClientPool"] = None _instance: Optional["HTTPClientPool"] = None
_default_client: Optional[httpx.AsyncClient] = None _default_client: Optional[httpx.AsyncClient] = None
_clients: Dict[str, httpx.AsyncClient] = {} _clients: Dict[str, httpx.AsyncClient] = {}
# 代理客户端缓存:{cache_key: (client, last_used_time)}
_proxy_clients: Dict[str, Tuple[httpx.AsyncClient, float]] = {}
# 代理客户端缓存上限(避免内存泄漏)
_max_proxy_clients: int = 50
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
@@ -73,12 +116,50 @@ class HTTPClientPool:
return cls._instance return cls._instance
@classmethod @classmethod
def get_default_client(cls) -> httpx.AsyncClient: async def get_default_client_async(cls) -> httpx.AsyncClient:
""" """
获取默认的HTTP客户端 获取默认的HTTP客户端(异步线程安全版本)
用于大多数HTTP请求,具有合理的默认配置 用于大多数HTTP请求,具有合理的默认配置
""" """
if cls._default_client is not None:
return cls._default_client
async with _default_client_lock:
# 双重检查,避免重复创建
if cls._default_client is None:
cls._default_client = httpx.AsyncClient(
http2=False, # 暂时禁用HTTP/2以提高兼容性
verify=True, # 启用SSL验证
timeout=httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
limits=httpx.Limits(
max_connections=config.http_max_connections,
max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=config.http_keepalive_expiry,
),
follow_redirects=True, # 跟随重定向
)
logger.info(
f"全局HTTP客户端池已初始化: "
f"max_connections={config.http_max_connections}, "
f"keepalive={config.http_keepalive_connections}, "
f"keepalive_expiry={config.http_keepalive_expiry}s"
)
return cls._default_client
@classmethod
def get_default_client(cls) -> httpx.AsyncClient:
"""
获取默认的HTTP客户端同步版本向后兼容
⚠️ 注意:此方法在高并发首次调用时可能存在竞态条件,
推荐使用 get_default_client_async() 异步版本。
"""
if cls._default_client is None: if cls._default_client is None:
cls._default_client = httpx.AsyncClient( cls._default_client = httpx.AsyncClient(
http2=False, # 暂时禁用HTTP/2以提高兼容性 http2=False, # 暂时禁用HTTP/2以提高兼容性
@@ -135,6 +216,101 @@ class HTTPClientPool:
return cls._clients[name] return cls._clients[name]
@classmethod
def _get_proxy_clients_lock(cls) -> asyncio.Lock:
"""获取代理客户端缓存锁(模块级单例,避免竞态条件)"""
return _proxy_clients_lock
@classmethod
async def _evict_lru_proxy_client(cls) -> None:
"""淘汰最久未使用的代理客户端"""
if len(cls._proxy_clients) < cls._max_proxy_clients:
return
# 找到最久未使用的客户端
oldest_key = min(cls._proxy_clients.keys(), key=lambda k: cls._proxy_clients[k][1])
old_client, _ = cls._proxy_clients.pop(oldest_key)
# 异步关闭旧客户端
try:
await old_client.aclose()
logger.debug(f"淘汰代理客户端: {oldest_key}")
except Exception as e:
logger.warning(f"关闭代理客户端失败: {e}")
@classmethod
async def get_proxy_client(
cls,
proxy_config: Optional[Dict[str, Any]] = None,
) -> httpx.AsyncClient:
"""
获取代理客户端(带缓存复用)
相同代理配置会复用同一个客户端,大幅减少连接建立开销。
注意:返回的客户端使用默认超时配置,如需自定义超时请在请求时传递 timeout 参数。
Args:
proxy_config: 代理配置字典,包含 url, username, password
Returns:
可复用的 httpx.AsyncClient 实例
"""
cache_key = _compute_proxy_cache_key(proxy_config)
# 无代理时返回默认客户端
if cache_key == "__no_proxy__":
return await cls.get_default_client_async()
lock = cls._get_proxy_clients_lock()
async with lock:
# 检查缓存
if cache_key in cls._proxy_clients:
client, _ = cls._proxy_clients[cache_key]
# 健康检查:如果客户端已关闭,移除并重新创建
if client.is_closed:
del cls._proxy_clients[cache_key]
logger.debug(f"代理客户端已关闭,将重新创建: {cache_key}")
else:
# 更新最后使用时间
cls._proxy_clients[cache_key] = (client, time.time())
return client
# 淘汰旧客户端(如果超过上限)
await cls._evict_lru_proxy_client()
# 创建新客户端(使用默认超时,请求时可覆盖)
client_config: Dict[str, Any] = {
"http2": False,
"verify": True,
"follow_redirects": True,
"limits": httpx.Limits(
max_connections=config.http_max_connections,
max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=config.http_keepalive_expiry,
),
"timeout": httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
}
# 添加代理配置
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url:
client_config["proxy"] = proxy_url
client = httpx.AsyncClient(**client_config)
cls._proxy_clients[cache_key] = (client, time.time())
logger.debug(
f"创建代理客户端(缓存): {proxy_config.get('url', 'unknown') if proxy_config else 'none'}, "
f"缓存数量: {len(cls._proxy_clients)}"
)
return client
@classmethod @classmethod
async def close_all(cls): async def close_all(cls):
"""关闭所有HTTP客户端""" """关闭所有HTTP客户端"""
@@ -148,6 +324,16 @@ class HTTPClientPool:
logger.debug(f"命名HTTP客户端已关闭: {name}") logger.debug(f"命名HTTP客户端已关闭: {name}")
cls._clients.clear() cls._clients.clear()
# 关闭代理客户端缓存
for cache_key, (client, _) in cls._proxy_clients.items():
try:
await client.aclose()
logger.debug(f"代理客户端已关闭: {cache_key}")
except Exception as e:
logger.warning(f"关闭代理客户端失败: {e}")
cls._proxy_clients.clear()
logger.info("所有HTTP客户端已关闭") logger.info("所有HTTP客户端已关闭")
@classmethod @classmethod
@@ -190,13 +376,15 @@ class HTTPClientPool:
""" """
创建带代理配置的HTTP客户端 创建带代理配置的HTTP客户端
⚠️ 性能警告:此方法每次都创建新客户端,推荐使用 get_proxy_client() 复用连接。
Args: Args:
proxy_config: 代理配置字典,包含 url, username, password proxy_config: 代理配置字典,包含 url, username, password
timeout: 超时配置 timeout: 超时配置
**kwargs: 其他 httpx.AsyncClient 配置参数 **kwargs: 其他 httpx.AsyncClient 配置参数
Returns: Returns:
配置好的 httpx.AsyncClient 实例 配置好的 httpx.AsyncClient 实例(调用者需要负责关闭)
""" """
client_config: Dict[str, Any] = { client_config: Dict[str, Any] = {
"http2": False, "http2": False,
@@ -218,11 +406,21 @@ class HTTPClientPool:
proxy_url = build_proxy_url(proxy_config) if proxy_config else None proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url: if proxy_url:
client_config["proxy"] = proxy_url client_config["proxy"] = proxy_url
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}") logger.debug(f"创建带代理的HTTP客户端(一次性): {proxy_config.get('url', 'unknown')}")
client_config.update(kwargs) client_config.update(kwargs)
return httpx.AsyncClient(**client_config) return httpx.AsyncClient(**client_config)
@classmethod
def get_pool_stats(cls) -> Dict[str, Any]:
"""获取连接池统计信息"""
return {
"default_client_active": cls._default_client is not None,
"named_clients_count": len(cls._clients),
"proxy_clients_count": len(cls._proxy_clients),
"max_proxy_clients": cls._max_proxy_clients,
}
# 便捷访问函数 # 便捷访问函数
def get_http_client() -> httpx.AsyncClient: def get_http_client() -> httpx.AsyncClient:

View File

@@ -85,6 +85,8 @@ class ConcurrencyManager:
""" """
获取当前并发数 获取当前并发数
性能优化:使用 MGET 批量获取,减少 Redis 往返次数
Args: Args:
endpoint_id: Endpoint ID可选 endpoint_id: Endpoint ID可选
key_id: ProviderAPIKey ID可选 key_id: ProviderAPIKey ID可选
@@ -104,15 +106,21 @@ class ConcurrencyManager:
key_count = 0 key_count = 0
try: try:
# 使用 MGET 批量获取,减少 Redis 往返2 次 GET -> 1 次 MGET
keys_to_fetch = []
if endpoint_id: if endpoint_id:
endpoint_key = self._get_endpoint_key(endpoint_id) keys_to_fetch.append(self._get_endpoint_key(endpoint_id))
result = await self._redis.get(endpoint_key)
endpoint_count = int(result) if result else 0
if key_id: if key_id:
key_key = self._get_key_key(key_id) keys_to_fetch.append(self._get_key_key(key_id))
result = await self._redis.get(key_key)
key_count = int(result) if result else 0 if keys_to_fetch:
results = await self._redis.mget(keys_to_fetch)
idx = 0
if endpoint_id:
endpoint_count = int(results[idx]) if results[idx] else 0
idx += 1
if key_id:
key_count = int(results[idx]) if results[idx] else 0
except Exception as e: except Exception as e:
logger.error(f"获取并发数失败: {e}") logger.error(f"获取并发数失败: {e}")