Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

10
src/models/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
"""
统一的模型定义模块
"""
from .api import * # noqa: F401, F403
from .claude import * # noqa: F401, F403
from .database import * # noqa: F401, F403
from .openai import * # noqa: F401, F403
__all__ = ["claude", "database", "openai", "api"]

View File

@@ -0,0 +1,353 @@
"""
管理接口的 Pydantic 请求模型
提供完整的输入验证和安全过滤
"""
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator, model_validator
from src.core.enums import APIFormat, ProviderBillingType
class CreateProviderRequest(BaseModel):
"""创建 Provider 请求"""
name: str = Field(
...,
min_length=1,
max_length=100,
description="Provider 名称(英文字母、数字、下划线、连字符)",
)
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
description: Optional[str] = Field(None, max_length=1000, description="描述")
website: Optional[str] = Field(None, max_length=500, description="官网地址")
billing_type: Optional[str] = Field(
ProviderBillingType.PAY_AS_YOU_GO.value, description="计费类型"
)
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="周期配额(美元)")
quota_reset_day: Optional[int] = Field(30, ge=1, le=365, description="配额重置周期(天数)")
quota_last_reset_at: Optional[datetime] = Field(None, description="当前周期开始时间")
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)")
is_active: Optional[bool] = Field(True, description="是否启用")
rate_limit: Optional[int] = Field(None, ge=0, description="速率限制")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""验证名称格式"""
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
# SQL 注入防护:检查危险关键字
dangerous_keywords = [
"SELECT",
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"CREATE",
"ALTER",
"EXEC",
"UNION",
"OR",
"AND",
"--",
";",
"'",
'"',
"<",
">",
]
upper_name = v.upper()
for keyword in dangerous_keywords:
if keyword in upper_name:
raise ValueError(f"名称包含禁止的字符或关键字: {keyword}")
return v
@field_validator("display_name", "description")
@classmethod
def sanitize_text(cls, v: Optional[str]) -> Optional[str]:
"""清理文本输入,防止 XSS"""
if v is None:
return v
# 移除潜在的脚本标签
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE) # 移除事件处理器
# 移除危险的 HTML 标签
dangerous_tags = ["script", "iframe", "object", "embed", "link", "style"]
for tag in dangerous_tags:
v = re.sub(rf"<{tag}[^>]*>", "", v, flags=re.IGNORECASE)
v = re.sub(rf"</{tag}>", "", v, flags=re.IGNORECASE)
return v.strip()
@field_validator("website")
@classmethod
def validate_website(cls, v: Optional[str]) -> Optional[str]:
"""验证网站地址"""
if v is None or v.strip() == "":
return None
v = v.strip()
# 自动补全 https:// 前缀
if not re.match(r"^https?://", v, re.IGNORECASE):
v = f"https://{v}"
# 防止 SSRF 攻击:禁止内网地址
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
r"169\.254\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v
@field_validator("billing_type")
@classmethod
def validate_billing_type(cls, v: Optional[str]) -> Optional[str]:
"""验证计费类型"""
if v is None:
return ProviderBillingType.PAY_AS_YOU_GO.value
try:
ProviderBillingType(v)
return v
except ValueError:
valid_types = [t.value for t in ProviderBillingType]
raise ValueError(f"无效的计费类型,有效值为: {', '.join(valid_types)}")
class UpdateProviderRequest(BaseModel):
"""更新 Provider 请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=1000)
website: Optional[str] = Field(None, max_length=500)
billing_type: Optional[str] = None
monthly_quota_usd: Optional[float] = Field(None, ge=0)
quota_reset_day: Optional[int] = Field(None, ge=1, le=365)
quota_last_reset_at: Optional[datetime] = None
quota_expires_at: Optional[datetime] = None
rpm_limit: Optional[int] = Field(None, ge=0)
provider_priority: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None
rate_limit: Optional[int] = Field(None, ge=0)
concurrent_limit: Optional[int] = Field(None, ge=0)
config: Optional[Dict[str, Any]] = None
# 复用相同的验证器
_sanitize_text = field_validator("display_name", "description")(
CreateProviderRequest.sanitize_text.__func__
)
_validate_website = field_validator("website")(CreateProviderRequest.validate_website.__func__)
_validate_billing_type = field_validator("billing_type")(
CreateProviderRequest.validate_billing_type.__func__
)
class CreateEndpointRequest(BaseModel):
"""创建 Endpoint 请求"""
provider_id: str = Field(..., description="Provider ID")
name: str = Field(..., min_length=1, max_length=100, description="Endpoint 名称")
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
api_format: str = Field(..., description="API 格式CLAUDE 或 OPENAI")
custom_path: Optional[str] = Field(None, max_length=200, description="自定义路径")
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
is_active: Optional[bool] = Field(True, description="是否启用")
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""验证名称"""
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
return v
@field_validator("base_url")
@classmethod
def validate_base_url(cls, v: str) -> str:
"""验证 API URL"""
if not re.match(r"^https?://", v, re.IGNORECASE):
raise ValueError("URL 必须以 http:// 或 https:// 开头")
# 防止 SSRF
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v.rstrip("/") # 移除末尾斜杠
@field_validator("api_format")
@classmethod
def validate_api_format(cls, v: str) -> str:
"""验证 API 格式"""
try:
APIFormat(v)
return v
except ValueError:
valid_formats = [f.value for f in APIFormat]
raise ValueError(f"无效的 API 格式,有效值为: {', '.join(valid_formats)}")
@field_validator("custom_path")
@classmethod
def validate_custom_path(cls, v: Optional[str]) -> Optional[str]:
"""验证自定义路径"""
if v is None:
return v
# 确保路径不包含危险字符
if not re.match(r"^[/a-zA-Z0-9_-]+$", v):
raise ValueError("路径只能包含字母、数字、斜杠、下划线和连字符")
return v
class UpdateEndpointRequest(BaseModel):
"""更新 Endpoint 请求"""
name: Optional[str] = Field(None, min_length=1, max_length=100)
base_url: Optional[str] = Field(None, min_length=1, max_length=500)
api_format: Optional[str] = None
custom_path: Optional[str] = Field(None, max_length=200)
priority: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None
rpm_limit: Optional[int] = Field(None, ge=0)
concurrent_limit: Optional[int] = Field(None, ge=0)
config: Optional[Dict[str, Any]] = None
# 复用验证器
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)
_validate_base_url = field_validator("base_url")(
CreateEndpointRequest.validate_base_url.__func__
)
_validate_api_format = field_validator("api_format")(
CreateEndpointRequest.validate_api_format.__func__
)
_validate_custom_path = field_validator("custom_path")(
CreateEndpointRequest.validate_custom_path.__func__
)
class CreateAPIKeyRequest(BaseModel):
"""创建 API Key 请求"""
endpoint_id: str = Field(..., description="Endpoint ID")
api_key: str = Field(..., min_length=1, max_length=500, description="API Key")
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
is_active: Optional[bool] = Field(True, description="是否启用")
max_rpm: Optional[int] = Field(None, ge=0, description="最大 RPM")
max_concurrent: Optional[int] = Field(None, ge=0, description="最大并发")
notes: Optional[str] = Field(None, max_length=500, description="备注")
@field_validator("api_key")
@classmethod
def validate_api_key(cls, v: str) -> str:
"""验证 API Key"""
# 移除首尾空白
v = v.strip()
# 检查最小长度
if len(v) < 10:
raise ValueError("API Key 长度不能少于 10 个字符")
# 检查危险字符(不应包含 SQL 注入字符)
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
for char in dangerous_chars:
if char in v:
raise ValueError(f"API Key 包含非法字符: {char}")
return v
@field_validator("notes")
@classmethod
def sanitize_notes(cls, v: Optional[str]) -> Optional[str]:
"""清理备注"""
if v is None:
return v
# 复用文本清理逻辑
return CreateProviderRequest.sanitize_text(v)
class UpdateUserRequest(BaseModel):
"""更新用户请求"""
username: Optional[str] = Field(None, min_length=1, max_length=50)
email: Optional[str] = Field(None, max_length=100)
quota_usd: Optional[float] = Field(None, ge=0)
is_active: Optional[bool] = None
role: Optional[str] = None
allowed_providers: Optional[List[str]] = Field(None, description="允许使用的提供商 ID 列表")
allowed_endpoints: Optional[List[str]] = Field(None, description="允许使用的端点 ID 列表")
allowed_models: Optional[List[str]] = Field(None, description="允许使用的模型名称列表")
@field_validator("username")
@classmethod
def validate_username(cls, v: Optional[str]) -> Optional[str]:
"""验证用户名"""
if v is None:
return v
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("用户名只能包含字母、数字、下划线和连字符")
return v
@field_validator("email")
@classmethod
def validate_email(cls, v: Optional[str]) -> Optional[str]:
"""验证邮箱"""
if v is None:
return v
# 简单的邮箱格式验证
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
if not re.match(email_pattern, v):
raise ValueError("邮箱格式不正确")
return v.lower()
@field_validator("role")
@classmethod
def validate_role(cls, v: Optional[str]) -> Optional[str]:
"""验证角色"""
if v is None:
return v
valid_roles = ["admin", "user"]
if v not in valid_roles:
raise ValueError(f"无效的角色,有效值为: {', '.join(valid_roles)}")
return v

716
src/models/api.py Normal file
View File

@@ -0,0 +1,716 @@
"""
API端点请求/响应模型定义
"""
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
from ..core.enums import UserRole
# ========== 认证相关 ==========
class LoginRequest(BaseModel):
"""登录请求"""
email: str = Field(..., min_length=3, max_length=255, description="邮箱地址")
password: str = Field(..., min_length=1, max_length=128, description="密码")
@classmethod
@field_validator("email")
def validate_email(cls, v):
"""验证邮箱格式"""
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
if not re.match(email_pattern, v):
raise ValueError("邮箱格式无效")
return v.lower()
@classmethod
@field_validator("password")
def validate_password(cls, v):
"""验证密码不为空且去除前后空格"""
v = v.strip()
if not v:
raise ValueError("密码不能为空")
return v
class LoginResponse(BaseModel):
"""登录响应"""
access_token: str
refresh_token: str # 刷新令牌
token_type: str = "bearer"
expires_in: int = 86400 # Token有效期默认24小时
user_id: str
email: str
username: str
role: str
class RefreshTokenRequest(BaseModel):
"""刷新令牌请求"""
refresh_token: str = Field(..., description="刷新令牌")
class RefreshTokenResponse(BaseModel):
"""刷新令牌响应"""
access_token: str
refresh_token: str # 返回新的刷新令牌
token_type: str = "bearer"
expires_in: int = 86400 # Token有效期默认24小时
class RegisterRequest(BaseModel):
"""注册请求"""
email: str = Field(..., min_length=3, max_length=255, description="邮箱地址")
username: str = Field(..., min_length=2, max_length=50, description="用户名")
password: str = Field(..., min_length=6, max_length=128, description="密码")
@classmethod
@field_validator("email")
def validate_email(cls, v):
"""验证邮箱格式"""
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
if not re.match(email_pattern, v):
raise ValueError("邮箱格式无效")
return v.lower()
@classmethod
@field_validator("username")
def validate_username(cls, v):
"""验证用户名格式"""
v = v.strip()
if not v:
raise ValueError("用户名不能为空")
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("用户名只能包含字母、数字、下划线和短横线")
return v
@classmethod
@field_validator("password")
def validate_password(cls, v):
"""验证密码强度"""
if len(v) < 6:
raise ValueError("密码至少需要6个字符")
if not re.search(r"[A-Z]", v):
raise ValueError("密码必须包含至少一个大写字母")
if not re.search(r"[a-z]", v):
raise ValueError("密码必须包含至少一个小写字母")
if not re.search(r"\d", v):
raise ValueError("密码必须包含至少一个数字")
return v
class RegisterResponse(BaseModel):
"""注册响应"""
user_id: str
email: str
username: str
message: str
class LogoutResponse(BaseModel):
"""登出响应"""
message: str
success: bool
# ========== 用户管理 ==========
class CreateUserRequest(BaseModel):
"""创建用户请求"""
username: str = Field(..., min_length=2, max_length=50, description="用户名")
password: str = Field(..., min_length=6, max_length=128, description="密码")
email: str = Field(..., min_length=3, max_length=255, description="邮箱地址")
role: Optional[UserRole] = Field(UserRole.USER, description="用户角色")
quota_usd: Optional[float] = Field(default=10.0, description="USD配额null表示无限制")
@field_validator("quota_usd", mode="before")
@classmethod
def validate_quota_usd(cls, v):
"""验证配额值允许null表示无限制"""
if v is None:
return None
if isinstance(v, (int, float)) and v >= 0 and v <= 10000:
return float(v)
if isinstance(v, (int, float)):
raise ValueError("配额必须在 0-10000 范围内")
return v
@classmethod
@field_validator("email")
def validate_email(cls, v):
"""验证邮箱格式"""
v = v.strip()
if not v:
raise ValueError("邮箱不能为空")
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
if not re.match(email_pattern, v):
raise ValueError("邮箱格式无效")
return v.lower()
@classmethod
@field_validator("username")
def validate_username(cls, v):
"""验证用户名格式"""
v = v.strip()
if not v:
raise ValueError("用户名不能为空")
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("用户名只能包含字母、数字、下划线和短横线")
return v
@classmethod
@field_validator("password")
def validate_password(cls, v):
"""验证密码强度"""
if len(v) < 6:
raise ValueError("密码至少需要6个字符")
if not re.search(r"[A-Z]", v):
raise ValueError("密码必须包含至少一个大写字母")
if not re.search(r"[a-z]", v):
raise ValueError("密码必须包含至少一个小写字母")
if not re.search(r"\d", v):
raise ValueError("密码必须包含至少一个数字")
return v
class UpdateUserRequest(BaseModel):
"""更新用户请求"""
email: Optional[str] = None
username: Optional[str] = None
password: Optional[str] = None
role: Optional[UserRole] = None
allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
quota_usd: Optional[float] = None
is_active: Optional[bool] = None
@field_validator("quota_usd", mode="before")
@classmethod
def validate_quota_usd(cls, v):
"""验证配额值允许null表示无限制"""
if v is None:
return None
if isinstance(v, (int, float)) and v >= 0 and v <= 10000:
return float(v)
if isinstance(v, (int, float)):
raise ValueError("配额必须在 0-10000 范围内")
return v
class CreateApiKeyRequest(BaseModel):
"""创建API密钥请求"""
name: Optional[str] = None
allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
rate_limit: Optional[int] = 100
expire_days: Optional[int] = None # None = 永不过期,数字 = 多少天后过期
initial_balance_usd: Optional[float] = Field(
None, description="初始余额USD仅用于独立KeyNone = 无限制"
)
is_standalone: bool = Field(False, description="是否为独立余额Key给非注册用户使用")
auto_delete_on_expiry: bool = Field(
False, description="过期后是否自动删除True=物理删除False=仅禁用)"
)
class UserResponse(BaseModel):
"""用户响应"""
id: str
email: str
username: str
role: UserRole
allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
quota_usd: float
used_usd: float
is_active: bool
created_at: datetime
updated_at: datetime
last_login_at: Optional[datetime]
class ApiKeyResponse(BaseModel):
"""API密钥响应"""
id: str
user_id: str
key: Optional[str] = None # 仅在创建时返回完整密钥
key_display: Optional[str] = None # 脱敏后的密钥显示
name: Optional[str]
total_requests: int
total_tokens: int
total_cost_usd: float
allowed_providers: Optional[List[str]]
allowed_models: Optional[List[str]]
rate_limit: int
is_active: bool
expires_at: Optional[datetime] = None
balance_used_usd: float = 0.0
current_balance_usd: Optional[float] = None # NULL = 无限制
is_standalone: bool = False
force_capabilities: Optional[Dict[str, bool]] = None # 强制开启的能力
created_at: datetime
last_used_at: Optional[datetime]
# ========== 提供商管理 ==========
class ProviderCreate(BaseModel):
"""创建提供商请求
新架构说明:
- Provider 仅包含提供商的元数据和计费配置
- API格式、URL、认证等配置应在 ProviderEndpoint 中设置
- API密钥应在 ProviderAPIKey 中设置
"""
name: str = Field(..., min_length=1, max_length=100, description="提供商唯一标识")
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
description: Optional[str] = Field(None, description="提供商描述")
website: Optional[str] = Field(None, max_length=500, description="主站网站")
# Provider 级别的配置
rate_limit: Optional[int] = Field(None, description="每分钟请求限制")
concurrent_limit: Optional[int] = Field(None, description="并发请求限制")
config: Optional[dict] = Field(None, description="额外配置")
is_active: bool = Field(False, description="是否启用默认false需要配置API密钥后才能启用")
class ProviderUpdate(BaseModel):
"""更新提供商请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None
website: Optional[str] = Field(None, max_length=500)
api_format: Optional[str] = None
base_url: Optional[str] = None
headers: Optional[dict] = None
timeout: Optional[int] = Field(None, ge=1, le=600)
max_retries: Optional[int] = Field(None, ge=0, le=10)
priority: Optional[int] = None
weight: Optional[float] = Field(None, gt=0)
rate_limit: Optional[int] = None
concurrent_limit: Optional[int] = None
config: Optional[dict] = None
is_active: Optional[bool] = None
class ProviderResponse(BaseModel):
"""提供商响应"""
id: str
name: str
display_name: str
description: Optional[str]
website: Optional[str]
api_format: str
base_url: str
headers: Optional[dict]
timeout: int
max_retries: int
priority: int
weight: float
rate_limit: Optional[int]
concurrent_limit: Optional[int]
config: Optional[dict]
is_active: bool
created_at: datetime
updated_at: datetime
models_count: int = 0
active_models_count: int = 0
model_mappings_count: int = 0
api_keys_count: int = 0
class Config:
from_attributes = True
# ========== 模型管理 ==========
class ModelCreate(BaseModel):
"""创建模型请求 - 价格和能力字段可选,为空时使用 GlobalModel 默认值"""
provider_model_name: str = Field(
..., min_length=1, max_length=200, description="Provider 侧的模型名称"
)
global_model_id: str = Field(..., description="关联的 GlobalModel ID必填")
# 按次计费配置 - 可选,为空时使用 GlobalModel 默认值
price_per_request: Optional[float] = Field(
None, ge=0, description="每次请求固定费用,为空使用默认值"
)
# 阶梯计费配置 - 可选,为空时使用 GlobalModel 默认值
tiered_pricing: Optional[dict] = Field(
None, description="阶梯计费配置,为空使用 GlobalModel 默认值"
)
# 能力配置 - 可选,为空时使用 GlobalModel 默认值
supports_vision: Optional[bool] = Field(None, description="是否支持图像输入,为空使用默认值")
supports_function_calling: Optional[bool] = Field(
None, description="是否支持函数调用,为空使用默认值"
)
supports_streaming: Optional[bool] = Field(None, description="是否支持流式输出,为空使用默认值")
supports_extended_thinking: Optional[bool] = Field(
None, description="是否支持扩展思考,为空使用默认值"
)
is_active: bool = Field(True, description="是否启用")
config: Optional[dict] = Field(None, description="额外配置")
class ModelUpdate(BaseModel):
"""更新模型请求"""
provider_model_name: Optional[str] = Field(None, min_length=1, max_length=200)
global_model_id: Optional[str] = None
# 按次计费配置
price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
# 阶梯计费配置
tiered_pricing: Optional[dict] = Field(None, description="阶梯计费配置")
supports_vision: Optional[bool] = None
supports_function_calling: Optional[bool] = None
supports_streaming: Optional[bool] = None
supports_extended_thinking: Optional[bool] = None
is_active: Optional[bool] = None
is_available: Optional[bool] = None
config: Optional[dict] = None
class ModelResponse(BaseModel):
"""模型响应 - 包含 Model 配置和关联的 GlobalModel 信息
注意:价格和能力字段返回的是有效值(优先使用 Model 配置,否则使用 GlobalModel 默认值)
"""
id: str
provider_id: str
global_model_id: Optional[str]
provider_model_name: str
# 按次计费配置
price_per_request: Optional[float] = None
# 阶梯计费配置
tiered_pricing: Optional[dict] = None
# Provider 能力配置 - 可选,为空表示使用 GlobalModel 默认值
supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
supports_streaming: Optional[bool]
supports_extended_thinking: Optional[bool]
supports_image_generation: Optional[bool]
# 有效值(合并 Model 配置和 GlobalModel 默认值后的结果)
effective_tiered_pricing: Optional[dict] = None
effective_input_price: Optional[float] = None
effective_output_price: Optional[float] = None
effective_price_per_request: Optional[float] = None
effective_supports_vision: Optional[bool] = None
effective_supports_function_calling: Optional[bool] = None
effective_supports_streaming: Optional[bool] = None
effective_supports_extended_thinking: Optional[bool] = None
effective_supports_image_generation: Optional[bool] = None
# 状态
is_active: bool
is_available: bool
# 时间戳
created_at: datetime
updated_at: datetime
# 关联的 GlobalModel 信息(如果有)
global_model_name: Optional[str] = None
global_model_display_name: Optional[str] = None
class Config:
from_attributes = True
class ModelDetailResponse(BaseModel):
"""模型详细响应 - 包含所有字段(用于需要完整信息的场景)"""
id: str
provider_id: str
name: str
display_name: str
description: Optional[str]
icon_url: Optional[str]
tags: Optional[List[str]]
input_price_per_1m: float
output_price_per_1m: float
cache_creation_price_per_1m: Optional[float]
cache_read_price_per_1m: Optional[float]
supports_vision: bool
supports_function_calling: bool
supports_streaming: bool
is_active: bool
is_available: bool
config: Optional[dict]
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# ========== 模型映射 ==========
class ModelMappingCreate(BaseModel):
"""创建模型映射请求(源模型到目标模型的映射)"""
source_model: str = Field(..., min_length=1, max_length=200, description="源模型名或别名")
target_global_model_id: str = Field(..., description="目标 GlobalModel ID")
provider_id: Optional[str] = Field(None, description="Provider ID为空时表示全局别名")
mapping_type: str = Field(
"alias",
description="映射类型alias=按目标模型计费别名mapping=按源模型计费(降级映射)",
)
is_active: bool = Field(True, description="是否启用")
class ModelMappingUpdate(BaseModel):
"""更新模型映射请求"""
source_model: Optional[str] = Field(
None, min_length=1, max_length=200, description="源模型名或别名"
)
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
provider_id: Optional[str] = Field(None, description="Provider ID为空时表示全局别名")
mapping_type: Optional[str] = Field(
None, description="映射类型alias=按目标模型计费别名mapping=按源模型计费(降级映射)"
)
is_active: Optional[bool] = None
class ModelMappingResponse(BaseModel):
"""模型映射响应"""
id: str
source_model: str
target_global_model_id: str
target_global_model_name: Optional[str]
target_global_model_display_name: Optional[str]
provider_id: Optional[str]
provider_name: Optional[str]
scope: str = Field(..., description="global 或 provider")
mapping_type: str = Field(..., description="映射类型alias 或 mapping")
is_active: bool
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# ========== 系统设置 ==========
class SystemSettingsRequest(BaseModel):
"""系统设置请求"""
default_provider: Optional[str] = None
default_model: Optional[str] = None
enable_usage_tracking: Optional[bool] = None
class SystemSettingsResponse(BaseModel):
"""系统设置响应"""
default_provider: Optional[str]
default_model: Optional[str]
enable_usage_tracking: bool
# ========== 使用统计 ==========
class UsageStatsResponse(BaseModel):
"""使用统计响应"""
total_requests: int
total_tokens: int
total_cost_usd: float
daily_requests: int
daily_tokens: int
daily_cost_usd: float
model_usage: Dict[str, Dict[str, Any]]
provider_usage: Dict[str, Dict[str, Any]]
# ========== 公开API响应模型 ==========
class PublicProviderResponse(BaseModel):
"""公开的提供商信息响应"""
id: str
name: str
display_name: str
description: Optional[str]
website: Optional[str]
is_active: bool
provider_priority: int # 提供商优先级(数字越小越优先)
# 统计信息
models_count: int
active_models_count: int
mappings_count: int
endpoints_count: int # 端点总数
active_endpoints_count: int # 活跃端点数
class PublicModelResponse(BaseModel):
"""公开的模型信息响应"""
id: str
provider_id: str
provider_name: str
provider_display_name: str
name: str
display_name: str
description: Optional[str] = None
tags: Optional[List[str]] = None
icon_url: Optional[str] = None
# 价格信息
input_price_per_1m: Optional[float] = None
output_price_per_1m: Optional[float] = None
cache_creation_price_per_1m: Optional[float] = None
cache_read_price_per_1m: Optional[float] = None
# 功能支持
supports_vision: Optional[bool] = None
supports_function_calling: Optional[bool] = None
supports_streaming: Optional[bool] = None
is_active: bool = True
class PublicModelMappingResponse(BaseModel):
"""公开的模型映射信息响应"""
id: str
source_model: str
target_global_model_id: str
target_global_model_name: Optional[str]
target_global_model_display_name: Optional[str]
provider_id: Optional[str] = None
scope: str = Field(..., description="global 或 provider")
is_active: bool
class ProviderStatsResponse(BaseModel):
"""提供商统计信息响应"""
total_providers: int
active_providers: int
total_models: int
active_models: int
total_mappings: int
supported_formats: List[str]
class PublicGlobalModelResponse(BaseModel):
"""公开的 GlobalModel 信息响应(用户可见)"""
id: str
name: str
display_name: Optional[str] = None
description: Optional[str] = None
icon_url: Optional[str] = None
is_active: bool = True
# 按次计费配置
default_price_per_request: Optional[float] = None
# 阶梯计费配置
default_tiered_pricing: Optional[dict] = None
# 默认能力
default_supports_vision: bool = False
default_supports_function_calling: bool = False
default_supports_streaming: bool = True
default_supports_extended_thinking: bool = False
# Key 能力配置
supported_capabilities: Optional[List[str]] = None
class PublicGlobalModelListResponse(BaseModel):
"""公开的 GlobalModel 列表响应"""
models: List[PublicGlobalModelResponse]
total: int
# ========== 个人中心相关模型 ==========
class UpdateProfileRequest(BaseModel):
"""更新个人信息请求"""
email: Optional[str] = None
username: Optional[str] = None
class UpdatePreferencesRequest(BaseModel):
"""更新偏好设置请求"""
avatar_url: Optional[str] = None
bio: Optional[str] = None
default_provider_id: Optional[int] = None
theme: Optional[str] = None
language: Optional[str] = None
timezone: Optional[str] = None
email_notifications: Optional[bool] = None
usage_alerts: Optional[bool] = None
announcement_notifications: Optional[bool] = None
class ChangePasswordRequest(BaseModel):
"""修改密码请求"""
old_password: str
new_password: str
class CreateMyApiKeyRequest(BaseModel):
"""创建我的API密钥请求"""
name: str
class ProviderConfig(BaseModel):
"""提供商配置"""
provider_id: str = Field(..., description="提供商ID")
priority: int = Field(100, description="优先级(越高越优先)")
weight: float = Field(1.0, description="负载均衡权重")
enabled: bool = Field(True, description="是否启用")
class UpdateApiKeyProvidersRequest(BaseModel):
"""更新API密钥可用提供商请求"""
allowed_providers: Optional[List[ProviderConfig]] = None # 提供商配置列表
# ========== 公告相关模型 ==========
class CreateAnnouncementRequest(BaseModel):
"""创建公告请求"""
title: str
content: str # 支持Markdown
type: str = "info" # info, warning, maintenance, important
priority: int = 0
is_pinned: bool = False
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
class UpdateAnnouncementRequest(BaseModel):
"""更新公告请求"""
title: Optional[str] = None
content: Optional[str] = None
type: Optional[str] = None
priority: Optional[int] = None
is_active: Optional[bool] = None
is_pinned: Optional[bool] = None
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None

