2025-12-10 20:52:44 +08:00
|
|
|
|
"""
|
|
|
|
|
|
管理接口的 Pydantic 请求模型
|
|
|
|
|
|
|
|
|
|
|
|
提供完整的输入验证和安全过滤
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import re
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
|
|
|
|
|
|
|
|
|
|
from src.core.enums import APIFormat, ProviderBillingType
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-12-18 14:42:06 +08:00
|
|
|
|
class ProxyConfig(BaseModel):
|
|
|
|
|
|
"""代理配置"""
|
|
|
|
|
|
|
|
|
|
|
|
url: str = Field(..., description="代理 URL (http://, https://, socks5://)")
|
|
|
|
|
|
username: Optional[str] = Field(None, max_length=255, description="代理用户名")
|
|
|
|
|
|
password: Optional[str] = Field(None, max_length=500, description="代理密码")
|
2025-12-18 16:14:37 +08:00
|
|
|
|
enabled: bool = Field(True, description="是否启用代理(false 时保留配置但不使用)")
|
2025-12-18 14:42:06 +08:00
|
|
|
|
|
|
|
|
|
|
@field_validator("url")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_proxy_url(cls, v: str) -> str:
|
|
|
|
|
|
"""验证代理 URL 格式"""
|
2025-12-18 16:14:37 +08:00
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
|
2025-12-18 14:42:06 +08:00
|
|
|
|
v = v.strip()
|
2025-12-18 16:14:37 +08:00
|
|
|
|
|
|
|
|
|
|
# 检查禁止的字符(防止注入)
|
|
|
|
|
|
if "\n" in v or "\r" in v:
|
|
|
|
|
|
raise ValueError("代理 URL 包含非法字符")
|
|
|
|
|
|
|
|
|
|
|
|
# 验证协议(不支持 SOCKS4)
|
|
|
|
|
|
if not re.match(r"^(http|https|socks5)://", v, re.IGNORECASE):
|
|
|
|
|
|
raise ValueError("代理 URL 必须以 http://, https:// 或 socks5:// 开头")
|
|
|
|
|
|
|
|
|
|
|
|
# 验证 URL 结构
|
|
|
|
|
|
parsed = urlparse(v)
|
|
|
|
|
|
if not parsed.netloc:
|
|
|
|
|
|
raise ValueError("代理 URL 必须包含有效的 host")
|
|
|
|
|
|
|
|
|
|
|
|
# 禁止 URL 中内嵌认证信息,强制使用独立字段
|
|
|
|
|
|
if parsed.username or parsed.password:
|
|
|
|
|
|
raise ValueError("请勿在 URL 中包含用户名和密码,请使用独立的认证字段")
|
|
|
|
|
|
|
2025-12-18 14:42:06 +08:00
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-12-10 20:52:44 +08:00
|
|
|
|
class CreateProviderRequest(BaseModel):
|
|
|
|
|
|
"""创建 Provider 请求"""
|
|
|
|
|
|
|
|
|
|
|
|
name: str = Field(
|
|
|
|
|
|
...,
|
|
|
|
|
|
min_length=1,
|
|
|
|
|
|
max_length=100,
|
|
|
|
|
|
description="Provider 名称(英文字母、数字、下划线、连字符)",
|
|
|
|
|
|
)
|
|
|
|
|
|
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
|
|
|
|
|
|
description: Optional[str] = Field(None, max_length=1000, description="描述")
|
|
|
|
|
|
website: Optional[str] = Field(None, max_length=500, description="官网地址")
|
|
|
|
|
|
billing_type: Optional[str] = Field(
|
|
|
|
|
|
ProviderBillingType.PAY_AS_YOU_GO.value, description="计费类型"
|
|
|
|
|
|
)
|
|
|
|
|
|
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="周期配额(美元)")
|
|
|
|
|
|
quota_reset_day: Optional[int] = Field(30, ge=1, le=365, description="配额重置周期(天数)")
|
|
|
|
|
|
quota_last_reset_at: Optional[datetime] = Field(None, description="当前周期开始时间")
|
|
|
|
|
|
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
|
|
|
|
|
|
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
|
|
|
|
|
provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)")
|
|
|
|
|
|
is_active: Optional[bool] = Field(True, description="是否启用")
|
|
|
|
|
|
rate_limit: Optional[int] = Field(None, ge=0, description="速率限制")
|
|
|
|
|
|
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
|
|
|
|
|
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("name")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_name(cls, v: str) -> str:
|
|
|
|
|
|
"""验证名称格式"""
|
|
|
|
|
|
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
|
|
|
|
|
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
|
|
|
|
|
|
|
|
|
|
|
|
# SQL 注入防护:检查危险关键字
|
|
|
|
|
|
dangerous_keywords = [
|
|
|
|
|
|
"SELECT",
|
|
|
|
|
|
"INSERT",
|
|
|
|
|
|
"UPDATE",
|
|
|
|
|
|
"DELETE",
|
|
|
|
|
|
"DROP",
|
|
|
|
|
|
"CREATE",
|
|
|
|
|
|
"ALTER",
|
|
|
|
|
|
"EXEC",
|
|
|
|
|
|
"UNION",
|
|
|
|
|
|
"OR",
|
|
|
|
|
|
"AND",
|
|
|
|
|
|
"--",
|
|
|
|
|
|
";",
|
|
|
|
|
|
"'",
|
|
|
|
|
|
'"',
|
|
|
|
|
|
"<",
|
|
|
|
|
|
">",
|
|
|
|
|
|
]
|
|
|
|
|
|
upper_name = v.upper()
|
|
|
|
|
|
for keyword in dangerous_keywords:
|
|
|
|
|
|
if keyword in upper_name:
|
|
|
|
|
|
raise ValueError(f"名称包含禁止的字符或关键字: {keyword}")
|
|
|
|
|
|
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("display_name", "description")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def sanitize_text(cls, v: Optional[str]) -> Optional[str]:
|
|
|
|
|
|
"""清理文本输入,防止 XSS"""
|
|
|
|
|
|
if v is None:
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
# 移除潜在的脚本标签
|
|
|
|
|
|
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
|
|
|
|
|
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
|
|
|
|
|
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
|
|
|
|
|
|
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE) # 移除事件处理器
|
|
|
|
|
|
|
|
|
|
|
|
# 移除危险的 HTML 标签
|
|
|
|
|
|
dangerous_tags = ["script", "iframe", "object", "embed", "link", "style"]
|
|
|
|
|
|
for tag in dangerous_tags:
|
|
|
|
|
|
v = re.sub(rf"<{tag}[^>]*>", "", v, flags=re.IGNORECASE)
|
|
|
|
|
|
v = re.sub(rf"</{tag}>", "", v, flags=re.IGNORECASE)
|
|
|
|
|
|
|
|
|
|
|
|
return v.strip()
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("website")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_website(cls, v: Optional[str]) -> Optional[str]:
|
|
|
|
|
|
"""验证网站地址"""
|
|
|
|
|
|
if v is None or v.strip() == "":
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
v = v.strip()
|
|
|
|
|
|
|
|
|
|
|
|
# 自动补全 https:// 前缀
|
|
|
|
|
|
if not re.match(r"^https?://", v, re.IGNORECASE):
|
|
|
|
|
|
v = f"https://{v}"
|
|
|
|
|
|
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("billing_type")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_billing_type(cls, v: Optional[str]) -> Optional[str]:
|
|
|
|
|
|
"""验证计费类型"""
|
|
|
|
|
|
if v is None:
|
|
|
|
|
|
return ProviderBillingType.PAY_AS_YOU_GO.value
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
ProviderBillingType(v)
|
|
|
|
|
|
return v
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
valid_types = [t.value for t in ProviderBillingType]
|
|
|
|
|
|
raise ValueError(f"无效的计费类型,有效值为: {', '.join(valid_types)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UpdateProviderRequest(BaseModel):
|
|
|
|
|
|
"""更新 Provider 请求"""
|
|
|
|
|
|
|
|
|
|
|
|
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
|
|
|
|
|
description: Optional[str] = Field(None, max_length=1000)
|
|
|
|
|
|
website: Optional[str] = Field(None, max_length=500)
|
|
|
|
|
|
billing_type: Optional[str] = None
|
|
|
|
|
|
monthly_quota_usd: Optional[float] = Field(None, ge=0)
|
|
|
|
|
|
quota_reset_day: Optional[int] = Field(None, ge=1, le=365)
|
|
|
|
|
|
quota_last_reset_at: Optional[datetime] = None
|
|
|
|
|
|
quota_expires_at: Optional[datetime] = None
|
|
|
|
|
|
rpm_limit: Optional[int] = Field(None, ge=0)
|
|
|
|
|
|
provider_priority: Optional[int] = Field(None, ge=0, le=1000)
|
|
|
|
|
|
is_active: Optional[bool] = None
|
|
|
|
|
|
rate_limit: Optional[int] = Field(None, ge=0)
|
|
|
|
|
|
concurrent_limit: Optional[int] = Field(None, ge=0)
|
|
|
|
|
|
config: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
|
|
|
|
|
|
# 复用相同的验证器
|
|
|
|
|
|
_sanitize_text = field_validator("display_name", "description")(
|
|
|
|
|
|
CreateProviderRequest.sanitize_text.__func__
|
|
|
|
|
|
)
|
|
|
|
|
|
_validate_website = field_validator("website")(CreateProviderRequest.validate_website.__func__)
|
|
|
|
|
|
_validate_billing_type = field_validator("billing_type")(
|
|
|
|
|
|
CreateProviderRequest.validate_billing_type.__func__
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CreateEndpointRequest(BaseModel):
|
|
|
|
|
|
"""创建 Endpoint 请求"""
|
|
|
|
|
|
|
|
|
|
|
|
provider_id: str = Field(..., description="Provider ID")
|
|
|
|
|
|
name: str = Field(..., min_length=1, max_length=100, description="Endpoint 名称")
|
|
|
|
|
|
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
|
|
|
|
|
|
api_format: str = Field(..., description="API 格式(CLAUDE 或 OPENAI)")
|
|
|
|
|
|
custom_path: Optional[str] = Field(None, max_length=200, description="自定义路径")
|
|
|
|
|
|
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
|
|
|
|
|
|
is_active: Optional[bool] = Field(True, description="是否启用")
|
|
|
|
|
|
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
|
|
|
|
|
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
|
|
|
|
|
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
2025-12-18 14:42:06 +08:00
|
|
|
|
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
@field_validator("name")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_name(cls, v: str) -> str:
|
|
|
|
|
|
"""验证名称"""
|
|
|
|
|
|
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
|
|
|
|
|
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("base_url")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_base_url(cls, v: str) -> str:
|
|
|
|
|
|
"""验证 API URL"""
|
|
|
|
|
|
if not re.match(r"^https?://", v, re.IGNORECASE):
|
|
|
|
|
|
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
|
|
|
|
|
|
|
|
|
|
|
return v.rstrip("/") # 移除末尾斜杠
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("api_format")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_api_format(cls, v: str) -> str:
|
|
|
|
|
|
"""验证 API 格式"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
APIFormat(v)
|
|
|
|
|
|
return v
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
valid_formats = [f.value for f in APIFormat]
|
|
|
|
|
|
raise ValueError(f"无效的 API 格式,有效值为: {', '.join(valid_formats)}")
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("custom_path")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_custom_path(cls, v: Optional[str]) -> Optional[str]:
|
|
|
|
|
|
"""验证自定义路径"""
|
|
|
|
|
|
if v is None:
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
# 确保路径不包含危险字符
|
|
|
|
|
|
if not re.match(r"^[/a-zA-Z0-9_-]+$", v):
|
|
|
|
|
|
raise ValueError("路径只能包含字母、数字、斜杠、下划线和连字符")
|
|
|
|
|
|
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UpdateEndpointRequest(BaseModel):
|
|
|
|
|
|
"""更新 Endpoint 请求"""
|
|
|
|
|
|
|
|
|
|
|
|
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
|
|
|
|
|
base_url: Optional[str] = Field(None, min_length=1, max_length=500)
|
|
|
|
|
|
api_format: Optional[str] = None
|
|
|
|
|
|
custom_path: Optional[str] = Field(None, max_length=200)
|
|
|
|
|
|
priority: Optional[int] = Field(None, ge=0, le=1000)
|
|
|
|
|
|
is_active: Optional[bool] = None
|
|
|
|
|
|
rpm_limit: Optional[int] = Field(None, ge=0)
|
|
|
|
|
|
concurrent_limit: Optional[int] = Field(None, ge=0)
|
|
|
|
|
|
config: Optional[Dict[str, Any]] = None
|
2025-12-18 14:42:06 +08:00
|
|
|
|
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
# 复用验证器
|
|
|
|
|
|
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)
|
|
|
|
|
|
_validate_base_url = field_validator("base_url")(
|
|
|
|
|
|
CreateEndpointRequest.validate_base_url.__func__
|
|
|
|
|
|
)
|
|
|
|
|
|
_validate_api_format = field_validator("api_format")(
|
|
|
|
|
|
CreateEndpointRequest.validate_api_format.__func__
|
|
|
|
|
|
)
|
|
|
|
|
|
_validate_custom_path = field_validator("custom_path")(
|
|
|
|
|
|
CreateEndpointRequest.validate_custom_path.__func__
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CreateAPIKeyRequest(BaseModel):
|
|
|
|
|
|
"""创建 API Key 请求"""
|
|
|
|
|
|
|
|
|
|
|
|
endpoint_id: str = Field(..., description="Endpoint ID")
|
|
|
|
|
|
api_key: str = Field(..., min_length=1, max_length=500, description="API Key")
|
|
|
|
|
|
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
|
|
|
|
|
|
is_active: Optional[bool] = Field(True, description="是否启用")
|
|
|
|
|
|
max_rpm: Optional[int] = Field(None, ge=0, description="最大 RPM")
|
|
|
|
|
|
max_concurrent: Optional[int] = Field(None, ge=0, description="最大并发")
|
|
|
|
|
|
notes: Optional[str] = Field(None, max_length=500, description="备注")
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("api_key")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_api_key(cls, v: str) -> str:
|
|
|
|
|
|
"""验证 API Key"""
|
|
|
|
|
|
# 移除首尾空白
|
|
|
|
|
|
v = v.strip()
|
|
|
|
|
|
|
|
|
|
|
|
# 检查最小长度
|
|
|
|
|
|
if len(v) < 10:
|
|
|
|
|
|
raise ValueError("API Key 长度不能少于 10 个字符")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查危险字符(不应包含 SQL 注入字符)
|
|
|
|
|
|
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
|
|
|
|
|
|
for char in dangerous_chars:
|
|
|
|
|
|
if char in v:
|
|
|
|
|
|
raise ValueError(f"API Key 包含非法字符: {char}")
|
|
|
|
|
|
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("notes")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def sanitize_notes(cls, v: Optional[str]) -> Optional[str]:
|
|
|
|
|
|
"""清理备注"""
|
|
|
|
|
|
if v is None:
|
|
|
|
|
|
return v
|
|
|
|
|
|
# 复用文本清理逻辑
|
|
|
|
|
|
return CreateProviderRequest.sanitize_text(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UpdateUserRequest(BaseModel):
|
|
|
|
|
|
"""更新用户请求"""
|
|
|
|
|
|
|
|
|
|
|
|
username: Optional[str] = Field(None, min_length=1, max_length=50)
|
|
|
|
|
|
email: Optional[str] = Field(None, max_length=100)
|
|
|
|
|
|
quota_usd: Optional[float] = Field(None, ge=0)
|
|
|
|
|
|
is_active: Optional[bool] = None
|
|
|
|
|
|
role: Optional[str] = None
|
|
|
|
|
|
allowed_providers: Optional[List[str]] = Field(None, description="允许使用的提供商 ID 列表")
|
|
|
|
|
|
allowed_endpoints: Optional[List[str]] = Field(None, description="允许使用的端点 ID 列表")
|
|
|
|
|
|
allowed_models: Optional[List[str]] = Field(None, description="允许使用的模型名称列表")
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("username")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_username(cls, v: Optional[str]) -> Optional[str]:
|
|
|
|
|
|
"""验证用户名"""
|
|
|
|
|
|
if v is None:
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
|
|
|
|
|
raise ValueError("用户名只能包含字母、数字、下划线和连字符")
|
|
|
|
|
|
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("email")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_email(cls, v: Optional[str]) -> Optional[str]:
|
|
|
|
|
|
"""验证邮箱"""
|
|
|
|
|
|
if v is None:
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
# 简单的邮箱格式验证
|
|
|
|
|
|
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
|
|
|
|
|
if not re.match(email_pattern, v):
|
|
|
|
|
|
raise ValueError("邮箱格式不正确")
|
|
|
|
|
|
|
|
|
|
|
|
return v.lower()
|
|
|
|
|
|
|
|
|
|
|
|
@field_validator("role")
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def validate_role(cls, v: Optional[str]) -> Optional[str]:
|
|
|
|
|
|
"""验证角色"""
|
|
|
|
|
|
if v is None:
|
|
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
valid_roles = ["admin", "user"]
|
|
|
|
|
|
if v not in valid_roles:
|
|
|
|
|
|
raise ValueError(f"无效的角色,有效值为: {', '.join(valid_roles)}")
|
|
|
|
|
|
|
|
|
|
|
|
return v
|