Files
Aether/src/models/admin_requests.py

365 lines
14 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
管理接口的 Pydantic 请求模型
提供完整的输入验证和安全过滤
"""
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator, model_validator
from src.core.enums import APIFormat, ProviderBillingType
class ProxyConfig(BaseModel):
"""代理配置"""
url: str = Field(..., description="代理 URL (http://, https://, socks5://)")
username: Optional[str] = Field(None, max_length=255, description="代理用户名")
password: Optional[str] = Field(None, max_length=500, description="代理密码")
enabled: bool = Field(True, description="是否启用代理false 时保留配置但不使用)")
@field_validator("url")
@classmethod
def validate_proxy_url(cls, v: str) -> str:
"""验证代理 URL 格式"""
from urllib.parse import urlparse
v = v.strip()
# 检查禁止的字符(防止注入)
if "\n" in v or "\r" in v:
raise ValueError("代理 URL 包含非法字符")
# 验证协议(不支持 SOCKS4
if not re.match(r"^(http|https|socks5)://", v, re.IGNORECASE):
raise ValueError("代理 URL 必须以 http://, https:// 或 socks5:// 开头")
# 验证 URL 结构
parsed = urlparse(v)
if not parsed.netloc:
raise ValueError("代理 URL 必须包含有效的 host")
# 禁止 URL 中内嵌认证信息,强制使用独立字段
if parsed.username or parsed.password:
raise ValueError("请勿在 URL 中包含用户名和密码,请使用独立的认证字段")
return v
2025-12-10 20:52:44 +08:00
class CreateProviderRequest(BaseModel):
"""创建 Provider 请求"""
name: str = Field(
...,
min_length=1,
max_length=100,
description="Provider 名称(英文字母、数字、下划线、连字符)",
)
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
description: Optional[str] = Field(None, max_length=1000, description="描述")
website: Optional[str] = Field(None, max_length=500, description="官网地址")
billing_type: Optional[str] = Field(
ProviderBillingType.PAY_AS_YOU_GO.value, description="计费类型"
)
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="周期配额(美元)")
quota_reset_day: Optional[int] = Field(30, ge=1, le=365, description="配额重置周期(天数)")
quota_last_reset_at: Optional[datetime] = Field(None, description="当前周期开始时间")
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)")
is_active: Optional[bool] = Field(True, description="是否启用")
rate_limit: Optional[int] = Field(None, ge=0, description="速率限制")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""验证名称格式"""
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
# SQL 注入防护:检查危险关键字
dangerous_keywords = [
"SELECT",
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"CREATE",
"ALTER",
"EXEC",
"UNION",
"OR",
"AND",
"--",
";",
"'",
'"',
"<",
">",
]
upper_name = v.upper()
for keyword in dangerous_keywords:
if keyword in upper_name:
raise ValueError(f"名称包含禁止的字符或关键字: {keyword}")
return v
@field_validator("display_name", "description")
@classmethod
def sanitize_text(cls, v: Optional[str]) -> Optional[str]:
"""清理文本输入,防止 XSS"""
if v is None:
return v
# 移除潜在的脚本标签
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE) # 移除事件处理器
# 移除危险的 HTML 标签
dangerous_tags = ["script", "iframe", "object", "embed", "link", "style"]
for tag in dangerous_tags:
v = re.sub(rf"<{tag}[^>]*>", "", v, flags=re.IGNORECASE)
v = re.sub(rf"</{tag}>", "", v, flags=re.IGNORECASE)
return v.strip()
@field_validator("website")
@classmethod
def validate_website(cls, v: Optional[str]) -> Optional[str]:
"""验证网站地址"""
if v is None or v.strip() == "":
return None
v = v.strip()
# 自动补全 https:// 前缀
if not re.match(r"^https?://", v, re.IGNORECASE):
v = f"https://{v}"
return v
@field_validator("billing_type")
@classmethod
def validate_billing_type(cls, v: Optional[str]) -> Optional[str]:
"""验证计费类型"""
if v is None:
return ProviderBillingType.PAY_AS_YOU_GO.value
try:
ProviderBillingType(v)
return v
except ValueError:
valid_types = [t.value for t in ProviderBillingType]
raise ValueError(f"无效的计费类型,有效值为: {', '.join(valid_types)}")
class UpdateProviderRequest(BaseModel):
"""更新 Provider 请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=1000)
website: Optional[str] = Field(None, max_length=500)
billing_type: Optional[str] = None
monthly_quota_usd: Optional[float] = Field(None, ge=0)
quota_reset_day: Optional[int] = Field(None, ge=1, le=365)
quota_last_reset_at: Optional[datetime] = None
quota_expires_at: Optional[datetime] = None
rpm_limit: Optional[int] = Field(None, ge=0)
provider_priority: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None
rate_limit: Optional[int] = Field(None, ge=0)
concurrent_limit: Optional[int] = Field(None, ge=0)
config: Optional[Dict[str, Any]] = None
# 复用相同的验证器
_sanitize_text = field_validator("display_name", "description")(
CreateProviderRequest.sanitize_text.__func__
)
_validate_website = field_validator("website")(CreateProviderRequest.validate_website.__func__)
_validate_billing_type = field_validator("billing_type")(
CreateProviderRequest.validate_billing_type.__func__
)
class CreateEndpointRequest(BaseModel):
"""创建 Endpoint 请求"""
provider_id: str = Field(..., description="Provider ID")
name: str = Field(..., min_length=1, max_length=100, description="Endpoint 名称")
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
api_format: str = Field(..., description="API 格式CLAUDE 或 OPENAI")
custom_path: Optional[str] = Field(None, max_length=200, description="自定义路径")
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
is_active: Optional[bool] = Field(True, description="是否启用")
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
2025-12-10 20:52:44 +08:00
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""验证名称"""
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
return v
@field_validator("base_url")
@classmethod
def validate_base_url(cls, v: str) -> str:
"""验证 API URL"""
if not re.match(r"^https?://", v, re.IGNORECASE):
raise ValueError("URL 必须以 http:// 或 https:// 开头")
return v.rstrip("/") # 移除末尾斜杠
@field_validator("api_format")
@classmethod
def validate_api_format(cls, v: str) -> str:
"""验证 API 格式"""
try:
APIFormat(v)
return v
except ValueError:
valid_formats = [f.value for f in APIFormat]
raise ValueError(f"无效的 API 格式,有效值为: {', '.join(valid_formats)}")
@field_validator("custom_path")
@classmethod
def validate_custom_path(cls, v: Optional[str]) -> Optional[str]:
"""验证自定义路径"""
if v is None:
return v
# 确保路径不包含危险字符
if not re.match(r"^[/a-zA-Z0-9_-]+$", v):
raise ValueError("路径只能包含字母、数字、斜杠、下划线和连字符")
return v
class UpdateEndpointRequest(BaseModel):
"""更新 Endpoint 请求"""
name: Optional[str] = Field(None, min_length=1, max_length=100)
base_url: Optional[str] = Field(None, min_length=1, max_length=500)
api_format: Optional[str] = None
custom_path: Optional[str] = Field(None, max_length=200)
priority: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None
rpm_limit: Optional[int] = Field(None, ge=0)
concurrent_limit: Optional[int] = Field(None, ge=0)
config: Optional[Dict[str, Any]] = None
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
2025-12-10 20:52:44 +08:00
# 复用验证器
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)
_validate_base_url = field_validator("base_url")(
CreateEndpointRequest.validate_base_url.__func__
)
_validate_api_format = field_validator("api_format")(
CreateEndpointRequest.validate_api_format.__func__
)
_validate_custom_path = field_validator("custom_path")(
CreateEndpointRequest.validate_custom_path.__func__
)
class CreateAPIKeyRequest(BaseModel):
"""创建 API Key 请求"""
endpoint_id: str = Field(..., description="Endpoint ID")
api_key: str = Field(..., min_length=1, max_length=500, description="API Key")
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
is_active: Optional[bool] = Field(True, description="是否启用")
max_rpm: Optional[int] = Field(None, ge=0, description="最大 RPM")
max_concurrent: Optional[int] = Field(None, ge=0, description="最大并发")
notes: Optional[str] = Field(None, max_length=500, description="备注")
@field_validator("api_key")
@classmethod
def validate_api_key(cls, v: str) -> str:
"""验证 API Key"""
# 移除首尾空白
v = v.strip()
# 检查最小长度
if len(v) < 10:
raise ValueError("API Key 长度不能少于 10 个字符")
# 检查危险字符(不应包含 SQL 注入字符)
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
for char in dangerous_chars:
if char in v:
raise ValueError(f"API Key 包含非法字符: {char}")
return v
@field_validator("notes")
@classmethod
def sanitize_notes(cls, v: Optional[str]) -> Optional[str]:
"""清理备注"""
if v is None:
return v
# 复用文本清理逻辑
return CreateProviderRequest.sanitize_text(v)
class UpdateUserRequest(BaseModel):
"""更新用户请求"""
username: Optional[str] = Field(None, min_length=1, max_length=50)
email: Optional[str] = Field(None, max_length=100)
quota_usd: Optional[float] = Field(None, ge=0)
is_active: Optional[bool] = None
role: Optional[str] = None
allowed_providers: Optional[List[str]] = Field(None, description="允许使用的提供商 ID 列表")
allowed_endpoints: Optional[List[str]] = Field(None, description="允许使用的端点 ID 列表")
allowed_models: Optional[List[str]] = Field(None, description="允许使用的模型名称列表")
@field_validator("username")
@classmethod
def validate_username(cls, v: Optional[str]) -> Optional[str]:
"""验证用户名"""
if v is None:
return v
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("用户名只能包含字母、数字、下划线和连字符")
return v
@field_validator("email")
@classmethod
def validate_email(cls, v: Optional[str]) -> Optional[str]:
"""验证邮箱"""
if v is None:
return v
# 简单的邮箱格式验证
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
if not re.match(email_pattern, v):
raise ValueError("邮箱格式不正确")
return v.lower()
@field_validator("role")
@classmethod
def validate_role(cls, v: Optional[str]) -> Optional[str]:
"""验证角色"""
if v is None:
return v
valid_roles = ["admin", "user"]
if v not in valid_roles:
raise ValueError(f"无效的角色,有效值为: {', '.join(valid_roles)}")
return v