72
src/models/api_key.py Normal file
View File

@@ -0,0 +1,72 @@
"""
Provider API Key相关的API模型
"""
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
class ProviderAPIKeyBase(BaseModel):
"""Provider API Key基础模型"""
name: Optional[str] = Field(None, description="密钥名称/备注")
api_key: str = Field(..., description="API密钥")
rate_limit: Optional[int] = Field(None, description="速率限制(每分钟请求数)")
daily_limit: Optional[int] = Field(None, description="每日请求限制")
monthly_limit: Optional[int] = Field(None, description="每月请求限制")
priority: int = Field(0, description="优先级(越高越优先使用)")
is_active: bool = Field(True, description="是否启用")
expires_at: Optional[datetime] = Field(None, description="过期时间")
class ProviderAPIKeyCreate(ProviderAPIKeyBase):
"""创建Provider API Key请求"""
pass
class ProviderAPIKeyUpdate(BaseModel):
"""更新Provider API Key请求"""
name: Optional[str] = None
api_key: Optional[str] = None
rate_limit: Optional[int] = None
daily_limit: Optional[int] = None
monthly_limit: Optional[int] = None
priority: Optional[int] = None
is_active: Optional[bool] = None
expires_at: Optional[datetime] = None
class ProviderAPIKeyResponse(ProviderAPIKeyBase):
"""Provider API Key响应"""
id: str
provider_id: str
request_count: Optional[int] = Field(0, description="请求次数")
error_count: Optional[int] = Field(0, description="错误次数")
last_used_at: Optional[datetime] = Field(None, description="最后使用时间")
last_error_at: Optional[datetime] = Field(None, description="最后错误时间")
last_error_msg: Optional[str] = Field(None, description="最后错误信息")
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ProviderAPIKeyStats(BaseModel):
"""Provider API Key统计信息"""
id: str
name: Optional[str]
request_count: int
error_count: int
success_rate: float
last_used_at: Optional[datetime]
is_active: bool
is_expired: bool
remaining_daily: Optional[int] = Field(None, description="今日剩余请求数")
remaining_monthly: Optional[int] = Field(None, description="本月剩余请求数")

