mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 16:22:27 +08:00
Initial commit
This commit is contained in:
353
src/models/admin_requests.py
Normal file
353
src/models/admin_requests.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
管理接口的 Pydantic 请求模型
|
||||
|
||||
提供完整的输入验证和安全过滤
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from src.core.enums import APIFormat, ProviderBillingType
|
||||
|
||||
|
||||
class CreateProviderRequest(BaseModel):
|
||||
"""创建 Provider 请求"""
|
||||
|
||||
name: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="Provider 名称(英文字母、数字、下划线、连字符)",
|
||||
)
|
||||
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
|
||||
description: Optional[str] = Field(None, max_length=1000, description="描述")
|
||||
website: Optional[str] = Field(None, max_length=500, description="官网地址")
|
||||
billing_type: Optional[str] = Field(
|
||||
ProviderBillingType.PAY_AS_YOU_GO.value, description="计费类型"
|
||||
)
|
||||
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="周期配额(美元)")
|
||||
quota_reset_day: Optional[int] = Field(30, ge=1, le=365, description="配额重置周期(天数)")
|
||||
quota_last_reset_at: Optional[datetime] = Field(None, description="当前周期开始时间")
|
||||
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
|
||||
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
||||
provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)")
|
||||
is_active: Optional[bool] = Field(True, description="是否启用")
|
||||
rate_limit: Optional[int] = Field(None, ge=0, description="速率限制")
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""验证名称格式"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
|
||||
|
||||
# SQL 注入防护:检查危险关键字
|
||||
dangerous_keywords = [
|
||||
"SELECT",
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"CREATE",
|
||||
"ALTER",
|
||||
"EXEC",
|
||||
"UNION",
|
||||
"OR",
|
||||
"AND",
|
||||
"--",
|
||||
";",
|
||||
"'",
|
||||
'"',
|
||||
"<",
|
||||
">",
|
||||
]
|
||||
upper_name = v.upper()
|
||||
for keyword in dangerous_keywords:
|
||||
if keyword in upper_name:
|
||||
raise ValueError(f"名称包含禁止的字符或关键字: {keyword}")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("display_name", "description")
|
||||
@classmethod
|
||||
def sanitize_text(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""清理文本输入,防止 XSS"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# 移除潜在的脚本标签
|
||||
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
|
||||
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE) # 移除事件处理器
|
||||
|
||||
# 移除危险的 HTML 标签
|
||||
dangerous_tags = ["script", "iframe", "object", "embed", "link", "style"]
|
||||
for tag in dangerous_tags:
|
||||
v = re.sub(rf"<{tag}[^>]*>", "", v, flags=re.IGNORECASE)
|
||||
v = re.sub(rf"</{tag}>", "", v, flags=re.IGNORECASE)
|
||||
|
||||
return v.strip()
|
||||
|
||||
@field_validator("website")
|
||||
@classmethod
|
||||
def validate_website(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证网站地址"""
|
||||
if v is None or v.strip() == "":
|
||||
return None
|
||||
|
||||
v = v.strip()
|
||||
|
||||
# 自动补全 https:// 前缀
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
v = f"https://{v}"
|
||||
|
||||
# 防止 SSRF 攻击:禁止内网地址
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
r"169\.254\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("billing_type")
|
||||
@classmethod
|
||||
def validate_billing_type(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证计费类型"""
|
||||
if v is None:
|
||||
return ProviderBillingType.PAY_AS_YOU_GO.value
|
||||
|
||||
try:
|
||||
ProviderBillingType(v)
|
||||
return v
|
||||
except ValueError:
|
||||
valid_types = [t.value for t in ProviderBillingType]
|
||||
raise ValueError(f"无效的计费类型,有效值为: {', '.join(valid_types)}")
|
||||
|
||||
|
||||
class UpdateProviderRequest(BaseModel):
|
||||
"""更新 Provider 请求"""
|
||||
|
||||
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
website: Optional[str] = Field(None, max_length=500)
|
||||
billing_type: Optional[str] = None
|
||||
monthly_quota_usd: Optional[float] = Field(None, ge=0)
|
||||
quota_reset_day: Optional[int] = Field(None, ge=1, le=365)
|
||||
quota_last_reset_at: Optional[datetime] = None
|
||||
quota_expires_at: Optional[datetime] = None
|
||||
rpm_limit: Optional[int] = Field(None, ge=0)
|
||||
provider_priority: Optional[int] = Field(None, ge=0, le=1000)
|
||||
is_active: Optional[bool] = None
|
||||
rate_limit: Optional[int] = Field(None, ge=0)
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0)
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 复用相同的验证器
|
||||
_sanitize_text = field_validator("display_name", "description")(
|
||||
CreateProviderRequest.sanitize_text.__func__
|
||||
)
|
||||
_validate_website = field_validator("website")(CreateProviderRequest.validate_website.__func__)
|
||||
_validate_billing_type = field_validator("billing_type")(
|
||||
CreateProviderRequest.validate_billing_type.__func__
|
||||
)
|
||||
|
||||
|
||||
class CreateEndpointRequest(BaseModel):
|
||||
"""创建 Endpoint 请求"""
|
||||
|
||||
provider_id: str = Field(..., description="Provider ID")
|
||||
name: str = Field(..., min_length=1, max_length=100, description="Endpoint 名称")
|
||||
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
|
||||
api_format: str = Field(..., description="API 格式(CLAUDE 或 OPENAI)")
|
||||
custom_path: Optional[str] = Field(None, max_length=200, description="自定义路径")
|
||||
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
|
||||
is_active: Optional[bool] = Field(True, description="是否启用")
|
||||
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""验证名称"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
|
||||
return v
|
||||
|
||||
@field_validator("base_url")
|
||||
@classmethod
|
||||
def validate_base_url(cls, v: str) -> str:
|
||||
"""验证 API URL"""
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||
|
||||
# 防止 SSRF
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v.rstrip("/") # 移除末尾斜杠
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def validate_api_format(cls, v: str) -> str:
|
||||
"""验证 API 格式"""
|
||||
try:
|
||||
APIFormat(v)
|
||||
return v
|
||||
except ValueError:
|
||||
valid_formats = [f.value for f in APIFormat]
|
||||
raise ValueError(f"无效的 API 格式,有效值为: {', '.join(valid_formats)}")
|
||||
|
||||
@field_validator("custom_path")
|
||||
@classmethod
|
||||
def validate_custom_path(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证自定义路径"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# 确保路径不包含危险字符
|
||||
if not re.match(r"^[/a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("路径只能包含字母、数字、斜杠、下划线和连字符")
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class UpdateEndpointRequest(BaseModel):
|
||||
"""更新 Endpoint 请求"""
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
base_url: Optional[str] = Field(None, min_length=1, max_length=500)
|
||||
api_format: Optional[str] = None
|
||||
custom_path: Optional[str] = Field(None, max_length=200)
|
||||
priority: Optional[int] = Field(None, ge=0, le=1000)
|
||||
is_active: Optional[bool] = None
|
||||
rpm_limit: Optional[int] = Field(None, ge=0)
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0)
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 复用验证器
|
||||
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)
|
||||
_validate_base_url = field_validator("base_url")(
|
||||
CreateEndpointRequest.validate_base_url.__func__
|
||||
)
|
||||
_validate_api_format = field_validator("api_format")(
|
||||
CreateEndpointRequest.validate_api_format.__func__
|
||||
)
|
||||
_validate_custom_path = field_validator("custom_path")(
|
||||
CreateEndpointRequest.validate_custom_path.__func__
|
||||
)
|
||||
|
||||
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
"""创建 API Key 请求"""
|
||||
|
||||
endpoint_id: str = Field(..., description="Endpoint ID")
|
||||
api_key: str = Field(..., min_length=1, max_length=500, description="API Key")
|
||||
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
|
||||
is_active: Optional[bool] = Field(True, description="是否启用")
|
||||
max_rpm: Optional[int] = Field(None, ge=0, description="最大 RPM")
|
||||
max_concurrent: Optional[int] = Field(None, ge=0, description="最大并发")
|
||||
notes: Optional[str] = Field(None, max_length=500, description="备注")
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def validate_api_key(cls, v: str) -> str:
|
||||
"""验证 API Key"""
|
||||
# 移除首尾空白
|
||||
v = v.strip()
|
||||
|
||||
# 检查最小长度
|
||||
if len(v) < 10:
|
||||
raise ValueError("API Key 长度不能少于 10 个字符")
|
||||
|
||||
# 检查危险字符(不应包含 SQL 注入字符)
|
||||
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
|
||||
for char in dangerous_chars:
|
||||
if char in v:
|
||||
raise ValueError(f"API Key 包含非法字符: {char}")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("notes")
|
||||
@classmethod
|
||||
def sanitize_notes(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""清理备注"""
|
||||
if v is None:
|
||||
return v
|
||||
# 复用文本清理逻辑
|
||||
return CreateProviderRequest.sanitize_text(v)
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
"""更新用户请求"""
|
||||
|
||||
username: Optional[str] = Field(None, min_length=1, max_length=50)
|
||||
email: Optional[str] = Field(None, max_length=100)
|
||||
quota_usd: Optional[float] = Field(None, ge=0)
|
||||
is_active: Optional[bool] = None
|
||||
role: Optional[str] = None
|
||||
allowed_providers: Optional[List[str]] = Field(None, description="允许使用的提供商 ID 列表")
|
||||
allowed_endpoints: Optional[List[str]] = Field(None, description="允许使用的端点 ID 列表")
|
||||
allowed_models: Optional[List[str]] = Field(None, description="允许使用的模型名称列表")
|
||||
|
||||
@field_validator("username")
|
||||
@classmethod
|
||||
def validate_username(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证用户名"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("用户名只能包含字母、数字、下划线和连字符")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证邮箱"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# 简单的邮箱格式验证
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if not re.match(email_pattern, v):
|
||||
raise ValueError("邮箱格式不正确")
|
||||
|
||||
return v.lower()
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证角色"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
valid_roles = ["admin", "user"]
|
||||
if v not in valid_roles:
|
||||
raise ValueError(f"无效的角色,有效值为: {', '.join(valid_roles)}")
|
||||
|
||||
return v
|
||||
Reference in New Issue
Block a user