diff --git a/alembic/versions/20251218_0631_f30f9936f6a2_add_proxy_field_to_provider_endpoints.py b/alembic/versions/20251218_0631_f30f9936f6a2_add_proxy_field_to_provider_endpoints.py new file mode 100644 index 0000000..1e7e2df --- /dev/null +++ b/alembic/versions/20251218_0631_f30f9936f6a2_add_proxy_field_to_provider_endpoints.py @@ -0,0 +1,57 @@ +"""add proxy field to provider_endpoints + +Revision ID: f30f9936f6a2 +Revises: 1cc6942cf06f +Create Date: 2025-12-18 06:31:58.451112+00:00 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy import inspect + +# revision identifiers, used by Alembic. +revision = 'f30f9936f6a2' +down_revision = '1cc6942cf06f' +branch_labels = None +depends_on = None + + +def column_exists(table_name: str, column_name: str) -> bool: + """检查列是否存在""" + bind = op.get_bind() + inspector = inspect(bind) + columns = [col['name'] for col in inspector.get_columns(table_name)] + return column_name in columns + + +def get_column_type(table_name: str, column_name: str) -> str: + """获取列的类型""" + bind = op.get_bind() + inspector = inspect(bind) + for col in inspector.get_columns(table_name): + if col['name'] == column_name: + return str(col['type']).upper() + return '' + + +def upgrade() -> None: + """添加 proxy 字段到 provider_endpoints 表""" + if not column_exists('provider_endpoints', 'proxy'): + # 字段不存在,直接添加 JSONB 类型 + op.add_column('provider_endpoints', sa.Column('proxy', JSONB(), nullable=True)) + else: + # 字段已存在,检查是否需要转换类型 + col_type = get_column_type('provider_endpoints', 'proxy') + if 'JSONB' not in col_type: + # 如果是 JSON 类型,转换为 JSONB + op.execute( + 'ALTER TABLE provider_endpoints ' + 'ALTER COLUMN proxy TYPE JSONB USING proxy::jsonb' + ) + + +def downgrade() -> None: + """移除 proxy 字段""" + if column_exists('provider_endpoints', 'proxy'): + op.drop_column('provider_endpoints', 'proxy') diff --git a/frontend/src/api/endpoints/endpoints.ts b/frontend/src/api/endpoints/endpoints.ts index 642048f..a0775d4 100644 --- a/frontend/src/api/endpoints/endpoints.ts +++ b/frontend/src/api/endpoints/endpoints.ts @@ -1,5 +1,5 @@ import client from '../client' -import type { ProviderEndpoint } from './types' +import type { ProviderEndpoint, ProxyConfig } from './types' /** * 获取指定 Provider 的所有 Endpoints @@ -38,6 +38,7 @@ export async function createEndpoint( rate_limit?: number is_active?: boolean config?: Record + proxy?: ProxyConfig | null } ): Promise { const response = await client.post(`/api/admin/endpoints/providers/${providerId}/endpoints`, data) @@ -63,6 +64,7 @@ export async function updateEndpoint( rate_limit: number is_active: boolean config: Record + proxy: ProxyConfig | null }> ): Promise { const response = await client.put(`/api/admin/endpoints/${endpointId}`, data) diff --git a/frontend/src/api/endpoints/types.ts b/frontend/src/api/endpoints/types.ts index ea8ee22..ef13cb7 100644 --- a/frontend/src/api/endpoints/types.ts +++ b/frontend/src/api/endpoints/types.ts @@ -20,6 +20,16 @@ export const API_FORMAT_LABELS: Record = { [API_FORMATS.GEMINI_CLI]: 'Gemini CLI', } +/** + * 代理配置类型 + */ +export interface ProxyConfig { + url: string + username?: string + password?: string + enabled?: boolean // 是否启用代理(false 时保留配置但不使用) +} + export interface ProviderEndpoint { id: string provider_id: string @@ -41,6 +51,7 @@ export interface ProviderEndpoint { last_failure_at?: string is_active: boolean config?: Record + proxy?: ProxyConfig | null total_keys: number active_keys: number created_at: string diff --git a/frontend/src/features/providers/components/EndpointFormDialog.vue b/frontend/src/features/providers/components/EndpointFormDialog.vue index c2a8c0e..3f4c200 100644 --- a/frontend/src/features/providers/components/EndpointFormDialog.vue +++ b/frontend/src/features/providers/components/EndpointFormDialog.vue @@ -9,7 +9,7 @@ >
@@ -132,6 +132,79 @@
+ + +
+
+

+ 代理配置 +

+
+ + 启用代理 +
+
+ +
+
+ + +

+ {{ proxyUrlError }} +

+

+ 支持 HTTP、HTTPS、SOCKS5 代理 +

+
+ +
+
+ + +
+ +
+ + +
+
+
+
+ + + diff --git a/src/api/admin/endpoints/routes.py b/src/api/admin/endpoints/routes.py index 9fd6bea..3d22e67 100644 --- a/src/api/admin/endpoints/routes.py +++ b/src/api/admin/endpoints/routes.py @@ -5,7 +5,7 @@ ProviderEndpoint CRUD 管理 API import uuid from dataclasses import dataclass from datetime import datetime, timezone -from typing import List +from typing import List, Optional from fastapi import APIRouter, Depends, Query, Request from sqlalchemy import and_, func @@ -27,6 +27,16 @@ router = APIRouter(tags=["Endpoint Management"]) pipeline = ApiRequestPipeline() +def mask_proxy_password(proxy_config: Optional[dict]) -> Optional[dict]: + """对代理配置中的密码进行脱敏处理""" + if not proxy_config: + return None + masked = dict(proxy_config) + if masked.get("password"): + masked["password"] = "***" + return masked + + @router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse]) async def list_provider_endpoints( provider_id: str, @@ -153,6 +163,7 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter): "api_format": endpoint.api_format, "total_keys": total_keys_map.get(endpoint.id, 0), "active_keys": active_keys_map.get(endpoint.id, 0), + "proxy": mask_proxy_password(endpoint.proxy), } endpoint_dict.pop("_sa_instance_state", None) result.append(ProviderEndpointResponse(**endpoint_dict)) @@ -202,6 +213,7 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter): rate_limit=self.endpoint_data.rate_limit, is_active=True, config=self.endpoint_data.config, + proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None, created_at=now, updated_at=now, ) @@ -215,12 +227,13 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter): endpoint_dict = { k: v for k, v in new_endpoint.__dict__.items() - if k not in {"api_format", "_sa_instance_state"} + if k not in {"api_format", "_sa_instance_state", "proxy"} } return ProviderEndpointResponse( **endpoint_dict, provider_name=provider.name, api_format=new_endpoint.api_format, + proxy=mask_proxy_password(new_endpoint.proxy), total_keys=0, active_keys=0, ) @@ -259,12 +272,13 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter): endpoint_dict = { k: v for k, v in endpoint_obj.__dict__.items() - if k not in {"api_format", "_sa_instance_state"} + if k not in {"api_format", "_sa_instance_state", "proxy"} } return ProviderEndpointResponse( **endpoint_dict, provider_name=provider.name, api_format=endpoint_obj.api_format, + proxy=mask_proxy_password(endpoint_obj.proxy), total_keys=total_keys, active_keys=active_keys, ) @@ -284,6 +298,17 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter): raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在") update_data = self.endpoint_data.model_dump(exclude_unset=True) + # 把 proxy 转换为 dict 存储,支持显式设置为 None 清除代理 + if "proxy" in update_data: + if update_data["proxy"] is not None: + new_proxy = dict(update_data["proxy"]) + # 只有当密码字段未提供时才保留原密码(空字符串视为显式清除) + if "password" not in new_proxy and endpoint.proxy: + old_password = endpoint.proxy.get("password") + if old_password: + new_proxy["password"] = old_password + update_data["proxy"] = new_proxy + # proxy 为 None 时保留,用于清除代理配置 for field, value in update_data.items(): setattr(endpoint, field, value) endpoint.updated_at = datetime.now(timezone.utc) @@ -311,12 +336,13 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter): endpoint_dict = { k: v for k, v in endpoint.__dict__.items() - if k not in {"api_format", "_sa_instance_state"} + if k not in {"api_format", "_sa_instance_state", "proxy"} } return ProviderEndpointResponse( **endpoint_dict, provider_name=provider.name if provider else "Unknown", api_format=endpoint.api_format, + proxy=mask_proxy_password(endpoint.proxy), total_keys=total_keys, active_keys=active_keys, ) diff --git a/src/api/handlers/base/chat_handler_base.py b/src/api/handlers/base/chat_handler_base.py index dee496b..2dc0501 100644 --- a/src/api/handlers/base/chat_handler_base.py +++ b/src/api/handlers/base/chat_handler_base.py @@ -466,7 +466,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC): pool=config.http_pool_timeout, ) - http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True) + # 创建 HTTP 客户端(支持代理配置) + from src.clients.http_client import HTTPClientPool + + http_client = HTTPClientPool.create_client_with_proxy( + proxy_config=endpoint.proxy, + timeout=timeout_config, + ) try: response_ctx = http_client.stream( "POST", url, json=provider_payload, headers=provider_headers @@ -634,10 +640,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC): logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, " f"模型={model} -> {mapped_model or '无映射'}") - async with httpx.AsyncClient( - timeout=float(endpoint.timeout), - follow_redirects=True, - ) as http_client: + # 创建 HTTP 客户端(支持代理配置) + from src.clients.http_client import HTTPClientPool + + http_client = HTTPClientPool.create_client_with_proxy( + proxy_config=endpoint.proxy, + timeout=httpx.Timeout(float(endpoint.timeout)), + ) + async with http_client: resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs) status_code = resp.status_code diff --git a/src/api/handlers/base/cli_handler_base.py b/src/api/handlers/base/cli_handler_base.py index 19ab91d..f9f569c 100644 --- a/src/api/handlers/base/cli_handler_base.py +++ b/src/api/handlers/base/cli_handler_base.py @@ -454,7 +454,13 @@ class CliMessageHandlerBase(BaseMessageHandler): f"Key=***{key.api_key[-4:]}, " f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}") - http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True) + # 创建 HTTP 客户端(支持代理配置) + from src.clients.http_client import HTTPClientPool + + http_client = HTTPClientPool.create_client_with_proxy( + proxy_config=endpoint.proxy, + timeout=timeout_config, + ) try: response_ctx = http_client.stream( "POST", url, json=provider_payload, headers=provider_headers @@ -1419,10 +1425,14 @@ class CliMessageHandlerBase(BaseMessageHandler): f"Key=***{key.api_key[-4:]}, " f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}") - async with httpx.AsyncClient( - timeout=float(endpoint.timeout), - follow_redirects=True, - ) as http_client: + # 创建 HTTP 客户端(支持代理配置) + from src.clients.http_client import HTTPClientPool + + http_client = HTTPClientPool.create_client_with_proxy( + proxy_config=endpoint.proxy, + timeout=httpx.Timeout(float(endpoint.timeout)), + ) + async with http_client: resp = await http_client.post(url, json=provider_payload, headers=provider_headers) status_code = resp.status_code diff --git a/src/clients/http_client.py b/src/clients/http_client.py index 8ffdfc0..9d419dd 100644 --- a/src/clients/http_client.py +++ b/src/clients/http_client.py @@ -5,12 +5,55 @@ from contextlib import asynccontextmanager from typing import Any, Dict, Optional +from urllib.parse import quote, urlparse import httpx from src.core.logger import logger +def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]: + """ + 根据代理配置构建完整的代理 URL + + Args: + proxy_config: 代理配置字典,包含 url, username, password, enabled + + Returns: + 完整的代理 URL,如 socks5://user:pass@host:port + 如果 enabled=False 或无配置,返回 None + """ + if not proxy_config: + return None + + # 检查 enabled 字段,默认为 True(兼容旧数据) + if not proxy_config.get("enabled", True): + return None + + proxy_url = proxy_config.get("url") + if not proxy_url: + return None + + username = proxy_config.get("username") + password = proxy_config.get("password") + + # 只要有用户名就添加认证信息(密码可以为空) + if username: + parsed = urlparse(proxy_url) + # URL 编码用户名和密码,处理特殊字符(如 @, :, /) + encoded_username = quote(username, safe="") + encoded_password = quote(password, safe="") if password else "" + # 重新构建带认证的代理 URL + if encoded_password: + auth_proxy = f"{parsed.scheme}://{encoded_username}:{encoded_password}@{parsed.netloc}" + else: + auth_proxy = f"{parsed.scheme}://{encoded_username}@{parsed.netloc}" + if parsed.path: + auth_proxy += parsed.path + return auth_proxy + + return proxy_url + class HTTPClientPool: """ @@ -121,6 +164,44 @@ class HTTPClientPool: finally: await client.aclose() + @classmethod + def create_client_with_proxy( + cls, + proxy_config: Optional[Dict[str, Any]] = None, + timeout: Optional[httpx.Timeout] = None, + **kwargs: Any, + ) -> httpx.AsyncClient: + """ + 创建带代理配置的HTTP客户端 + + Args: + proxy_config: 代理配置字典,包含 url, username, password + timeout: 超时配置 + **kwargs: 其他 httpx.AsyncClient 配置参数 + + Returns: + 配置好的 httpx.AsyncClient 实例 + """ + config: Dict[str, Any] = { + "http2": False, + "verify": True, + "follow_redirects": True, + } + + if timeout: + config["timeout"] = timeout + else: + config["timeout"] = httpx.Timeout(10.0, read=300.0) + + # 添加代理配置 + proxy_url = build_proxy_url(proxy_config) if proxy_config else None + if proxy_url: + config["proxy"] = proxy_url + logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}") + + config.update(kwargs) + return httpx.AsyncClient(**config) + # 便捷访问函数 def get_http_client() -> httpx.AsyncClient: diff --git a/src/models/admin_requests.py b/src/models/admin_requests.py index c20e05f..35eb995 100644 --- a/src/models/admin_requests.py +++ b/src/models/admin_requests.py @@ -13,6 +13,42 @@ from pydantic import BaseModel, Field, field_validator, model_validator from src.core.enums import APIFormat, ProviderBillingType +class ProxyConfig(BaseModel): + """代理配置""" + + url: str = Field(..., description="代理 URL (http://, https://, socks5://)") + username: Optional[str] = Field(None, max_length=255, description="代理用户名") + password: Optional[str] = Field(None, max_length=500, description="代理密码") + enabled: bool = Field(True, description="是否启用代理(false 时保留配置但不使用)") + + @field_validator("url") + @classmethod + def validate_proxy_url(cls, v: str) -> str: + """验证代理 URL 格式""" + from urllib.parse import urlparse + + v = v.strip() + + # 检查禁止的字符(防止注入) + if "\n" in v or "\r" in v: + raise ValueError("代理 URL 包含非法字符") + + # 验证协议(不支持 SOCKS4) + if not re.match(r"^(http|https|socks5)://", v, re.IGNORECASE): + raise ValueError("代理 URL 必须以 http://, https:// 或 socks5:// 开头") + + # 验证 URL 结构 + parsed = urlparse(v) + if not parsed.netloc: + raise ValueError("代理 URL 必须包含有效的 host") + + # 禁止 URL 中内嵌认证信息,强制使用独立字段 + if parsed.username or parsed.password: + raise ValueError("请勿在 URL 中包含用户名和密码,请使用独立的认证字段") + + return v + + class CreateProviderRequest(BaseModel): """创建 Provider 请求""" @@ -165,6 +201,7 @@ class CreateEndpointRequest(BaseModel): rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制") concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制") config: Optional[Dict[str, Any]] = Field(None, description="其他配置") + proxy: Optional[ProxyConfig] = Field(None, description="代理配置") @field_validator("name") @classmethod @@ -220,6 +257,7 @@ class UpdateEndpointRequest(BaseModel): rpm_limit: Optional[int] = Field(None, ge=0) concurrent_limit: Optional[int] = Field(None, ge=0) config: Optional[Dict[str, Any]] = None + proxy: Optional[ProxyConfig] = Field(None, description="代理配置") # 复用验证器 _validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__) diff --git a/src/models/database.py b/src/models/database.py index e0027d9..5317fff 100644 --- a/src/models/database.py +++ b/src/models/database.py @@ -538,6 +538,9 @@ class ProviderEndpoint(Base): # 额外配置 config = Column(JSON, nullable=True) # 端点特定配置(不推荐使用,优先使用专用字段) + # 代理配置 + proxy = Column(JSONB, nullable=True) # 代理配置: {url, username, password} + # 时间戳 created_at = Column( DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False diff --git a/src/models/endpoint_models.py b/src/models/endpoint_models.py index b61f091..cbbf025 100644 --- a/src/models/endpoint_models.py +++ b/src/models/endpoint_models.py @@ -8,6 +8,8 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, ConfigDict, Field, field_validator +from src.models.admin_requests import ProxyConfig + # ========== ProviderEndpoint CRUD ========== @@ -30,6 +32,9 @@ class ProviderEndpointCreate(BaseModel): # 额外配置 config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置(JSON)") + # 代理配置 + proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置") + @field_validator("api_format") @classmethod def validate_api_format(cls, v: str) -> str: @@ -64,6 +69,7 @@ class ProviderEndpointUpdate(BaseModel): rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制") is_active: Optional[bool] = Field(default=None, description="是否启用") config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置") + proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置") @field_validator("base_url") @classmethod @@ -104,6 +110,9 @@ class ProviderEndpointResponse(BaseModel): # 额外配置 config: Optional[Dict[str, Any]] = None + # 代理配置(响应中密码已脱敏) + proxy: Optional[Dict[str, Any]] = Field(default=None, description="代理配置(密码已脱敏)") + # 统计(从 Keys 聚合) total_keys: int = Field(default=0, description="总 Key 数量") active_keys: int = Field(default=0, description="活跃 Key 数量")