118
src/models/claude.py Normal file
View File

@@ -0,0 +1,118 @@
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field
# 配置允许额外字段以支持API的新特性
class BaseModelWithExtras(BaseModel):
model_config = ConfigDict(extra="allow")
class ClaudeContentBlockText(BaseModelWithExtras):
type: Literal["text"]
text: str
class ClaudeContentBlockImage(BaseModelWithExtras):
type: Literal["image"]
source: Dict[str, Any]
class ClaudeContentBlockToolUse(BaseModelWithExtras):
type: Literal["tool_use"]
id: str
name: str
input: Dict[str, Any]
class ClaudeContentBlockToolResult(BaseModelWithExtras):
type: Literal["tool_result"]
tool_use_id: str
content: Union[str, List[Dict[str, Any]], Dict[str, Any]]
class ClaudeContentBlockThinking(BaseModelWithExtras):
type: Literal["thinking"]
thinking: str
class ClaudeSystemContent(BaseModelWithExtras):
type: Literal["text"]
text: str
class ClaudeMessage(BaseModelWithExtras):
role: Literal["user", "assistant"]
# 宽松的内容类型定义 - 接受字符串或任意字典列表
# 作为转发代理,不应该严格限制内容块类型,以支持API的新特性
content: Union[str, List[Dict[str, Any]]]
class ClaudeTool(BaseModelWithExtras):
name: str
description: Optional[str] = None
input_schema: Dict[str, Any]
class ClaudeThinkingConfig(BaseModelWithExtras):
enabled: bool = True
class ClaudeMessagesRequest(BaseModelWithExtras):
model: str
max_tokens: int
messages: List[ClaudeMessage]
# 宽松的system类型 - 接受字符串、字典列表或任意字典
system: Optional[Union[str, List[Dict[str, Any]], Dict[str, Any]]] = None
stop_sequences: Optional[List[str]] = None
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = None
top_k: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
tools: Optional[List[Dict[str, Any]]] = None # 改为更宽松的类型
tool_choice: Optional[Dict[str, Any]] = None
thinking: Optional[Dict[str, Any]] = None # 改为更宽松的类型
class ClaudeTokenCountRequest(BaseModelWithExtras):
model: str
messages: List[ClaudeMessage]
# 宽松的类型定义以支持API新特性
system: Optional[Union[str, List[Dict[str, Any]], Dict[str, Any]]] = None
tools: Optional[List[Dict[str, Any]]] = None
thinking: Optional[Dict[str, Any]] = None
tool_choice: Optional[Dict[str, Any]] = None
# ---------------------------------------------------------------------------
# 响应模型
# ---------------------------------------------------------------------------
class ClaudeResponseUsage(BaseModelWithExtras):
"""Claude 响应 token 使用量"""
input_tokens: int = 0
output_tokens: int = 0
cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
class ClaudeResponse(BaseModelWithExtras):
"""
Claude Messages API 响应模型
对应 POST /v1/messages 端点的响应体。
"""
id: str
model: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: List[Dict[str, Any]]
stop_reason: Optional[str] = None
stop_sequence: Optional[str] = None
usage: Optional[ClaudeResponseUsage] = None
context_management: Optional[Dict[str, Any]] = None
container: Optional[Dict[str, Any]] = None

