Files
Aether/src/models/admin_requests.py
2025-12-10 20:52:44 +08:00

354 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
管理接口的 Pydantic 请求模型
提供完整的输入验证和安全过滤
"""
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator, model_validator
from src.core.enums import APIFormat, ProviderBillingType
class CreateProviderRequest(BaseModel):
"""创建 Provider 请求"""
name: str = Field(
...,
min_length=1,
max_length=100,
description="Provider 名称(英文字母、数字、下划线、连字符)",
)
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
description: Optional[str] = Field(None, max_length=1000, description="描述")
website: Optional[str] = Field(None, max_length=500, description="官网地址")
billing_type: Optional[str] = Field(
ProviderBillingType.PAY_AS_YOU_GO.value, description="计费类型"
)
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="周期配额(美元)")
quota_reset_day: Optional[int] = Field(30, ge=1, le=365, description="配额重置周期(天数)")
quota_last_reset_at: Optional[datetime] = Field(None, description="当前周期开始时间")
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)")
is_active: Optional[bool] = Field(True, description="是否启用")
rate_limit: Optional[int] = Field(None, ge=0, description="速率限制")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""验证名称格式"""
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
# SQL 注入防护:检查危险关键字
dangerous_keywords = [
"SELECT",
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"CREATE",
"ALTER",
"EXEC",
"UNION",
"OR",
"AND",
"--",
";",
"'",
'"',
"<",
">",
]
upper_name = v.upper()
for keyword in dangerous_keywords:
if keyword in upper_name:
raise ValueError(f"名称包含禁止的字符或关键字: {keyword}")
return v
@field_validator("display_name", "description")
@classmethod
def sanitize_text(cls, v: Optional[str]) -> Optional[str]:
"""清理文本输入,防止 XSS"""
if v is None:
return v
# 移除潜在的脚本标签
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE) # 移除事件处理器
# 移除危险的 HTML 标签
dangerous_tags = ["script", "iframe", "object", "embed", "link", "style"]
for tag in dangerous_tags:
v = re.sub(rf"<{tag}[^>]*>", "", v, flags=re.IGNORECASE)
v = re.sub(rf"</{tag}>", "", v, flags=re.IGNORECASE)
return v.strip()
@field_validator("website")
@classmethod
def validate_website(cls, v: Optional[str]) -> Optional[str]:
"""验证网站地址"""
if v is None or v.strip() == "":
return None
v = v.strip()
# 自动补全 https:// 前缀
if not re.match(r"^https?://", v, re.IGNORECASE):
v = f"https://{v}"
# 防止 SSRF 攻击:禁止内网地址
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
r"169\.254\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v
@field_validator("billing_type")
@classmethod
def validate_billing_type(cls, v: Optional[str]) -> Optional[str]:
"""验证计费类型"""
if v is None:
return ProviderBillingType.PAY_AS_YOU_GO.value
try:
ProviderBillingType(v)
return v
except ValueError:
valid_types = [t.value for t in ProviderBillingType]
raise ValueError(f"无效的计费类型,有效值为: {', '.join(valid_types)}")
class UpdateProviderRequest(BaseModel):
"""更新 Provider 请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=1000)
website: Optional[str] = Field(None, max_length=500)
billing_type: Optional[str] = None
monthly_quota_usd: Optional[float] = Field(None, ge=0)
quota_reset_day: Optional[int] = Field(None, ge=1, le=365)
quota_last_reset_at: Optional[datetime] = None
quota_expires_at: Optional[datetime] = None
rpm_limit: Optional[int] = Field(None, ge=0)
provider_priority: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None
rate_limit: Optional[int] = Field(None, ge=0)
concurrent_limit: Optional[int] = Field(None, ge=0)
config: Optional[Dict[str, Any]] = None
# 复用相同的验证器
_sanitize_text = field_validator("display_name", "description")(
CreateProviderRequest.sanitize_text.__func__
)
_validate_website = field_validator("website")(CreateProviderRequest.validate_website.__func__)
_validate_billing_type = field_validator("billing_type")(
CreateProviderRequest.validate_billing_type.__func__
)
class CreateEndpointRequest(BaseModel):
"""创建 Endpoint 请求"""
provider_id: str = Field(..., description="Provider ID")
name: str = Field(..., min_length=1, max_length=100, description="Endpoint 名称")
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
api_format: str = Field(..., description="API 格式CLAUDE 或 OPENAI")
custom_path: Optional[str] = Field(None, max_length=200, description="自定义路径")
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
is_active: Optional[bool] = Field(True, description="是否启用")
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""验证名称"""
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
return v
@field_validator("base_url")
@classmethod
def validate_base_url(cls, v: str) -> str:
"""验证 API URL"""
if not re.match(r"^https?://", v, re.IGNORECASE):
raise ValueError("URL 必须以 http:// 或 https:// 开头")
# 防止 SSRF
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v.rstrip("/") # 移除末尾斜杠
@field_validator("api_format")
@classmethod
def validate_api_format(cls, v: str) -> str:
"""验证 API 格式"""
try:
APIFormat(v)
return v
except ValueError:
valid_formats = [f.value for f in APIFormat]
raise ValueError(f"无效的 API 格式,有效值为: {', '.join(valid_formats)}")
@field_validator("custom_path")
@classmethod
def validate_custom_path(cls, v: Optional[str]) -> Optional[str]:
"""验证自定义路径"""
if v is None:
return v
# 确保路径不包含危险字符
if not re.match(r"^[/a-zA-Z0-9_-]+$", v):
raise ValueError("路径只能包含字母、数字、斜杠、下划线和连字符")
return v
class UpdateEndpointRequest(BaseModel):
"""更新 Endpoint 请求"""
name: Optional[str] = Field(None, min_length=1, max_length=100)
base_url: Optional[str] = Field(None, min_length=1, max_length=500)
api_format: Optional[str] = None
custom_path: Optional[str] = Field(None, max_length=200)
priority: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None
rpm_limit: Optional[int] = Field(None, ge=0)
concurrent_limit: Optional[int] = Field(None, ge=0)
config: Optional[Dict[str, Any]] = None
# 复用验证器
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)
_validate_base_url = field_validator("base_url")(
CreateEndpointRequest.validate_base_url.__func__
)
_validate_api_format = field_validator("api_format")(
CreateEndpointRequest.validate_api_format.__func__
)
_validate_custom_path = field_validator("custom_path")(
CreateEndpointRequest.validate_custom_path.__func__
)
class CreateAPIKeyRequest(BaseModel):
"""创建 API Key 请求"""
endpoint_id: str = Field(..., description="Endpoint ID")
api_key: str = Field(..., min_length=1, max_length=500, description="API Key")
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
is_active: Optional[bool] = Field(True, description="是否启用")
max_rpm: Optional[int] = Field(None, ge=0, description="最大 RPM")
max_concurrent: Optional[int] = Field(None, ge=0, description="最大并发")
notes: Optional[str] = Field(None, max_length=500, description="备注")
@field_validator("api_key")
@classmethod
def validate_api_key(cls, v: str) -> str:
"""验证 API Key"""
# 移除首尾空白
v = v.strip()
# 检查最小长度
if len(v) < 10:
raise ValueError("API Key 长度不能少于 10 个字符")
# 检查危险字符(不应包含 SQL 注入字符)
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
for char in dangerous_chars:
if char in v:
raise ValueError(f"API Key 包含非法字符: {char}")
return v
@field_validator("notes")
@classmethod
def sanitize_notes(cls, v: Optional[str]) -> Optional[str]:
"""清理备注"""
if v is None:
return v
# 复用文本清理逻辑
return CreateProviderRequest.sanitize_text(v)
class UpdateUserRequest(BaseModel):
"""更新用户请求"""
username: Optional[str] = Field(None, min_length=1, max_length=50)
email: Optional[str] = Field(None, max_length=100)
quota_usd: Optional[float] = Field(None, ge=0)
is_active: Optional[bool] = None
role: Optional[str] = None
allowed_providers: Optional[List[str]] = Field(None, description="允许使用的提供商 ID 列表")
allowed_endpoints: Optional[List[str]] = Field(None, description="允许使用的端点 ID 列表")
allowed_models: Optional[List[str]] = Field(None, description="允许使用的模型名称列表")
@field_validator("username")
@classmethod
def validate_username(cls, v: Optional[str]) -> Optional[str]:
"""验证用户名"""
if v is None:
return v
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("用户名只能包含字母、数字、下划线和连字符")
return v
@field_validator("email")
@classmethod
def validate_email(cls, v: Optional[str]) -> Optional[str]:
"""验证邮箱"""
if v is None:
return v
# 简单的邮箱格式验证
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
if not re.match(email_pattern, v):
raise ValueError("邮箱格式不正确")
return v.lower()
@field_validator("role")
@classmethod
def validate_role(cls, v: Optional[str]) -> Optional[str]:
"""验证角色"""
if v is None:
return v
valid_roles = ["admin", "user"]
if v not in valid_roles:
raise ValueError(f"无效的角色,有效值为: {', '.join(valid_roles)}")
return v