1341
src/models/database.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,119 @@
"""
数据库模型扩展 - 新增的提供商策略相关表
"""
import uuid
from datetime import datetime, timezone
from sqlalchemy import (
Boolean,
Column,
DateTime,
Float,
ForeignKey,
Index,
Integer,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import relationship
from .database import Base
class ApiKeyProviderMapping(Base):
"""
API Key 和 Provider 的关联映射表
用途:管理员为特定的 API Key 指定提供商
- 如果存在映射:该 API Key 只能使用指定的提供商(无负载均衡和故障转移)
- 如果不存在映射:该 API Key 使用所有可用提供商(系统默认优先级,有负载均衡和故障转移)
注意priority_adjustment 和 weight_multiplier 字段保留但在当前版本不使用
"""
__tablename__ = "api_key_provider_mappings"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
api_key_id = Column(
String(36), ForeignKey("api_keys.id", ondelete="CASCADE"), nullable=False, index=True
)
provider_id = Column(
String(36), ForeignKey("providers.id", ondelete="CASCADE"), nullable=False, index=True
)
# 管理员设置的优先级调整(非用户自己设置)
priority_adjustment = Column(Integer, default=0) # 优先级调整值(可正可负)
weight_multiplier = Column(Float, default=1.0) # 权重乘数(>0
# 是否启用
is_enabled = Column(Boolean, default=True, nullable=False)
# 时间戳
created_at = Column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
nullable=False,
)
# 关系
api_key = relationship("ApiKey", back_populates="provider_mappings")
provider = relationship("Provider", back_populates="api_key_mappings")
# 唯一约束
__table_args__ = (
UniqueConstraint("api_key_id", "provider_id", name="uq_apikey_provider"),
Index("idx_apikey_provider_enabled", "api_key_id", "is_enabled"),
)
class ProviderUsageTracking(Base):
"""提供商使用追踪 (用于RPM限流和健康检测)"""
__tablename__ = "provider_usage_tracking"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
provider_id = Column(
String(36), ForeignKey("providers.id", ondelete="CASCADE"), nullable=False, index=True
)
# 时间窗口
window_start = Column(DateTime(timezone=True), nullable=False, index=True)
window_end = Column(DateTime(timezone=True), nullable=False)
# 统计数据
total_requests = Column(Integer, default=0)
successful_requests = Column(Integer, default=0)
failed_requests = Column(Integer, default=0)
# 性能数据
avg_response_time_ms = Column(Float, default=0.0)
total_response_time_ms = Column(Float, default=0.0) # 用于计算平均值
# 成本数据
total_cost_usd = Column(Float, default=0.0)
# 时间戳
created_at = Column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
nullable=False,
)
# 关系
provider = relationship("Provider", back_populates="usage_tracking")
# 索引
__table_args__ = (
Index("idx_provider_window", "provider_id", "window_start"),
Index("idx_window_time", "window_start", "window_end"),
)

View File

@@ -0,0 +1,653 @@
"""
ProviderEndpoint 相关的 API 模型定义
"""
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
# ========== ProviderEndpoint CRUD ==========
class ProviderEndpointCreate(BaseModel):
"""创建 Endpoint 请求"""
provider_id: str = Field(..., description="Provider ID")
api_format: str = Field(..., description="API 格式 (CLAUDE, OPENAI, CLAUDE_CLI, OPENAI_CLI)")
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
# 请求配置
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
timeout: int = Field(default=300, ge=10, le=600, description="超时时间(秒)")
max_retries: int = Field(default=3, ge=0, le=10, description="最大重试次数")
# 限制
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制(请求/秒)")
# 额外配置
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置JSON")
@field_validator("api_format")
@classmethod
def validate_api_format(cls, v: str) -> str:
"""验证 API 格式"""
from src.core.enums import APIFormat
allowed = [fmt.value for fmt in APIFormat]
v_upper = v.upper()
if v_upper not in allowed:
raise ValueError(f"API 格式必须是 {allowed} 之一")
return v_upper
@field_validator("base_url")
@classmethod
def validate_base_url(cls, v: str) -> str:
"""验证 API URLSSRF 防护)"""
if not re.match(r"^https?://", v, re.IGNORECASE):
raise ValueError("URL 必须以 http:// 或 https:// 开头")
# 防止 SSRF 攻击:禁止内网地址
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
r"169\.254\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v.rstrip("/") # 移除末尾斜杠
class ProviderEndpointUpdate(BaseModel):
"""更新 Endpoint 请求"""
base_url: Optional[str] = Field(
default=None, min_length=1, max_length=500, description="API 基础 URL"
)
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)")
max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数")
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
is_active: Optional[bool] = Field(default=None, description="是否启用")
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
@field_validator("base_url")
@classmethod
def validate_base_url(cls, v: Optional[str]) -> Optional[str]:
"""验证 API URLSSRF 防护)"""
if v is None:
return v
if not re.match(r"^https?://", v, re.IGNORECASE):
raise ValueError("URL 必须以 http:// 或 https:// 开头")
# 防止 SSRF 攻击:禁止内网地址
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
r"169\.254\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v.rstrip("/") # 移除末尾斜杠
class ProviderEndpointResponse(BaseModel):
"""Endpoint 响应"""
id: str
provider_id: str
provider_name: str # 冗余字段,方便前端显示
# API 配置
api_format: str
base_url: str
# 请求配置
headers: Optional[Dict[str, str]] = None
timeout: int
max_retries: int
# 限制
max_concurrent: Optional[int] = None
rate_limit: Optional[int] = None
# 状态
is_active: bool
# 额外配置
config: Optional[Dict[str, Any]] = None
# 统计(从 Keys 聚合)
total_keys: int = Field(default=0, description="总 Key 数量")
active_keys: int = Field(default=0, description="活跃 Key 数量")
# 时间戳
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# ========== ProviderAPIKey 相关(新架构) ==========
class EndpointAPIKeyCreate(BaseModel):
"""为 Endpoint 添加 API Key"""
endpoint_id: str = Field(..., description="Endpoint ID")
api_key: str = Field(..., min_length=10, max_length=500, description="API Key将自动加密")
name: str = Field(..., min_length=1, max_length=100, description="密钥名称(必填,用于识别)")
# 成本计算
rate_multiplier: float = Field(
default=1.0, ge=0.01, description="成本倍率(真实成本 = 表面成本 × 倍率)"
)
# 优先级和限制(数字越小越优先)
internal_priority: int = Field(default=50, description="Endpoint 内部优先级(提供商优先模式)")
# max_concurrent: NULL=自适应模式(系统自动学习),数字=固定限制模式
max_concurrent: Optional[int] = Field(
default=None, ge=1, description="最大并发数NULL=自适应模式)"
)
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")
allowed_models: Optional[List[str]] = Field(
default=None, description="允许使用的模型列表null = 支持所有模型)"
)
# 能力标签
capabilities: Optional[Dict[str, bool]] = Field(
default=None, description="Key 能力标签,如 {'cache_1h': true, 'context_1m': true}"
)
# 缓存与熔断配置
cache_ttl_minutes: int = Field(
default=5, ge=0, le=60, description="缓存 TTL分钟0=禁用默认5分钟"
)
max_probe_interval_minutes: int = Field(
default=32, ge=2, le=32, description="熔断探测间隔(分钟),范围 2-32"
)
# 备注
note: Optional[str] = Field(default=None, max_length=500, description="备注说明(可选)")
@field_validator("api_key")
@classmethod
def validate_api_key(cls, v: str) -> str:
"""验证 API Key 安全性"""
# 移除首尾空白
v = v.strip()
# 检查最小长度
if len(v) < 10:
raise ValueError("API Key 长度不能少于 10 个字符")
# 检查危险字符SQL 注入防护)
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
for char in dangerous_chars:
if char in v:
raise ValueError(f"API Key 包含非法字符: {char}")
return v
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""验证名称(防止 XSS"""
# 移除危险的 HTML 标签
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE)
return v.strip()
@field_validator("note")
@classmethod
def validate_note(cls, v: Optional[str]) -> Optional[str]:
"""验证备注(防止 XSS"""
if v is None:
return v
# 移除危险的 HTML 标签
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE)
return v.strip()
class EndpointAPIKeyUpdate(BaseModel):
"""更新 Endpoint API Key"""
api_key: Optional[str] = Field(
default=None, min_length=10, max_length=500, description="API Key将自动加密"
)
name: Optional[str] = Field(default=None, min_length=1, max_length=100, description="密钥名称")
rate_multiplier: Optional[float] = Field(default=None, ge=0.01, description="成本倍率")
internal_priority: Optional[int] = Field(
default=None, description="Endpoint 内部优先级(提供商优先模式,数字越小越优先)"
)
global_priority: Optional[int] = Field(
default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)"
)
# 注意max_concurrent=None 表示不更新,要切换为自适应模式请使用专用 API
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")
allowed_models: Optional[List[str]] = Field(default=None, description="允许使用的模型列表")
capabilities: Optional[Dict[str, bool]] = Field(
default=None, description="Key 能力标签,如 {'cache_1h': true, 'context_1m': true}"
)
cache_ttl_minutes: Optional[int] = Field(
default=None, ge=0, le=60, description="缓存 TTL分钟0=禁用"
)
max_probe_interval_minutes: Optional[int] = Field(
default=None, ge=2, le=32, description="熔断探测间隔(分钟),范围 2-32"
)
is_active: Optional[bool] = Field(default=None, description="是否启用")
note: Optional[str] = Field(default=None, max_length=500, description="备注说明")
@field_validator("api_key")
@classmethod
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
"""验证 API Key 安全性"""
if v is None:
return v
v = v.strip()
if len(v) < 10:
raise ValueError("API Key 长度不能少于 10 个字符")
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
for char in dangerous_chars:
if char in v:
raise ValueError(f"API Key 包含非法字符: {char}")
return v
@field_validator("name")
@classmethod
def validate_name(cls, v: Optional[str]) -> Optional[str]:
"""验证名称(防止 XSS"""
if v is None:
return v
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE)
return v.strip()
@field_validator("note")
@classmethod
def validate_note(cls, v: Optional[str]) -> Optional[str]:
"""验证备注(防止 XSS"""
if v is None:
return v
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE)
return v.strip()
class EndpointAPIKeyResponse(BaseModel):
"""Endpoint API Key 响应"""
id: str
endpoint_id: str
# Key 信息(脱敏)
api_key_masked: str = Field(..., description="脱敏后的 Key")
api_key_plain: Optional[str] = Field(default=None, description="完整的 Key")
name: str = Field(..., description="密钥名称")
# 成本计算
rate_multiplier: float = Field(default=1.0, description="成本倍率")
# 优先级和限制
internal_priority: int = Field(default=50, description="Endpoint 内部优先级")
global_priority: Optional[int] = Field(default=None, description="全局 Key 优先级")
max_concurrent: Optional[int] = None
rate_limit: Optional[int] = None
daily_limit: Optional[int] = None
monthly_limit: Optional[int] = None
allowed_models: Optional[List[str]] = None
capabilities: Optional[Dict[str, bool]] = Field(
default=None, description="Key 能力标签"
)
# 缓存与熔断配置
cache_ttl_minutes: int = Field(default=5, description="缓存 TTL分钟0=禁用")
max_probe_interval_minutes: int = Field(default=32, description="熔断探测间隔(分钟)")
# 健康度
health_score: float
consecutive_failures: int
last_failure_at: Optional[datetime] = None
# 熔断器状态(滑动窗口 + 半开模式)
circuit_breaker_open: bool = Field(default=False, description="熔断器是否打开")
circuit_breaker_open_at: Optional[datetime] = Field(default=None, description="熔断器打开时间")
next_probe_at: Optional[datetime] = Field(default=None, description="下次进入半开状态时间")
half_open_until: Optional[datetime] = Field(default=None, description="半开状态结束时间")
half_open_successes: Optional[int] = Field(default=0, description="半开状态成功次数")
half_open_failures: Optional[int] = Field(default=0, description="半开状态失败次数")
request_results_window: Optional[List[dict]] = Field(None, description="请求结果滑动窗口")
# 使用统计
request_count: int
success_count: int
error_count: int
success_rate: float = Field(default=0.0, description="成功率")
avg_response_time_ms: float = Field(default=0.0, description="平均响应时间(毫秒)")
# 状态
is_active: bool
# 自适应并发信息
is_adaptive: bool = Field(default=False, description="是否为自适应模式max_concurrent=NULL")
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
effective_limit: Optional[int] = Field(None, description="当前有效限制")
# 滑动窗口利用率采样
utilization_samples: Optional[List[dict]] = Field(None, description="利用率采样窗口")
last_probe_increase_at: Optional[datetime] = Field(None, description="上次探测性扩容时间")
concurrent_429_count: Optional[int] = None
rpm_429_count: Optional[int] = None
last_429_at: Optional[datetime] = None
last_429_type: Optional[str] = None
# 备注
note: Optional[str] = None
# 时间戳
last_used_at: Optional[datetime] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# ========== 健康监控相关 ==========
class HealthStatusResponse(BaseModel):
"""健康状态响应(仅 Key 级别)"""
# Key 健康状态
key_id: str
key_health_score: float
key_consecutive_failures: int
key_last_failure_at: Optional[datetime] = None
key_is_active: bool
key_statistics: Optional[Dict[str, Any]] = None
# 熔断器状态(滑动窗口 + 半开模式)
circuit_breaker_open: bool = False
circuit_breaker_open_at: Optional[datetime] = None
next_probe_at: Optional[datetime] = None
half_open_until: Optional[datetime] = None
half_open_successes: int = 0
half_open_failures: int = 0
class HealthSummaryResponse(BaseModel):
"""健康状态摘要"""
endpoints: Dict[str, int] = Field(..., description="Endpoint 统计 (total, active, unhealthy)")
keys: Dict[str, int] = Field(..., description="Key 统计 (total, active, unhealthy)")
# ========== 并发控制相关 ==========
class ConcurrencyStatusResponse(BaseModel):
"""并发状态响应"""
endpoint_id: Optional[str] = None
endpoint_current_concurrency: int = Field(default=0, description="Endpoint 当前并发数")
endpoint_max_concurrent: Optional[int] = Field(default=None, description="Endpoint 最大并发数")
key_id: Optional[str] = None
key_current_concurrency: int = Field(default=0, description="Key 当前并发数")
key_max_concurrent: Optional[int] = Field(default=None, description="Key 最大并发数")
class ResetConcurrencyRequest(BaseModel):
"""重置并发计数请求"""
endpoint_id: Optional[str] = Field(default=None, description="Endpoint ID可选")
key_id: Optional[str] = Field(default=None, description="Key ID可选")
class KeyPriorityItem(BaseModel):
"""单个 Key 优先级项"""
key_id: str = Field(..., description="Key ID")
internal_priority: int = Field(..., ge=0, description="Endpoint 内部优先级(数字越小越优先)")
class BatchUpdateKeyPriorityRequest(BaseModel):
"""批量更新 Key 优先级请求"""
priorities: List[KeyPriorityItem] = Field(..., min_length=1, description="Key 优先级列表")
# ========== 提供商摘要(增强版) ==========
class ProviderUpdateRequest(BaseModel):
"""Provider 基础配置更新请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None
website: Optional[str] = Field(None, max_length=500, description="主站网站")
priority: Optional[int] = None
weight: Optional[float] = Field(None, gt=0)
provider_priority: Optional[int] = Field(None, description="提供商优先级(数字越小越优先)")
is_active: Optional[bool] = None
billing_type: Optional[str] = Field(
None, description="计费类型monthly_quota/pay_as_you_go/free_tier"
)
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="订阅配额(美元)")
quota_reset_day: Optional[int] = Field(None, ge=1, le=31, description="配额重置日1-31")
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
rpm_limit: Optional[int] = Field(
None, ge=0, description="每分钟请求数限制NULL=无限制0=禁止请求)"
)
class ProviderWithEndpointsSummary(BaseModel):
"""Provider 和 Endpoints 摘要"""
# Provider 基本信息
id: str
name: str
display_name: str
description: Optional[str] = None
website: Optional[str] = None
provider_priority: int = Field(default=100, description="提供商优先级(数字越小越优先)")
is_active: bool
# 计费相关字段
billing_type: Optional[str] = None
monthly_quota_usd: Optional[float] = None
monthly_used_usd: Optional[float] = None
quota_reset_day: Optional[int] = Field(default=None, description="配额重置周期(天数)")
quota_last_reset_at: Optional[datetime] = Field(default=None, description="当前周期开始时间")
quota_expires_at: Optional[datetime] = Field(default=None, description="配额过期时间")
# RPM 限制
rpm_limit: Optional[int] = Field(
default=None, description="每分钟请求数限制NULL=无限制0=禁止请求)"
)
rpm_used: Optional[int] = Field(default=None, description="当前分钟已用请求数")
rpm_reset_at: Optional[datetime] = Field(default=None, description="RPM 重置时间")
# Endpoint 统计
total_endpoints: int = Field(default=0, description="总 Endpoint 数量")
active_endpoints: int = Field(default=0, description="活跃 Endpoint 数量")
# Key 统计(所有 Endpoints 的 Keys
total_keys: int = Field(default=0, description="总 Key 数量")
active_keys: int = Field(default=0, description="活跃 Key 数量")
# Model 统计
total_models: int = Field(default=0, description="总模型数量")
active_models: int = Field(default=0, description="活跃模型数量")
# API 格式列表
api_formats: List[str] = Field(default=[], description="支持的 API 格式列表")
# Endpoint 健康度详情
endpoint_health_details: List[Dict[str, Any]] = Field(
default=[],
description="每个 Endpoint 的健康度详情 [{api_format: str, health_score: float, is_active: bool}]",
)
# 健康度统计
avg_health_score: float = Field(default=1.0, description="平均健康度")
unhealthy_endpoints: int = Field(
default=0, description="不健康的端点数量health_score < 0.5"
)
# 时间戳
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# ========== 健康监控可视化模型 ==========
class EndpointHealthEvent(BaseModel):
"""单个端点的请求事件"""
timestamp: datetime
status: str
status_code: Optional[int] = None
latency_ms: Optional[int] = None
error_type: Optional[str] = None
error_message: Optional[str] = None
class EndpointHealthMonitor(BaseModel):
"""端点健康监控信息"""
endpoint_id: str
api_format: str
is_active: bool
total_attempts: int
success_count: int
failed_count: int
skipped_count: int
success_rate: float = Field(default=1.0, description="最近事件窗口的成功率")
last_event_at: Optional[datetime] = None
events: List[EndpointHealthEvent] = Field(default_factory=list)
class ProviderEndpointHealthMonitorResponse(BaseModel):
"""Provider 下所有端点的健康监控"""
provider_id: str
provider_name: str
generated_at: datetime
endpoints: List[EndpointHealthMonitor] = Field(default_factory=list)
class ApiFormatHealthMonitor(BaseModel):
"""按 API 格式聚合的健康监控信息"""
api_format: str
total_attempts: int
success_count: int
failed_count: int
skipped_count: int
success_rate: float = Field(default=1.0, description="最近事件窗口的成功率")
provider_count: int = Field(default=0, description="参与统计的 Provider 数量")
key_count: int = Field(default=0, description="参与统计的 API Key 数量")
last_event_at: Optional[datetime] = None
events: List[EndpointHealthEvent] = Field(default_factory=list)
timeline: List[str] = Field(
default_factory=list,
description="Usage 表生成的健康时间线healthy/warning/unhealthy/unknown",
)
time_range_start: Optional[datetime] = Field(
default=None, description="时间线所覆盖区间的开始时间"
)
time_range_end: Optional[datetime] = Field(
default=None, description="时间线所覆盖区间的结束时间"
)
class ApiFormatHealthMonitorResponse(BaseModel):
"""所有 API 格式的健康监控汇总"""
generated_at: datetime
formats: List[ApiFormatHealthMonitor] = Field(default_factory=list)
# ========== 公开健康监控模型(不含敏感信息) ==========
class PublicHealthEvent(BaseModel):
"""公开版单个请求事件(不含敏感信息如 provider_id、key_id"""
timestamp: datetime
status: str
status_code: Optional[int] = None
latency_ms: Optional[int] = None
error_type: Optional[str] = None
class PublicApiFormatHealthMonitor(BaseModel):
"""公开版 API 格式健康监控信息(不含敏感信息)"""
api_format: str
api_path: str = Field(default="/", description="该 API 格式的本站请求路径")
total_attempts: int = Field(default=0, description="总请求次数")
success_count: int = Field(default=0, description="成功次数")
failed_count: int = Field(default=0, description="失败次数")
skipped_count: int = Field(default=0, description="跳过次数")
success_rate: float = Field(default=1.0, description="成功率")
last_event_at: Optional[datetime] = None
events: List[PublicHealthEvent] = Field(default_factory=list, description="事件列表")
timeline: List[str] = Field(
default_factory=list,
description="Usage 表生成的健康时间线healthy/warning/unhealthy/unknown",
)
time_range_start: Optional[datetime] = Field(
default=None, description="时间线覆盖区间开始时间"
)
time_range_end: Optional[datetime] = Field(
default=None, description="时间线覆盖区间结束时间"
)
class PublicApiFormatHealthMonitorResponse(BaseModel):
"""公开版健康监控汇总(不含敏感信息)"""
generated_at: datetime
formats: List[PublicApiFormatHealthMonitor] = Field(default_factory=list)

468
src/models/gemini.py Normal file
View File

@@ -0,0 +1,468 @@
"""
Google Gemini API 请求/响应模型
支持 Gemini 3 Pro 及之前版本的 API 格式
参考文档: https://ai.google.dev/gemini-api/docs/gemini-3
"""
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field
class BaseModelWithExtras(BaseModel):
"""允许额外字段的基础模型"""
model_config = ConfigDict(extra="allow")
# ---------------------------------------------------------------------------
# 内容块定义
# ---------------------------------------------------------------------------
class GeminiTextPart(BaseModelWithExtras):
"""文本内容块"""
text: str
thought_signature: Optional[str] = Field(
default=None,
alias="thoughtSignature",
description="Gemini 3 思维签名,用于维护多轮对话中的推理上下文",
)
class GeminiInlineData(BaseModelWithExtras):
"""内联数据(图片等)"""
mime_type: str = Field(alias="mimeType")
data: str # base64 encoded
class GeminiMediaResolution(BaseModelWithExtras):
"""
媒体分辨率配置 (Gemini 3 新增)
控制图片/视频的处理分辨率:
- media_resolution_low: 图片 280 tokens, 视频 70 tokens/帧
- media_resolution_medium: 图片 560 tokens, 视频 70 tokens/帧
- media_resolution_high: 图片 1120 tokens, 视频 280 tokens/帧
"""
level: Literal["media_resolution_low", "media_resolution_medium", "media_resolution_high"]
class GeminiFileData(BaseModelWithExtras):
"""文件引用"""
mime_type: Optional[str] = Field(default=None, alias="mimeType")
file_uri: str = Field(alias="fileUri")
class GeminiFunctionCall(BaseModelWithExtras):
"""函数调用"""
name: str
args: Dict[str, Any]
class GeminiFunctionResponse(BaseModelWithExtras):
"""函数响应"""
name: str
response: Dict[str, Any]
class GeminiPart(BaseModelWithExtras):
"""
Gemini 内容部分 - 支持多种类型
可以是以下类型之一:
- text: 文本内容
- inline_data: 内联数据(图片等)
- file_data: 文件引用
- function_call: 函数调用
- function_response: 函数响应
Gemini 3 新增:
- thought_signature: 思维签名,用于维护推理上下文
- media_resolution: 媒体分辨率配置
"""
text: Optional[str] = None
inline_data: Optional[GeminiInlineData] = Field(default=None, alias="inlineData")
file_data: Optional[GeminiFileData] = Field(default=None, alias="fileData")
function_call: Optional[GeminiFunctionCall] = Field(default=None, alias="functionCall")
function_response: Optional[GeminiFunctionResponse] = Field(
default=None, alias="functionResponse"
)
# Gemini 3 新增
thought_signature: Optional[str] = Field(
default=None,
alias="thoughtSignature",
description="思维签名,用于函数调用和图片生成的上下文保持",
)
media_resolution: Optional[GeminiMediaResolution] = Field(
default=None, alias="mediaResolution", description="媒体分辨率配置"
)
class GeminiContent(BaseModelWithExtras):
"""
Gemini 消息内容
对应 Gemini API 的 Content 对象
"""
role: Optional[Literal["user", "model"]] = None
parts: List[Union[GeminiPart, Dict[str, Any]]]
# ---------------------------------------------------------------------------
# 配置定义
# ---------------------------------------------------------------------------
class GeminiImageConfig(BaseModelWithExtras):
"""
图片生成配置 (Gemini 3 Pro Image)
用于 gemini-3-pro-image-preview 模型
"""
aspect_ratio: Optional[str] = Field(
default=None, alias="aspectRatio", description="图片宽高比,如 '16:9', '1:1', '4:3'"
)
image_size: Optional[Literal["2K", "4K"]] = Field(
default=None, alias="imageSize", description="图片尺寸: 2K 或 4K"
)
class GeminiGenerationConfig(BaseModelWithExtras):
"""
生成配置
Gemini 3 新增:
- thinking_level: 思考深度 (low/medium/high)
- response_json_schema: 结构化输出的 JSON Schema
- image_config: 图片生成配置
"""
temperature: Optional[float] = Field(
default=None, description="采样温度Gemini 3 建议保持默认值 1.0"
)
top_p: Optional[float] = Field(default=None, alias="topP")
top_k: Optional[int] = Field(default=None, alias="topK")
max_output_tokens: Optional[int] = Field(default=None, alias="maxOutputTokens")
stop_sequences: Optional[List[str]] = Field(default=None, alias="stopSequences")
candidate_count: Optional[int] = Field(default=None, alias="candidateCount")
response_mime_type: Optional[str] = Field(default=None, alias="responseMimeType")
response_schema: Optional[Dict[str, Any]] = Field(default=None, alias="responseSchema")
# Gemini 3 新增
response_json_schema: Optional[Dict[str, Any]] = Field(
default=None, alias="responseJsonSchema", description="结构化输出的 JSON Schema"
)
thinking_level: Optional[Literal["low", "medium", "high"]] = Field(
default=None,
alias="thinkingLevel",
description="Gemini 3 思考深度: low(快速), medium(平衡), high(深度推理,默认)",
)
image_config: Optional[GeminiImageConfig] = Field(
default=None, alias="imageConfig", description="图片生成配置"
)
class GeminiSafetySettings(BaseModelWithExtras):
"""安全设置"""
category: str
threshold: str
class GeminiFunctionDeclaration(BaseModelWithExtras):
"""函数声明"""
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
class GeminiGoogleSearchTool(BaseModelWithExtras):
"""Google Search 工具 (Gemini 3)"""
pass # 空对象表示启用
class GeminiUrlContextTool(BaseModelWithExtras):
"""URL Context 工具 (Gemini 3)"""
pass # 空对象表示启用
class GeminiCodeExecutionTool(BaseModelWithExtras):
"""代码执行工具"""
pass # 空对象表示启用
class GeminiTool(BaseModelWithExtras):
"""
工具定义
支持的工具类型:
- function_declarations: 自定义函数
- code_execution: 代码执行
- google_search: Google 搜索 (Gemini 3)
- url_context: URL 上下文 (Gemini 3)
"""
function_declarations: Optional[List[GeminiFunctionDeclaration]] = Field(
default=None, alias="functionDeclarations"
)
code_execution: Optional[Dict[str, Any]] = Field(default=None, alias="codeExecution")
# Gemini 3 内置工具
google_search: Optional[Dict[str, Any]] = Field(
default=None, alias="googleSearch", description="启用 Google 搜索工具"
)
url_context: Optional[Dict[str, Any]] = Field(
default=None, alias="urlContext", description="启用 URL 上下文工具"
)
class GeminiToolConfig(BaseModelWithExtras):
"""工具配置"""
function_calling_config: Optional[Dict[str, Any]] = Field(
default=None, alias="functionCallingConfig"
)
class GeminiSystemInstruction(BaseModelWithExtras):
"""系统指令"""
parts: List[Union[GeminiPart, Dict[str, Any]]]
# ---------------------------------------------------------------------------
# 请求模型
# ---------------------------------------------------------------------------
class GeminiGenerateContentRequest(BaseModelWithExtras):
"""
Gemini generateContent 请求模型
对应 POST /v1beta/models/{model}:generateContent 端点
"""
contents: List[GeminiContent]
system_instruction: Optional[GeminiSystemInstruction] = Field(
default=None, alias="systemInstruction"
)
tools: Optional[List[GeminiTool]] = None
tool_config: Optional[GeminiToolConfig] = Field(default=None, alias="toolConfig")
safety_settings: Optional[List[GeminiSafetySettings]] = Field(
default=None, alias="safetySettings"
)
generation_config: Optional[GeminiGenerationConfig] = Field(
default=None, alias="generationConfig"
)
class GeminiStreamGenerateContentRequest(BaseModelWithExtras):
"""
Gemini streamGenerateContent 请求模型
对应 POST /v1beta/models/{model}:streamGenerateContent 端点
与 generateContent 相同,但返回流式响应
"""
contents: List[GeminiContent]
system_instruction: Optional[GeminiSystemInstruction] = Field(
default=None, alias="systemInstruction"
)
tools: Optional[List[GeminiTool]] = None
tool_config: Optional[GeminiToolConfig] = Field(default=None, alias="toolConfig")
safety_settings: Optional[List[GeminiSafetySettings]] = Field(
default=None, alias="safetySettings"
)
generation_config: Optional[GeminiGenerationConfig] = Field(
default=None, alias="generationConfig"
)
# ---------------------------------------------------------------------------
# 统一请求模型(用于内部处理)
# ---------------------------------------------------------------------------
class GeminiRequest(BaseModelWithExtras):
"""
Gemini 统一请求模型
内部使用,统一处理 generateContent 和 streamGenerateContent
注意: Gemini API 通过 URL 端点区分流式/非流式请求:
- generateContent - 非流式
- streamGenerateContent - 流式
请求体中不应包含 stream 字段
"""
model: Optional[str] = Field(default=None, description="模型名称,从 URL 路径提取(内部使用)")
contents: List[GeminiContent]
system_instruction: Optional[GeminiSystemInstruction] = Field(
default=None, alias="systemInstruction"
)
tools: Optional[List[GeminiTool]] = None
tool_config: Optional[GeminiToolConfig] = Field(default=None, alias="toolConfig")
safety_settings: Optional[List[GeminiSafetySettings]] = Field(
default=None, alias="safetySettings"
)
generation_config: Optional[GeminiGenerationConfig] = Field(
default=None, alias="generationConfig"
)
# ---------------------------------------------------------------------------
# 响应模型
# ---------------------------------------------------------------------------
class GeminiUsageMetadata(BaseModelWithExtras):
"""Token 使用量"""
prompt_token_count: int = Field(default=0, alias="promptTokenCount")
candidates_token_count: int = Field(default=0, alias="candidatesTokenCount")
total_token_count: int = Field(default=0, alias="totalTokenCount")
cached_content_token_count: Optional[int] = Field(default=None, alias="cachedContentTokenCount")
class GeminiSafetyRating(BaseModelWithExtras):
"""安全评级"""
category: str
probability: str
blocked: Optional[bool] = None
class GeminiCitationSource(BaseModelWithExtras):
"""引用来源"""
start_index: Optional[int] = Field(default=None, alias="startIndex")
end_index: Optional[int] = Field(default=None, alias="endIndex")
uri: Optional[str] = None
license: Optional[str] = None
class GeminiCitationMetadata(BaseModelWithExtras):
"""引用元数据"""
citation_sources: Optional[List[GeminiCitationSource]] = Field(
default=None, alias="citationSources"
)
class GeminiGroundingMetadata(BaseModelWithExtras):
"""
Grounding 元数据 (Gemini 3)
当使用 Google Search 工具时返回
"""
search_entry_point: Optional[Dict[str, Any]] = Field(default=None, alias="searchEntryPoint")
grounding_chunks: Optional[List[Dict[str, Any]]] = Field(default=None, alias="groundingChunks")
grounding_supports: Optional[List[Dict[str, Any]]] = Field(
default=None, alias="groundingSupports"
)
web_search_queries: Optional[List[str]] = Field(default=None, alias="webSearchQueries")
class GeminiCandidate(BaseModelWithExtras):
"""候选响应"""
content: Optional[GeminiContent] = None
finish_reason: Optional[str] = Field(default=None, alias="finishReason")
safety_ratings: Optional[List[GeminiSafetyRating]] = Field(default=None, alias="safetyRatings")
citation_metadata: Optional[GeminiCitationMetadata] = Field(
default=None, alias="citationMetadata"
)
grounding_metadata: Optional[GeminiGroundingMetadata] = Field(
default=None, alias="groundingMetadata"
)
token_count: Optional[int] = Field(default=None, alias="tokenCount")
index: Optional[int] = None
class GeminiPromptFeedback(BaseModelWithExtras):
"""提示反馈"""
block_reason: Optional[str] = Field(default=None, alias="blockReason")
safety_ratings: Optional[List[GeminiSafetyRating]] = Field(default=None, alias="safetyRatings")
class GeminiGenerateContentResponse(BaseModelWithExtras):
"""
Gemini generateContent 响应模型
对应 generateContent 端点的响应体
"""
candidates: Optional[List[GeminiCandidate]] = None
prompt_feedback: Optional[GeminiPromptFeedback] = Field(default=None, alias="promptFeedback")
usage_metadata: Optional[GeminiUsageMetadata] = Field(default=None, alias="usageMetadata")
model_version: Optional[str] = Field(default=None, alias="modelVersion")
# ---------------------------------------------------------------------------
# 流式响应模型
# ---------------------------------------------------------------------------
class GeminiStreamChunk(BaseModelWithExtras):
"""
Gemini 流式响应块
流式响应中的单个数据块,结构与完整响应相同
"""
candidates: Optional[List[GeminiCandidate]] = None
prompt_feedback: Optional[GeminiPromptFeedback] = Field(default=None, alias="promptFeedback")
usage_metadata: Optional[GeminiUsageMetadata] = Field(default=None, alias="usageMetadata")
model_version: Optional[str] = Field(default=None, alias="modelVersion")
# ---------------------------------------------------------------------------
# 错误响应
# ---------------------------------------------------------------------------
class GeminiErrorDetail(BaseModelWithExtras):
"""错误详情"""
type: Optional[str] = Field(default=None, alias="@type")
reason: Optional[str] = None
domain: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
class GeminiError(BaseModelWithExtras):
"""错误信息"""
code: int
message: str
status: str
details: Optional[List[GeminiErrorDetail]] = None
class GeminiErrorResponse(BaseModelWithExtras):
"""错误响应"""
error: GeminiError
# ---------------------------------------------------------------------------
# Thought Signature 常量
# ---------------------------------------------------------------------------
# 用于从其他模型迁移对话时绕过签名验证
DUMMY_THOUGHT_SIGNATURE = "context_engineering_is_the_way_to_go"

153
src/models/openai.py Normal file
View File

@@ -0,0 +1,153 @@
"""
OpenAI API 数据模型定义
"""
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, ConfigDict
# 配置允许额外字段,以支持 API 的新特性
class BaseModelWithExtras(BaseModel):
model_config = ConfigDict(extra="allow")
class OpenAIMessage(BaseModelWithExtras):
"""OpenAI消息模型"""
role: str
content: Optional[Union[str, List[Dict[str, Any]]]] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None
name: Optional[str] = None
class OpenAIFunction(BaseModelWithExtras):
"""OpenAI函数定义"""
name: str
description: Optional[str] = None
parameters: Dict[str, Any]
class OpenAITool(BaseModelWithExtras):
"""OpenAI工具定义"""
type: str = "function"
function: OpenAIFunction
class OpenAIRequest(BaseModelWithExtras):
"""OpenAI请求模型"""
model: str
messages: List[OpenAIMessage]
max_tokens: Optional[int] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = None
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
tools: Optional[List[OpenAITool]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
n: Optional[int] = None
seed: Optional[int] = None
response_format: Optional[Dict[str, Any]] = None
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
user: Optional[str] = None
class ResponsesInputMessage(BaseModelWithExtras):
"""Responses API 输入消息"""
type: str = "message"
role: str
content: List[Dict[str, Any]]
class ResponsesReasoningConfig(BaseModelWithExtras):
"""Responses API 推理配置"""
effort: str = "high" # low, medium, high
summary: str = "auto" # auto, off
class ResponsesRequest(BaseModelWithExtras):
"""OpenAI Responses API 请求模型(用于 Claude Code 等客户端)"""
model: str
instructions: Optional[str] = None
input: List[ResponsesInputMessage]
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
parallel_tool_calls: Optional[bool] = False
reasoning: Optional[ResponsesReasoningConfig] = None
store: Optional[bool] = False
stream: Optional[bool] = True
include: Optional[List[str]] = None
prompt_cache_key: Optional[str] = None
# 其他参数
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
stop: Optional[Union[str, List[str]]] = None
class OpenAIUsage(BaseModelWithExtras):
"""OpenAI使用统计"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
class OpenAIChoice(BaseModelWithExtras):
"""OpenAI选择结果"""
index: int
message: OpenAIMessage
finish_reason: Optional[str] = None
logprobs: Optional[Dict[str, Any]] = None
class OpenAIResponse(BaseModelWithExtras):
"""OpenAI响应模型"""
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[OpenAIChoice]
usage: Optional[OpenAIUsage] = None
system_fingerprint: Optional[str] = None
class OpenAIStreamDelta(BaseModelWithExtras):
"""OpenAI流式响应增量"""
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
class OpenAIStreamChoice(BaseModelWithExtras):
"""OpenAI流式响应选择"""
index: int
delta: OpenAIStreamDelta
finish_reason: Optional[str] = None
logprobs: Optional[Dict[str, Any]] = None
class OpenAIStreamResponse(BaseModelWithExtras):
"""OpenAI流式响应模型"""
id: str
object: str = "chat.completion.chunk"
created: int
model: str
choices: List[OpenAIStreamChoice]
system_fingerprint: Optional[str] = None

View File

@@ -0,0 +1,435 @@
"""
Pydantic 数据模型(阶段一统一模型管理)
"""
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, model_validator
from .api import ModelCreate
# ========== 阶梯计费相关模型 ==========
class CacheTTLPricing(BaseModel):
"""缓存时长定价配置"""
ttl_minutes: int = Field(..., ge=1, description="缓存时长(分钟)")
cache_creation_price_per_1m: float = Field(..., ge=0, description="该时长的缓存创建价格/M tokens")
class PricingTier(BaseModel):
"""单个价格阶梯配置"""
up_to: Optional[int] = Field(
None,
ge=1,
description="阶梯上限tokensnull 表示无上限(最后一个阶梯)"
)
input_price_per_1m: float = Field(..., ge=0, description="输入价格/M tokens")
output_price_per_1m: float = Field(..., ge=0, description="输出价格/M tokens")
cache_creation_price_per_1m: Optional[float] = Field(
None, ge=0, description="缓存创建价格/M tokens"
)
cache_read_price_per_1m: Optional[float] = Field(
None, ge=0, description="缓存读取价格/M tokens"
)
cache_ttl_pricing: Optional[List[CacheTTLPricing]] = Field(
None, description="按缓存时长分价格(可选)"
)
class TieredPricingConfig(BaseModel):
"""阶梯计费配置"""
tiers: List[PricingTier] = Field(
...,
min_length=1,
description="价格阶梯列表,按 up_to 升序排列"
)
@model_validator(mode="after")
def validate_tiers(self) -> "TieredPricingConfig":
"""验证阶梯配置的合法性"""
tiers = self.tiers
if not tiers:
raise ValueError("至少需要一个价格阶梯")
# 检查阶梯顺序和唯一性
prev_up_to = 0
has_unlimited = False
for i, tier in enumerate(tiers):
if has_unlimited:
raise ValueError("无上限阶梯up_to=null必须是最后一个")
if tier.up_to is None:
has_unlimited = True
else:
if tier.up_to <= prev_up_to:
raise ValueError(
f"阶梯 {i+1} 的 up_to ({tier.up_to}) 必须大于前一个阶梯 ({prev_up_to})"
)
prev_up_to = tier.up_to
# 验证缓存时长定价顺序
if tier.cache_ttl_pricing:
prev_ttl = 0
for ttl_pricing in tier.cache_ttl_pricing:
if ttl_pricing.ttl_minutes <= prev_ttl:
raise ValueError(
f"cache_ttl_pricing 必须按 ttl_minutes 升序排列"
)
prev_ttl = ttl_pricing.ttl_minutes
# 最后一个阶梯必须是无上限的
if not has_unlimited:
raise ValueError("最后一个阶梯必须设置 up_to=null无上限")
return self
# ========== 其他模型 ==========
class ModelCapabilities(BaseModel):
"""模型能力聚合"""
supports_vision: bool = False
supports_function_calling: bool = False
supports_streaming: bool = False
class ModelPriceRange(BaseModel):
"""统一模型价格区间"""
min_input: Optional[float] = None
max_input: Optional[float] = None
min_output: Optional[float] = None
max_output: Optional[float] = None
class ModelCatalogProviderDetail(BaseModel):
"""统一模型目录中的关联提供商信息"""
provider_id: str
provider_name: str
provider_display_name: Optional[str]
model_id: Optional[str]
target_model: str
input_price_per_1m: Optional[float]
output_price_per_1m: Optional[float]
cache_creation_price_per_1m: Optional[float]
cache_read_price_per_1m: Optional[float]
cache_1h_creation_price_per_1m: Optional[float] = None # 1h 缓存创建价格
price_per_request: Optional[float] = None # 按次计费价格
effective_tiered_pricing: Optional[Dict[str, Any]] = None # 有效阶梯计费配置(含继承)
tier_count: int = 1 # 阶梯数量
supports_vision: Optional[bool] = None
supports_function_calling: Optional[bool] = None
supports_streaming: Optional[bool] = None
is_active: bool
mapping_id: Optional[str]
class OrphanedModel(BaseModel):
"""孤立的统一模型Mapping 存在但 GlobalModel 缺失)"""
alias: str # 别名
global_model_name: Optional[str] # 关联的 GlobalModel 名称(如果有)
mapping_count: int
class ModelCatalogItem(BaseModel):
"""统一模型目录条目(方案 A基于 GlobalModel"""
global_model_name: str # GlobalModel.name
display_name: str # GlobalModel.display_name
description: Optional[str] # GlobalModel.description
aliases: List[str] # 所有指向该 GlobalModel 的别名列表
providers: List[ModelCatalogProviderDetail] # 支持该模型的 Provider 列表
price_range: ModelPriceRange # 价格区间(从所有 Provider 的 Model 中聚合)
total_providers: int
capabilities: ModelCapabilities # 能力聚合(从所有 Provider 的 Model 中聚合)
class ModelCatalogResponse(BaseModel):
"""统一模型目录响应"""
models: List[ModelCatalogItem]
total: int
orphaned_models: List[OrphanedModel]
class ProviderModelPriceInfo(BaseModel):
"""Provider 维度的模型价格信息"""
input_price_per_1m: Optional[float]
output_price_per_1m: Optional[float]
cache_creation_price_per_1m: Optional[float]
cache_read_price_per_1m: Optional[float]
price_per_request: Optional[float] = None # 按次计费价格
class ProviderAvailableSourceModel(BaseModel):
"""Provider 支持的统一模型条目(方案 A"""
global_model_name: str # GlobalModel.name
display_name: str # GlobalModel.display_name
provider_model_name: str # Model.provider_model_name (Provider 侧的模型名)
has_alias: bool # 是否有别名指向该 GlobalModel
aliases: List[str] # 别名列表
model_id: Optional[str] # Model.id
price: ProviderModelPriceInfo
capabilities: ModelCapabilities
is_active: bool
class ProviderAvailableSourceModelsResponse(BaseModel):
"""Provider 可用统一模型响应"""
models: List[ProviderAvailableSourceModel]
total: int
class BatchAssignProviderConfig(BaseModel):
"""批量添加映射的 Provider 配置"""
provider_id: str
create_model: bool = Field(False, description="是否需要创建新的 Model")
model_data: Optional[ModelCreate] = Field(
None, description="create_model=true 时需要提供的模型配置", alias="model_config"
)
model_id: Optional[str] = Field(None, description="create_model=false 时需要提供的现有模型 ID")
class BatchAssignModelMappingRequest(BaseModel):
"""批量添加模型映射请求(方案 A暂不支持需要重构"""
global_model_id: str # 要分配的 GlobalModel ID
providers: List[BatchAssignProviderConfig]
class BatchAssignProviderResult(BaseModel):
"""批量映射结果条目"""
provider_id: str
mapping_id: Optional[str]
created_model: bool
model_id: Optional[str]
updated: bool = False
class BatchAssignError(BaseModel):
"""批量映射错误信息"""
provider_id: str
error: str
class BatchAssignModelMappingResponse(BaseModel):
"""批量映射响应"""
success: bool
created_mappings: List[BatchAssignProviderResult]
errors: List[BatchAssignError]
# ========== 阶段二GlobalModel 相关模型 ==========
class GlobalModelCreate(BaseModel):
"""创建 GlobalModel 请求"""
name: str = Field(..., min_length=1, max_length=100, description="统一模型名(唯一)")
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
description: Optional[str] = Field(None, description="模型描述")
official_url: Optional[str] = Field(None, max_length=500, description="官方文档链接")
icon_url: Optional[str] = Field(None, max_length=500, description="图标 URL")
# 按次计费配置(可选,与阶梯计费叠加)
default_price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
# 统一阶梯计费配置(必填)
# 固定价格也用单阶梯表示: {"tiers": [{"up_to": null, "input_price_per_1m": X, ...}]}
default_tiered_pricing: TieredPricingConfig = Field(
..., description="阶梯计费配置(固定价格用单阶梯表示)"
)
# 默认能力配置
default_supports_vision: Optional[bool] = Field(False, description="默认是否支持视觉")
default_supports_function_calling: Optional[bool] = Field(
False, description="默认是否支持函数调用"
)
default_supports_streaming: Optional[bool] = Field(True, description="默认是否支持流式输出")
default_supports_extended_thinking: Optional[bool] = Field(
False, description="默认是否支持扩展思考"
)
default_supports_image_generation: Optional[bool] = Field(
False, description="默认是否支持图像生成"
)
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"]
supported_capabilities: Optional[List[str]] = Field(
None, description="支持的 Key 能力列表"
)
is_active: Optional[bool] = Field(True, description="是否激活")
class GlobalModelUpdate(BaseModel):
"""更新 GlobalModel 请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None
official_url: Optional[str] = Field(None, max_length=500)
icon_url: Optional[str] = Field(None, max_length=500)
is_active: Optional[bool] = None
# 按次计费配置
default_price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
# 阶梯计费配置
default_tiered_pricing: Optional[TieredPricingConfig] = Field(
None, description="阶梯计费配置"
)
# 默认能力配置
default_supports_vision: Optional[bool] = None
default_supports_function_calling: Optional[bool] = None
default_supports_streaming: Optional[bool] = None
default_supports_extended_thinking: Optional[bool] = None
default_supports_image_generation: Optional[bool] = None
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"]
supported_capabilities: Optional[List[str]] = Field(
None, description="支持的 Key 能力列表"
)
class GlobalModelResponse(BaseModel):
"""GlobalModel 响应"""
id: str
name: str
display_name: str
description: Optional[str]
official_url: Optional[str]
icon_url: Optional[str]
is_active: bool
# 按次计费配置
default_price_per_request: Optional[float] = Field(None, description="每次请求固定费用")
# 阶梯计费配置
default_tiered_pricing: TieredPricingConfig = Field(
..., description="阶梯计费配置"
)
# 默认能力配置
default_supports_vision: Optional[bool]
default_supports_function_calling: Optional[bool]
default_supports_streaming: Optional[bool]
default_supports_extended_thinking: Optional[bool]
default_supports_image_generation: Optional[bool]
# Key 能力配置 - 模型支持的能力列表
supported_capabilities: Optional[List[str]] = Field(
default=None, description="支持的 Key 能力列表"
)
# 统计数据(可选)
provider_count: Optional[int] = Field(default=0, description="支持的 Provider 数量")
alias_count: Optional[int] = Field(default=0, description="别名数量")
usage_count: Optional[int] = Field(default=0, description="调用次数")
created_at: datetime
updated_at: Optional[datetime]
class Config:
from_attributes = True
class GlobalModelWithStats(GlobalModelResponse):
"""带统计信息的 GlobalModel"""
total_models: int = Field(..., description="关联的 Model 数量")
total_providers: int = Field(..., description="支持的 Provider 数量")
price_range: ModelPriceRange
class GlobalModelListResponse(BaseModel):
"""GlobalModel 列表响应"""
models: List[GlobalModelResponse]
total: int
class BatchAssignToProvidersRequest(BaseModel):
"""批量为 Provider 添加 GlobalModel 实现"""
provider_ids: List[str] = Field(..., min_items=1, description="Provider ID 列表")
create_models: bool = Field(default=False, description="是否自动创建 Model 记录")
class BatchAssignToProvidersResponse(BaseModel):
"""批量分配响应"""
success: List[dict]
errors: List[dict]
class BatchAssignModelsToProviderRequest(BaseModel):
"""批量为 Provider 关联 GlobalModel"""
global_model_ids: List[str] = Field(..., min_length=1, description="GlobalModel ID 列表")
class BatchAssignModelsToProviderResponse(BaseModel):
"""批量关联 GlobalModel 到 Provider 的响应"""
success: List[dict]
errors: List[dict]
class UpdateModelMappingRequest(BaseModel):
"""更新模型映射请求"""
source_model: Optional[str] = Field(
None, min_length=1, max_length=200, description="源模型名或别名"
)
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
provider_id: Optional[str] = Field(None, description="Provider ID为空时为全局别名")
is_active: Optional[bool] = Field(None, description="是否启用")
class UpdateModelMappingResponse(BaseModel):
"""更新模型映射响应"""
success: bool
mapping_id: str
message: Optional[str] = None
class DeleteModelMappingResponse(BaseModel):
"""删除模型映射响应"""
success: bool
message: Optional[str] = None
__all__ = [
"BatchAssignError",
"BatchAssignModelMappingRequest",
"BatchAssignModelMappingResponse",
"BatchAssignModelsToProviderRequest",
"BatchAssignModelsToProviderResponse",
"BatchAssignProviderConfig",
"BatchAssignProviderResult",
"BatchAssignToProvidersRequest",
"BatchAssignToProvidersResponse",
"DeleteModelMappingResponse",
"GlobalModelCreate",
"GlobalModelListResponse",
"GlobalModelResponse",
"GlobalModelUpdate",
"GlobalModelWithStats",
"ModelCapabilities",
"ModelCatalogItem",
"ModelCatalogProviderDetail",
"ModelCatalogResponse",
"ModelPriceRange",
"OrphanedModel",
"ProviderAvailableSourceModel",
"ProviderAvailableSourceModelsResponse",
"ProviderModelPriceInfo",
"UpdateModelMappingRequest",
"UpdateModelMappingResponse",
]