mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 18:52:28 +08:00
Initial commit
This commit is contained in:
10
src/models/__init__.py
Normal file
10
src/models/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
统一的模型定义模块
|
||||
"""
|
||||
|
||||
from .api import * # noqa: F401, F403
|
||||
from .claude import * # noqa: F401, F403
|
||||
from .database import * # noqa: F401, F403
|
||||
from .openai import * # noqa: F401, F403
|
||||
|
||||
__all__ = ["claude", "database", "openai", "api"]
|
||||
353
src/models/admin_requests.py
Normal file
353
src/models/admin_requests.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
管理接口的 Pydantic 请求模型
|
||||
|
||||
提供完整的输入验证和安全过滤
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from src.core.enums import APIFormat, ProviderBillingType
|
||||
|
||||
|
||||
class CreateProviderRequest(BaseModel):
|
||||
"""创建 Provider 请求"""
|
||||
|
||||
name: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="Provider 名称(英文字母、数字、下划线、连字符)",
|
||||
)
|
||||
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
|
||||
description: Optional[str] = Field(None, max_length=1000, description="描述")
|
||||
website: Optional[str] = Field(None, max_length=500, description="官网地址")
|
||||
billing_type: Optional[str] = Field(
|
||||
ProviderBillingType.PAY_AS_YOU_GO.value, description="计费类型"
|
||||
)
|
||||
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="周期配额(美元)")
|
||||
quota_reset_day: Optional[int] = Field(30, ge=1, le=365, description="配额重置周期(天数)")
|
||||
quota_last_reset_at: Optional[datetime] = Field(None, description="当前周期开始时间")
|
||||
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
|
||||
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
||||
provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)")
|
||||
is_active: Optional[bool] = Field(True, description="是否启用")
|
||||
rate_limit: Optional[int] = Field(None, ge=0, description="速率限制")
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""验证名称格式"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
|
||||
|
||||
# SQL 注入防护:检查危险关键字
|
||||
dangerous_keywords = [
|
||||
"SELECT",
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"CREATE",
|
||||
"ALTER",
|
||||
"EXEC",
|
||||
"UNION",
|
||||
"OR",
|
||||
"AND",
|
||||
"--",
|
||||
";",
|
||||
"'",
|
||||
'"',
|
||||
"<",
|
||||
">",
|
||||
]
|
||||
upper_name = v.upper()
|
||||
for keyword in dangerous_keywords:
|
||||
if keyword in upper_name:
|
||||
raise ValueError(f"名称包含禁止的字符或关键字: {keyword}")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("display_name", "description")
|
||||
@classmethod
|
||||
def sanitize_text(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""清理文本输入,防止 XSS"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# 移除潜在的脚本标签
|
||||
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
|
||||
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE) # 移除事件处理器
|
||||
|
||||
# 移除危险的 HTML 标签
|
||||
dangerous_tags = ["script", "iframe", "object", "embed", "link", "style"]
|
||||
for tag in dangerous_tags:
|
||||
v = re.sub(rf"<{tag}[^>]*>", "", v, flags=re.IGNORECASE)
|
||||
v = re.sub(rf"</{tag}>", "", v, flags=re.IGNORECASE)
|
||||
|
||||
return v.strip()
|
||||
|
||||
@field_validator("website")
|
||||
@classmethod
|
||||
def validate_website(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证网站地址"""
|
||||
if v is None or v.strip() == "":
|
||||
return None
|
||||
|
||||
v = v.strip()
|
||||
|
||||
# 自动补全 https:// 前缀
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
v = f"https://{v}"
|
||||
|
||||
# 防止 SSRF 攻击:禁止内网地址
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
r"169\.254\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("billing_type")
|
||||
@classmethod
|
||||
def validate_billing_type(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证计费类型"""
|
||||
if v is None:
|
||||
return ProviderBillingType.PAY_AS_YOU_GO.value
|
||||
|
||||
try:
|
||||
ProviderBillingType(v)
|
||||
return v
|
||||
except ValueError:
|
||||
valid_types = [t.value for t in ProviderBillingType]
|
||||
raise ValueError(f"无效的计费类型,有效值为: {', '.join(valid_types)}")
|
||||
|
||||
|
||||
class UpdateProviderRequest(BaseModel):
|
||||
"""更新 Provider 请求"""
|
||||
|
||||
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
website: Optional[str] = Field(None, max_length=500)
|
||||
billing_type: Optional[str] = None
|
||||
monthly_quota_usd: Optional[float] = Field(None, ge=0)
|
||||
quota_reset_day: Optional[int] = Field(None, ge=1, le=365)
|
||||
quota_last_reset_at: Optional[datetime] = None
|
||||
quota_expires_at: Optional[datetime] = None
|
||||
rpm_limit: Optional[int] = Field(None, ge=0)
|
||||
provider_priority: Optional[int] = Field(None, ge=0, le=1000)
|
||||
is_active: Optional[bool] = None
|
||||
rate_limit: Optional[int] = Field(None, ge=0)
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0)
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 复用相同的验证器
|
||||
_sanitize_text = field_validator("display_name", "description")(
|
||||
CreateProviderRequest.sanitize_text.__func__
|
||||
)
|
||||
_validate_website = field_validator("website")(CreateProviderRequest.validate_website.__func__)
|
||||
_validate_billing_type = field_validator("billing_type")(
|
||||
CreateProviderRequest.validate_billing_type.__func__
|
||||
)
|
||||
|
||||
|
||||
class CreateEndpointRequest(BaseModel):
|
||||
"""创建 Endpoint 请求"""
|
||||
|
||||
provider_id: str = Field(..., description="Provider ID")
|
||||
name: str = Field(..., min_length=1, max_length=100, description="Endpoint 名称")
|
||||
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
|
||||
api_format: str = Field(..., description="API 格式(CLAUDE 或 OPENAI)")
|
||||
custom_path: Optional[str] = Field(None, max_length=200, description="自定义路径")
|
||||
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
|
||||
is_active: Optional[bool] = Field(True, description="是否启用")
|
||||
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""验证名称"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
|
||||
return v
|
||||
|
||||
@field_validator("base_url")
|
||||
@classmethod
|
||||
def validate_base_url(cls, v: str) -> str:
|
||||
"""验证 API URL"""
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||
|
||||
# 防止 SSRF
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v.rstrip("/") # 移除末尾斜杠
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def validate_api_format(cls, v: str) -> str:
|
||||
"""验证 API 格式"""
|
||||
try:
|
||||
APIFormat(v)
|
||||
return v
|
||||
except ValueError:
|
||||
valid_formats = [f.value for f in APIFormat]
|
||||
raise ValueError(f"无效的 API 格式,有效值为: {', '.join(valid_formats)}")
|
||||
|
||||
@field_validator("custom_path")
|
||||
@classmethod
|
||||
def validate_custom_path(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证自定义路径"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# 确保路径不包含危险字符
|
||||
if not re.match(r"^[/a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("路径只能包含字母、数字、斜杠、下划线和连字符")
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class UpdateEndpointRequest(BaseModel):
|
||||
"""更新 Endpoint 请求"""
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
base_url: Optional[str] = Field(None, min_length=1, max_length=500)
|
||||
api_format: Optional[str] = None
|
||||
custom_path: Optional[str] = Field(None, max_length=200)
|
||||
priority: Optional[int] = Field(None, ge=0, le=1000)
|
||||
is_active: Optional[bool] = None
|
||||
rpm_limit: Optional[int] = Field(None, ge=0)
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0)
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 复用验证器
|
||||
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)
|
||||
_validate_base_url = field_validator("base_url")(
|
||||
CreateEndpointRequest.validate_base_url.__func__
|
||||
)
|
||||
_validate_api_format = field_validator("api_format")(
|
||||
CreateEndpointRequest.validate_api_format.__func__
|
||||
)
|
||||
_validate_custom_path = field_validator("custom_path")(
|
||||
CreateEndpointRequest.validate_custom_path.__func__
|
||||
)
|
||||
|
||||
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
"""创建 API Key 请求"""
|
||||
|
||||
endpoint_id: str = Field(..., description="Endpoint ID")
|
||||
api_key: str = Field(..., min_length=1, max_length=500, description="API Key")
|
||||
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
|
||||
is_active: Optional[bool] = Field(True, description="是否启用")
|
||||
max_rpm: Optional[int] = Field(None, ge=0, description="最大 RPM")
|
||||
max_concurrent: Optional[int] = Field(None, ge=0, description="最大并发")
|
||||
notes: Optional[str] = Field(None, max_length=500, description="备注")
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def validate_api_key(cls, v: str) -> str:
|
||||
"""验证 API Key"""
|
||||
# 移除首尾空白
|
||||
v = v.strip()
|
||||
|
||||
# 检查最小长度
|
||||
if len(v) < 10:
|
||||
raise ValueError("API Key 长度不能少于 10 个字符")
|
||||
|
||||
# 检查危险字符(不应包含 SQL 注入字符)
|
||||
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
|
||||
for char in dangerous_chars:
|
||||
if char in v:
|
||||
raise ValueError(f"API Key 包含非法字符: {char}")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("notes")
|
||||
@classmethod
|
||||
def sanitize_notes(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""清理备注"""
|
||||
if v is None:
|
||||
return v
|
||||
# 复用文本清理逻辑
|
||||
return CreateProviderRequest.sanitize_text(v)
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
"""更新用户请求"""
|
||||
|
||||
username: Optional[str] = Field(None, min_length=1, max_length=50)
|
||||
email: Optional[str] = Field(None, max_length=100)
|
||||
quota_usd: Optional[float] = Field(None, ge=0)
|
||||
is_active: Optional[bool] = None
|
||||
role: Optional[str] = None
|
||||
allowed_providers: Optional[List[str]] = Field(None, description="允许使用的提供商 ID 列表")
|
||||
allowed_endpoints: Optional[List[str]] = Field(None, description="允许使用的端点 ID 列表")
|
||||
allowed_models: Optional[List[str]] = Field(None, description="允许使用的模型名称列表")
|
||||
|
||||
@field_validator("username")
|
||||
@classmethod
|
||||
def validate_username(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证用户名"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("用户名只能包含字母、数字、下划线和连字符")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证邮箱"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# 简单的邮箱格式验证
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if not re.match(email_pattern, v):
|
||||
raise ValueError("邮箱格式不正确")
|
||||
|
||||
return v.lower()
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证角色"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
valid_roles = ["admin", "user"]
|
||||
if v not in valid_roles:
|
||||
raise ValueError(f"无效的角色,有效值为: {', '.join(valid_roles)}")
|
||||
|
||||
return v
|
||||
716
src/models/api.py
Normal file
716
src/models/api.py
Normal file
@@ -0,0 +1,716 @@
|
||||
"""
|
||||
API端点请求/响应模型定义
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from ..core.enums import UserRole
|
||||
|
||||
|
||||
# ========== 认证相关 ==========
|
||||
class LoginRequest(BaseModel):
|
||||
"""登录请求"""
|
||||
|
||||
email: str = Field(..., min_length=3, max_length=255, description="邮箱地址")
|
||||
password: str = Field(..., min_length=1, max_length=128, description="密码")
|
||||
|
||||
@classmethod
|
||||
@field_validator("email")
|
||||
def validate_email(cls, v):
|
||||
"""验证邮箱格式"""
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if not re.match(email_pattern, v):
|
||||
raise ValueError("邮箱格式无效")
|
||||
return v.lower()
|
||||
|
||||
@classmethod
|
||||
@field_validator("password")
|
||||
def validate_password(cls, v):
|
||||
"""验证密码不为空且去除前后空格"""
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("密码不能为空")
|
||||
return v
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""登录响应"""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str # 刷新令牌
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = 86400 # Token有效期(秒),默认24小时
|
||||
user_id: str
|
||||
email: str
|
||||
username: str
|
||||
role: str
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""刷新令牌请求"""
|
||||
|
||||
refresh_token: str = Field(..., description="刷新令牌")
|
||||
|
||||
|
||||
class RefreshTokenResponse(BaseModel):
|
||||
"""刷新令牌响应"""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str # 返回新的刷新令牌
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = 86400 # Token有效期(秒),默认24小时
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""注册请求"""
|
||||
|
||||
email: str = Field(..., min_length=3, max_length=255, description="邮箱地址")
|
||||
username: str = Field(..., min_length=2, max_length=50, description="用户名")
|
||||
password: str = Field(..., min_length=6, max_length=128, description="密码")
|
||||
|
||||
@classmethod
|
||||
@field_validator("email")
|
||||
def validate_email(cls, v):
|
||||
"""验证邮箱格式"""
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if not re.match(email_pattern, v):
|
||||
raise ValueError("邮箱格式无效")
|
||||
return v.lower()
|
||||
|
||||
@classmethod
|
||||
@field_validator("username")
|
||||
def validate_username(cls, v):
|
||||
"""验证用户名格式"""
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("用户名不能为空")
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("用户名只能包含字母、数字、下划线和短横线")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
@field_validator("password")
|
||||
def validate_password(cls, v):
|
||||
"""验证密码强度"""
|
||||
if len(v) < 6:
|
||||
raise ValueError("密码至少需要6个字符")
|
||||
if not re.search(r"[A-Z]", v):
|
||||
raise ValueError("密码必须包含至少一个大写字母")
|
||||
if not re.search(r"[a-z]", v):
|
||||
raise ValueError("密码必须包含至少一个小写字母")
|
||||
if not re.search(r"\d", v):
|
||||
raise ValueError("密码必须包含至少一个数字")
|
||||
return v
|
||||
|
||||
|
||||
class RegisterResponse(BaseModel):
|
||||
"""注册响应"""
|
||||
|
||||
user_id: str
|
||||
email: str
|
||||
username: str
|
||||
message: str
|
||||
|
||||
|
||||
class LogoutResponse(BaseModel):
|
||||
"""登出响应"""
|
||||
|
||||
message: str
|
||||
success: bool
|
||||
|
||||
|
||||
# ========== 用户管理 ==========
|
||||
class CreateUserRequest(BaseModel):
|
||||
"""创建用户请求"""
|
||||
|
||||
username: str = Field(..., min_length=2, max_length=50, description="用户名")
|
||||
password: str = Field(..., min_length=6, max_length=128, description="密码")
|
||||
email: str = Field(..., min_length=3, max_length=255, description="邮箱地址")
|
||||
role: Optional[UserRole] = Field(UserRole.USER, description="用户角色")
|
||||
quota_usd: Optional[float] = Field(default=10.0, description="USD配额,null表示无限制")
|
||||
|
||||
@field_validator("quota_usd", mode="before")
|
||||
@classmethod
|
||||
def validate_quota_usd(cls, v):
|
||||
"""验证配额值,允许null表示无限制"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, (int, float)) and v >= 0 and v <= 10000:
|
||||
return float(v)
|
||||
if isinstance(v, (int, float)):
|
||||
raise ValueError("配额必须在 0-10000 范围内")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
@field_validator("email")
|
||||
def validate_email(cls, v):
|
||||
"""验证邮箱格式"""
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("邮箱不能为空")
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if not re.match(email_pattern, v):
|
||||
raise ValueError("邮箱格式无效")
|
||||
return v.lower()
|
||||
|
||||
@classmethod
|
||||
@field_validator("username")
|
||||
def validate_username(cls, v):
|
||||
"""验证用户名格式"""
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("用户名不能为空")
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("用户名只能包含字母、数字、下划线和短横线")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
@field_validator("password")
|
||||
def validate_password(cls, v):
|
||||
"""验证密码强度"""
|
||||
if len(v) < 6:
|
||||
raise ValueError("密码至少需要6个字符")
|
||||
if not re.search(r"[A-Z]", v):
|
||||
raise ValueError("密码必须包含至少一个大写字母")
|
||||
if not re.search(r"[a-z]", v):
|
||||
raise ValueError("密码必须包含至少一个小写字母")
|
||||
if not re.search(r"\d", v):
|
||||
raise ValueError("密码必须包含至少一个数字")
|
||||
return v
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
"""更新用户请求"""
|
||||
|
||||
email: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
role: Optional[UserRole] = None
|
||||
allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表
|
||||
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
|
||||
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
|
||||
quota_usd: Optional[float] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
@field_validator("quota_usd", mode="before")
|
||||
@classmethod
|
||||
def validate_quota_usd(cls, v):
|
||||
"""验证配额值,允许null表示无限制"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, (int, float)) and v >= 0 and v <= 10000:
|
||||
return float(v)
|
||||
if isinstance(v, (int, float)):
|
||||
raise ValueError("配额必须在 0-10000 范围内")
|
||||
return v
|
||||
|
||||
|
||||
class CreateApiKeyRequest(BaseModel):
|
||||
"""创建API密钥请求"""
|
||||
|
||||
name: Optional[str] = None
|
||||
allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表
|
||||
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
|
||||
allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表
|
||||
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
|
||||
rate_limit: Optional[int] = 100
|
||||
expire_days: Optional[int] = None # None = 永不过期,数字 = 多少天后过期
|
||||
initial_balance_usd: Optional[float] = Field(
|
||||
None, description="初始余额(USD),仅用于独立Key,None = 无限制"
|
||||
)
|
||||
is_standalone: bool = Field(False, description="是否为独立余额Key(给非注册用户使用)")
|
||||
auto_delete_on_expiry: bool = Field(
|
||||
False, description="过期后是否自动删除(True=物理删除,False=仅禁用)"
|
||||
)
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""用户响应"""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
username: str
|
||||
role: UserRole
|
||||
allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表
|
||||
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
|
||||
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
|
||||
quota_usd: float
|
||||
used_usd: float
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_login_at: Optional[datetime]
|
||||
|
||||
|
||||
class ApiKeyResponse(BaseModel):
|
||||
"""API密钥响应"""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
key: Optional[str] = None # 仅在创建时返回完整密钥
|
||||
key_display: Optional[str] = None # 脱敏后的密钥显示
|
||||
name: Optional[str]
|
||||
total_requests: int
|
||||
total_tokens: int
|
||||
total_cost_usd: float
|
||||
allowed_providers: Optional[List[str]]
|
||||
allowed_models: Optional[List[str]]
|
||||
rate_limit: int
|
||||
is_active: bool
|
||||
expires_at: Optional[datetime] = None
|
||||
balance_used_usd: float = 0.0
|
||||
current_balance_usd: Optional[float] = None # NULL = 无限制
|
||||
is_standalone: bool = False
|
||||
force_capabilities: Optional[Dict[str, bool]] = None # 强制开启的能力
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime]
|
||||
|
||||
|
||||
# ========== 提供商管理 ==========
|
||||
class ProviderCreate(BaseModel):
|
||||
"""创建提供商请求
|
||||
|
||||
新架构说明:
|
||||
- Provider 仅包含提供商的元数据和计费配置
|
||||
- API格式、URL、认证等配置应在 ProviderEndpoint 中设置
|
||||
- API密钥应在 ProviderAPIKey 中设置
|
||||
"""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100, description="提供商唯一标识")
|
||||
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
|
||||
description: Optional[str] = Field(None, description="提供商描述")
|
||||
website: Optional[str] = Field(None, max_length=500, description="主站网站")
|
||||
|
||||
# Provider 级别的配置
|
||||
rate_limit: Optional[int] = Field(None, description="每分钟请求限制")
|
||||
concurrent_limit: Optional[int] = Field(None, description="并发请求限制")
|
||||
config: Optional[dict] = Field(None, description="额外配置")
|
||||
is_active: bool = Field(False, description="是否启用(默认false,需要配置API密钥后才能启用)")
|
||||
|
||||
|
||||
class ProviderUpdate(BaseModel):
|
||||
"""更新提供商请求"""
|
||||
|
||||
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
website: Optional[str] = Field(None, max_length=500)
|
||||
api_format: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
headers: Optional[dict] = None
|
||||
timeout: Optional[int] = Field(None, ge=1, le=600)
|
||||
max_retries: Optional[int] = Field(None, ge=0, le=10)
|
||||
priority: Optional[int] = None
|
||||
weight: Optional[float] = Field(None, gt=0)
|
||||
rate_limit: Optional[int] = None
|
||||
concurrent_limit: Optional[int] = None
|
||||
config: Optional[dict] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class ProviderResponse(BaseModel):
|
||||
"""提供商响应"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
website: Optional[str]
|
||||
api_format: str
|
||||
base_url: str
|
||||
headers: Optional[dict]
|
||||
timeout: int
|
||||
max_retries: int
|
||||
priority: int
|
||||
weight: float
|
||||
rate_limit: Optional[int]
|
||||
concurrent_limit: Optional[int]
|
||||
config: Optional[dict]
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
models_count: int = 0
|
||||
active_models_count: int = 0
|
||||
model_mappings_count: int = 0
|
||||
api_keys_count: int = 0
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ========== 模型管理 ==========
|
||||
class ModelCreate(BaseModel):
|
||||
"""创建模型请求 - 价格和能力字段可选,为空时使用 GlobalModel 默认值"""
|
||||
|
||||
provider_model_name: str = Field(
|
||||
..., min_length=1, max_length=200, description="Provider 侧的模型名称"
|
||||
)
|
||||
global_model_id: str = Field(..., description="关联的 GlobalModel ID(必填)")
|
||||
# 按次计费配置 - 可选,为空时使用 GlobalModel 默认值
|
||||
price_per_request: Optional[float] = Field(
|
||||
None, ge=0, description="每次请求固定费用,为空使用默认值"
|
||||
)
|
||||
# 阶梯计费配置 - 可选,为空时使用 GlobalModel 默认值
|
||||
tiered_pricing: Optional[dict] = Field(
|
||||
None, description="阶梯计费配置,为空使用 GlobalModel 默认值"
|
||||
)
|
||||
# 能力配置 - 可选,为空时使用 GlobalModel 默认值
|
||||
supports_vision: Optional[bool] = Field(None, description="是否支持图像输入,为空使用默认值")
|
||||
supports_function_calling: Optional[bool] = Field(
|
||||
None, description="是否支持函数调用,为空使用默认值"
|
||||
)
|
||||
supports_streaming: Optional[bool] = Field(None, description="是否支持流式输出,为空使用默认值")
|
||||
supports_extended_thinking: Optional[bool] = Field(
|
||||
None, description="是否支持扩展思考,为空使用默认值"
|
||||
)
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
config: Optional[dict] = Field(None, description="额外配置")
|
||||
|
||||
|
||||
class ModelUpdate(BaseModel):
|
||||
"""更新模型请求"""
|
||||
|
||||
provider_model_name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
global_model_id: Optional[str] = None
|
||||
# 按次计费配置
|
||||
price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
|
||||
# 阶梯计费配置
|
||||
tiered_pricing: Optional[dict] = Field(None, description="阶梯计费配置")
|
||||
supports_vision: Optional[bool] = None
|
||||
supports_function_calling: Optional[bool] = None
|
||||
supports_streaming: Optional[bool] = None
|
||||
supports_extended_thinking: Optional[bool] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_available: Optional[bool] = None
|
||||
config: Optional[dict] = None
|
||||
|
||||
|
||||
class ModelResponse(BaseModel):
|
||||
"""模型响应 - 包含 Model 配置和关联的 GlobalModel 信息
|
||||
|
||||
注意:价格和能力字段返回的是有效值(优先使用 Model 配置,否则使用 GlobalModel 默认值)
|
||||
"""
|
||||
|
||||
id: str
|
||||
provider_id: str
|
||||
global_model_id: Optional[str]
|
||||
provider_model_name: str
|
||||
|
||||
# 按次计费配置
|
||||
price_per_request: Optional[float] = None
|
||||
# 阶梯计费配置
|
||||
tiered_pricing: Optional[dict] = None
|
||||
|
||||
# Provider 能力配置 - 可选,为空表示使用 GlobalModel 默认值
|
||||
supports_vision: Optional[bool]
|
||||
supports_function_calling: Optional[bool]
|
||||
supports_streaming: Optional[bool]
|
||||
supports_extended_thinking: Optional[bool]
|
||||
supports_image_generation: Optional[bool]
|
||||
|
||||
# 有效值(合并 Model 配置和 GlobalModel 默认值后的结果)
|
||||
effective_tiered_pricing: Optional[dict] = None
|
||||
effective_input_price: Optional[float] = None
|
||||
effective_output_price: Optional[float] = None
|
||||
effective_price_per_request: Optional[float] = None
|
||||
effective_supports_vision: Optional[bool] = None
|
||||
effective_supports_function_calling: Optional[bool] = None
|
||||
effective_supports_streaming: Optional[bool] = None
|
||||
effective_supports_extended_thinking: Optional[bool] = None
|
||||
effective_supports_image_generation: Optional[bool] = None
|
||||
|
||||
# 状态
|
||||
is_active: bool
|
||||
is_available: bool
|
||||
|
||||
# 时间戳
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# 关联的 GlobalModel 信息(如果有)
|
||||
global_model_name: Optional[str] = None
|
||||
global_model_display_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ModelDetailResponse(BaseModel):
|
||||
"""模型详细响应 - 包含所有字段(用于需要完整信息的场景)"""
|
||||
|
||||
id: str
|
||||
provider_id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
icon_url: Optional[str]
|
||||
tags: Optional[List[str]]
|
||||
input_price_per_1m: float
|
||||
output_price_per_1m: float
|
||||
cache_creation_price_per_1m: Optional[float]
|
||||
cache_read_price_per_1m: Optional[float]
|
||||
supports_vision: bool
|
||||
supports_function_calling: bool
|
||||
supports_streaming: bool
|
||||
is_active: bool
|
||||
is_available: bool
|
||||
config: Optional[dict]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ========== 模型映射 ==========
|
||||
class ModelMappingCreate(BaseModel):
|
||||
"""创建模型映射请求(源模型到目标模型的映射)"""
|
||||
|
||||
source_model: str = Field(..., min_length=1, max_length=200, description="源模型名或别名")
|
||||
target_global_model_id: str = Field(..., description="目标 GlobalModel ID")
|
||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时表示全局别名)")
|
||||
mapping_type: str = Field(
|
||||
"alias",
|
||||
description="映射类型:alias=按目标模型计费(别名),mapping=按源模型计费(降级映射)",
|
||||
)
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
|
||||
|
||||
class ModelMappingUpdate(BaseModel):
|
||||
"""更新模型映射请求"""
|
||||
|
||||
source_model: Optional[str] = Field(
|
||||
None, min_length=1, max_length=200, description="源模型名或别名"
|
||||
)
|
||||
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
|
||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时表示全局别名)")
|
||||
mapping_type: Optional[str] = Field(
|
||||
None, description="映射类型:alias=按目标模型计费(别名),mapping=按源模型计费(降级映射)"
|
||||
)
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class ModelMappingResponse(BaseModel):
|
||||
"""模型映射响应"""
|
||||
|
||||
id: str
|
||||
source_model: str
|
||||
target_global_model_id: str
|
||||
target_global_model_name: Optional[str]
|
||||
target_global_model_display_name: Optional[str]
|
||||
provider_id: Optional[str]
|
||||
provider_name: Optional[str]
|
||||
scope: str = Field(..., description="global 或 provider")
|
||||
mapping_type: str = Field(..., description="映射类型:alias 或 mapping")
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ========== 系统设置 ==========
|
||||
class SystemSettingsRequest(BaseModel):
|
||||
"""系统设置请求"""
|
||||
|
||||
default_provider: Optional[str] = None
|
||||
default_model: Optional[str] = None
|
||||
enable_usage_tracking: Optional[bool] = None
|
||||
|
||||
|
||||
class SystemSettingsResponse(BaseModel):
|
||||
"""系统设置响应"""
|
||||
|
||||
default_provider: Optional[str]
|
||||
default_model: Optional[str]
|
||||
enable_usage_tracking: bool
|
||||
|
||||
|
||||
# ========== 使用统计 ==========
|
||||
class UsageStatsResponse(BaseModel):
|
||||
"""使用统计响应"""
|
||||
|
||||
total_requests: int
|
||||
total_tokens: int
|
||||
total_cost_usd: float
|
||||
daily_requests: int
|
||||
daily_tokens: int
|
||||
daily_cost_usd: float
|
||||
model_usage: Dict[str, Dict[str, Any]]
|
||||
provider_usage: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
# ========== 公开API响应模型 ==========
|
||||
class PublicProviderResponse(BaseModel):
|
||||
"""公开的提供商信息响应"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
website: Optional[str]
|
||||
is_active: bool
|
||||
provider_priority: int # 提供商优先级(数字越小越优先)
|
||||
# 统计信息
|
||||
models_count: int
|
||||
active_models_count: int
|
||||
mappings_count: int
|
||||
endpoints_count: int # 端点总数
|
||||
active_endpoints_count: int # 活跃端点数
|
||||
|
||||
|
||||
class PublicModelResponse(BaseModel):
|
||||
"""公开的模型信息响应"""
|
||||
|
||||
id: str
|
||||
provider_id: str
|
||||
provider_name: str
|
||||
provider_display_name: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
icon_url: Optional[str] = None
|
||||
# 价格信息
|
||||
input_price_per_1m: Optional[float] = None
|
||||
output_price_per_1m: Optional[float] = None
|
||||
cache_creation_price_per_1m: Optional[float] = None
|
||||
cache_read_price_per_1m: Optional[float] = None
|
||||
# 功能支持
|
||||
supports_vision: Optional[bool] = None
|
||||
supports_function_calling: Optional[bool] = None
|
||||
supports_streaming: Optional[bool] = None
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class PublicModelMappingResponse(BaseModel):
|
||||
"""公开的模型映射信息响应"""
|
||||
|
||||
id: str
|
||||
source_model: str
|
||||
target_global_model_id: str
|
||||
target_global_model_name: Optional[str]
|
||||
target_global_model_display_name: Optional[str]
|
||||
provider_id: Optional[str] = None
|
||||
scope: str = Field(..., description="global 或 provider")
|
||||
is_active: bool
|
||||
|
||||
|
||||
class ProviderStatsResponse(BaseModel):
|
||||
"""提供商统计信息响应"""
|
||||
|
||||
total_providers: int
|
||||
active_providers: int
|
||||
total_models: int
|
||||
active_models: int
|
||||
total_mappings: int
|
||||
supported_formats: List[str]
|
||||
|
||||
|
||||
class PublicGlobalModelResponse(BaseModel):
|
||||
"""公开的 GlobalModel 信息响应(用户可见)"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
icon_url: Optional[str] = None
|
||||
is_active: bool = True
|
||||
# 按次计费配置
|
||||
default_price_per_request: Optional[float] = None
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing: Optional[dict] = None
|
||||
# 默认能力
|
||||
default_supports_vision: bool = False
|
||||
default_supports_function_calling: bool = False
|
||||
default_supports_streaming: bool = True
|
||||
default_supports_extended_thinking: bool = False
|
||||
# Key 能力配置
|
||||
supported_capabilities: Optional[List[str]] = None
|
||||
|
||||
|
||||
class PublicGlobalModelListResponse(BaseModel):
|
||||
"""公开的 GlobalModel 列表响应"""
|
||||
|
||||
models: List[PublicGlobalModelResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ========== 个人中心相关模型 ==========
|
||||
class UpdateProfileRequest(BaseModel):
|
||||
"""更新个人信息请求"""
|
||||
|
||||
email: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class UpdatePreferencesRequest(BaseModel):
|
||||
"""更新偏好设置请求"""
|
||||
|
||||
avatar_url: Optional[str] = None
|
||||
bio: Optional[str] = None
|
||||
default_provider_id: Optional[int] = None
|
||||
theme: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
email_notifications: Optional[bool] = None
|
||||
usage_alerts: Optional[bool] = None
|
||||
announcement_notifications: Optional[bool] = None
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
"""修改密码请求"""
|
||||
|
||||
old_password: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class CreateMyApiKeyRequest(BaseModel):
|
||||
"""创建我的API密钥请求"""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""提供商配置"""
|
||||
|
||||
provider_id: str = Field(..., description="提供商ID")
|
||||
priority: int = Field(100, description="优先级(越高越优先)")
|
||||
weight: float = Field(1.0, description="负载均衡权重")
|
||||
enabled: bool = Field(True, description="是否启用")
|
||||
|
||||
|
||||
class UpdateApiKeyProvidersRequest(BaseModel):
|
||||
"""更新API密钥可用提供商请求"""
|
||||
|
||||
allowed_providers: Optional[List[ProviderConfig]] = None # 提供商配置列表
|
||||
|
||||
|
||||
# ========== 公告相关模型 ==========
|
||||
class CreateAnnouncementRequest(BaseModel):
|
||||
"""创建公告请求"""
|
||||
|
||||
title: str
|
||||
content: str # 支持Markdown
|
||||
type: str = "info" # info, warning, maintenance, important
|
||||
priority: int = 0
|
||||
is_pinned: bool = False
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
|
||||
|
||||
class UpdateAnnouncementRequest(BaseModel):
|
||||
"""更新公告请求"""
|
||||
|
||||
title: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
priority: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_pinned: Optional[bool] = None
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
72
src/models/api_key.py
Normal file
72
src/models/api_key.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Provider API Key相关的API模型
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ProviderAPIKeyBase(BaseModel):
|
||||
"""Provider API Key基础模型"""
|
||||
|
||||
name: Optional[str] = Field(None, description="密钥名称/备注")
|
||||
api_key: str = Field(..., description="API密钥")
|
||||
rate_limit: Optional[int] = Field(None, description="速率限制(每分钟请求数)")
|
||||
daily_limit: Optional[int] = Field(None, description="每日请求限制")
|
||||
monthly_limit: Optional[int] = Field(None, description="每月请求限制")
|
||||
priority: int = Field(0, description="优先级(越高越优先使用)")
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
expires_at: Optional[datetime] = Field(None, description="过期时间")
|
||||
|
||||
|
||||
class ProviderAPIKeyCreate(ProviderAPIKeyBase):
|
||||
"""创建Provider API Key请求"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProviderAPIKeyUpdate(BaseModel):
|
||||
"""更新Provider API Key请求"""
|
||||
|
||||
name: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
rate_limit: Optional[int] = None
|
||||
daily_limit: Optional[int] = None
|
||||
monthly_limit: Optional[int] = None
|
||||
priority: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class ProviderAPIKeyResponse(ProviderAPIKeyBase):
|
||||
"""Provider API Key响应"""
|
||||
|
||||
id: str
|
||||
provider_id: str
|
||||
request_count: Optional[int] = Field(0, description="请求次数")
|
||||
error_count: Optional[int] = Field(0, description="错误次数")
|
||||
last_used_at: Optional[datetime] = Field(None, description="最后使用时间")
|
||||
last_error_at: Optional[datetime] = Field(None, description="最后错误时间")
|
||||
last_error_msg: Optional[str] = Field(None, description="最后错误信息")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProviderAPIKeyStats(BaseModel):
|
||||
"""Provider API Key统计信息"""
|
||||
|
||||
id: str
|
||||
name: Optional[str]
|
||||
request_count: int
|
||||
error_count: int
|
||||
success_rate: float
|
||||
last_used_at: Optional[datetime]
|
||||
is_active: bool
|
||||
is_expired: bool
|
||||
remaining_daily: Optional[int] = Field(None, description="今日剩余请求数")
|
||||
remaining_monthly: Optional[int] = Field(None, description="本月剩余请求数")
|
||||
118
src/models/claude.py
Normal file
118
src/models/claude.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
# 配置允许额外字段,以支持API的新特性
|
||||
class BaseModelWithExtras(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ClaudeContentBlockText(BaseModelWithExtras):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeContentBlockImage(BaseModelWithExtras):
|
||||
type: Literal["image"]
|
||||
source: Dict[str, Any]
|
||||
|
||||
|
||||
class ClaudeContentBlockToolUse(BaseModelWithExtras):
|
||||
type: Literal["tool_use"]
|
||||
id: str
|
||||
name: str
|
||||
input: Dict[str, Any]
|
||||
|
||||
|
||||
class ClaudeContentBlockToolResult(BaseModelWithExtras):
|
||||
type: Literal["tool_result"]
|
||||
tool_use_id: str
|
||||
content: Union[str, List[Dict[str, Any]], Dict[str, Any]]
|
||||
|
||||
|
||||
class ClaudeContentBlockThinking(BaseModelWithExtras):
|
||||
type: Literal["thinking"]
|
||||
thinking: str
|
||||
|
||||
|
||||
class ClaudeSystemContent(BaseModelWithExtras):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeMessage(BaseModelWithExtras):
|
||||
role: Literal["user", "assistant"]
|
||||
# 宽松的内容类型定义 - 接受字符串或任意字典列表
|
||||
# 作为转发代理,不应该严格限制内容块类型,以支持API的新特性
|
||||
content: Union[str, List[Dict[str, Any]]]
|
||||
|
||||
|
||||
class ClaudeTool(BaseModelWithExtras):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
input_schema: Dict[str, Any]
|
||||
|
||||
|
||||
class ClaudeThinkingConfig(BaseModelWithExtras):
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class ClaudeMessagesRequest(BaseModelWithExtras):
|
||||
model: str
|
||||
max_tokens: int
|
||||
messages: List[ClaudeMessage]
|
||||
# 宽松的system类型 - 接受字符串、字典列表或任意字典
|
||||
system: Optional[Union[str, List[Dict[str, Any]], Dict[str, Any]]] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
stream: Optional[bool] = False
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tools: Optional[List[Dict[str, Any]]] = None # 改为更宽松的类型
|
||||
tool_choice: Optional[Dict[str, Any]] = None
|
||||
thinking: Optional[Dict[str, Any]] = None # 改为更宽松的类型
|
||||
|
||||
|
||||
class ClaudeTokenCountRequest(BaseModelWithExtras):
|
||||
model: str
|
||||
messages: List[ClaudeMessage]
|
||||
# 宽松的类型定义以支持API新特性
|
||||
system: Optional[Union[str, List[Dict[str, Any]], Dict[str, Any]]] = None
|
||||
tools: Optional[List[Dict[str, Any]]] = None
|
||||
thinking: Optional[Dict[str, Any]] = None
|
||||
tool_choice: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 响应模型
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ClaudeResponseUsage(BaseModelWithExtras):
|
||||
"""Claude 响应 token 使用量"""
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_input_tokens: Optional[int] = None
|
||||
cache_read_input_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class ClaudeResponse(BaseModelWithExtras):
|
||||
"""
|
||||
Claude Messages API 响应模型
|
||||
|
||||
对应 POST /v1/messages 端点的响应体。
|
||||
"""
|
||||
|
||||
id: str
|
||||
model: str
|
||||
type: Literal["message"] = "message"
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: List[Dict[str, Any]]
|
||||
stop_reason: Optional[str] = None
|
||||
stop_sequence: Optional[str] = None
|
||||
usage: Optional[ClaudeResponseUsage] = None
|
||||
context_management: Optional[Dict[str, Any]] = None
|
||||
container: Optional[Dict[str, Any]] = None
|
||||
1341
src/models/database.py
Normal file
1341
src/models/database.py
Normal file
File diff suppressed because it is too large
Load Diff
119
src/models/database_extensions.py
Normal file
119
src/models/database_extensions.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
数据库模型扩展 - 新增的提供商策略相关表
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .database import Base
|
||||
|
||||
|
||||
class ApiKeyProviderMapping(Base):
|
||||
"""
|
||||
API Key 和 Provider 的关联映射表
|
||||
|
||||
用途:管理员为特定的 API Key 指定提供商
|
||||
- 如果存在映射:该 API Key 只能使用指定的提供商(无负载均衡和故障转移)
|
||||
- 如果不存在映射:该 API Key 使用所有可用提供商(系统默认优先级,有负载均衡和故障转移)
|
||||
|
||||
注意:priority_adjustment 和 weight_multiplier 字段保留但在当前版本不使用
|
||||
"""
|
||||
|
||||
__tablename__ = "api_key_provider_mappings"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
api_key_id = Column(
|
||||
String(36), ForeignKey("api_keys.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
provider_id = Column(
|
||||
String(36), ForeignKey("providers.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# 管理员设置的优先级调整(非用户自己设置)
|
||||
priority_adjustment = Column(Integer, default=0) # 优先级调整值(可正可负)
|
||||
weight_multiplier = Column(Float, default=1.0) # 权重乘数(>0)
|
||||
|
||||
# 是否启用
|
||||
is_enabled = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# 关系
|
||||
api_key = relationship("ApiKey", back_populates="provider_mappings")
|
||||
provider = relationship("Provider", back_populates="api_key_mappings")
|
||||
|
||||
# 唯一约束
|
||||
__table_args__ = (
|
||||
UniqueConstraint("api_key_id", "provider_id", name="uq_apikey_provider"),
|
||||
Index("idx_apikey_provider_enabled", "api_key_id", "is_enabled"),
|
||||
)
|
||||
|
||||
|
||||
class ProviderUsageTracking(Base):
|
||||
"""提供商使用追踪 (用于RPM限流和健康检测)"""
|
||||
|
||||
__tablename__ = "provider_usage_tracking"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
provider_id = Column(
|
||||
String(36), ForeignKey("providers.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# 时间窗口
|
||||
window_start = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
window_end = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# 统计数据
|
||||
total_requests = Column(Integer, default=0)
|
||||
successful_requests = Column(Integer, default=0)
|
||||
failed_requests = Column(Integer, default=0)
|
||||
|
||||
# 性能数据
|
||||
avg_response_time_ms = Column(Float, default=0.0)
|
||||
total_response_time_ms = Column(Float, default=0.0) # 用于计算平均值
|
||||
|
||||
# 成本数据
|
||||
total_cost_usd = Column(Float, default=0.0)
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# 关系
|
||||
provider = relationship("Provider", back_populates="usage_tracking")
|
||||
|
||||
# 索引
|
||||
__table_args__ = (
|
||||
Index("idx_provider_window", "provider_id", "window_start"),
|
||||
Index("idx_window_time", "window_start", "window_end"),
|
||||
)
|
||||
653
src/models/endpoint_models.py
Normal file
653
src/models/endpoint_models.py
Normal file
@@ -0,0 +1,653 @@
|
||||
"""
|
||||
ProviderEndpoint 相关的 API 模型定义
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
# ========== ProviderEndpoint CRUD ==========
|
||||
|
||||
|
||||
class ProviderEndpointCreate(BaseModel):
|
||||
"""创建 Endpoint 请求"""
|
||||
|
||||
provider_id: str = Field(..., description="Provider ID")
|
||||
api_format: str = Field(..., description="API 格式 (CLAUDE, OPENAI, CLAUDE_CLI, OPENAI_CLI)")
|
||||
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
|
||||
|
||||
# 请求配置
|
||||
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
|
||||
timeout: int = Field(default=300, ge=10, le=600, description="超时时间(秒)")
|
||||
max_retries: int = Field(default=3, ge=0, le=10, description="最大重试次数")
|
||||
|
||||
# 限制
|
||||
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
|
||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制(请求/秒)")
|
||||
|
||||
# 额外配置
|
||||
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置(JSON)")
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def validate_api_format(cls, v: str) -> str:
|
||||
"""验证 API 格式"""
|
||||
from src.core.enums import APIFormat
|
||||
|
||||
allowed = [fmt.value for fmt in APIFormat]
|
||||
v_upper = v.upper()
|
||||
if v_upper not in allowed:
|
||||
raise ValueError(f"API 格式必须是 {allowed} 之一")
|
||||
return v_upper
|
||||
|
||||
@field_validator("base_url")
|
||||
@classmethod
|
||||
def validate_base_url(cls, v: str) -> str:
|
||||
"""验证 API URL(SSRF 防护)"""
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||
|
||||
# 防止 SSRF 攻击:禁止内网地址
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
r"169\.254\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v.rstrip("/") # 移除末尾斜杠
|
||||
|
||||
|
||||
class ProviderEndpointUpdate(BaseModel):
|
||||
"""更新 Endpoint 请求"""
|
||||
|
||||
base_url: Optional[str] = Field(
|
||||
default=None, min_length=1, max_length=500, description="API 基础 URL"
|
||||
)
|
||||
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
|
||||
timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)")
|
||||
max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数")
|
||||
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
|
||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
||||
is_active: Optional[bool] = Field(default=None, description="是否启用")
|
||||
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
|
||||
|
||||
@field_validator("base_url")
|
||||
@classmethod
|
||||
def validate_base_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证 API URL(SSRF 防护)"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||
|
||||
# 防止 SSRF 攻击:禁止内网地址
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
r"169\.254\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v.rstrip("/") # 移除末尾斜杠
|
||||
|
||||
|
||||
class ProviderEndpointResponse(BaseModel):
|
||||
"""Endpoint 响应"""
|
||||
|
||||
id: str
|
||||
provider_id: str
|
||||
provider_name: str # 冗余字段,方便前端显示
|
||||
|
||||
# API 配置
|
||||
api_format: str
|
||||
base_url: str
|
||||
|
||||
# 请求配置
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
timeout: int
|
||||
max_retries: int
|
||||
|
||||
# 限制
|
||||
max_concurrent: Optional[int] = None
|
||||
rate_limit: Optional[int] = None
|
||||
|
||||
# 状态
|
||||
is_active: bool
|
||||
|
||||
# 额外配置
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 统计(从 Keys 聚合)
|
||||
total_keys: int = Field(default=0, description="总 Key 数量")
|
||||
active_keys: int = Field(default=0, description="活跃 Key 数量")
|
||||
|
||||
# 时间戳
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ========== ProviderAPIKey 相关(新架构) ==========
|
||||
|
||||
|
||||
class EndpointAPIKeyCreate(BaseModel):
|
||||
"""为 Endpoint 添加 API Key"""
|
||||
|
||||
endpoint_id: str = Field(..., description="Endpoint ID")
|
||||
api_key: str = Field(..., min_length=10, max_length=500, description="API Key(将自动加密)")
|
||||
name: str = Field(..., min_length=1, max_length=100, description="密钥名称(必填,用于识别)")
|
||||
|
||||
# 成本计算
|
||||
rate_multiplier: float = Field(
|
||||
default=1.0, ge=0.01, description="成本倍率(真实成本 = 表面成本 × 倍率)"
|
||||
)
|
||||
|
||||
# 优先级和限制(数字越小越优先)
|
||||
internal_priority: int = Field(default=50, description="Endpoint 内部优先级(提供商优先模式)")
|
||||
# max_concurrent: NULL=自适应模式(系统自动学习),数字=固定限制模式
|
||||
max_concurrent: Optional[int] = Field(
|
||||
default=None, ge=1, description="最大并发数(NULL=自适应模式)"
|
||||
)
|
||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
||||
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
|
||||
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")
|
||||
allowed_models: Optional[List[str]] = Field(
|
||||
default=None, description="允许使用的模型列表(null = 支持所有模型)"
|
||||
)
|
||||
|
||||
# 能力标签
|
||||
capabilities: Optional[Dict[str, bool]] = Field(
|
||||
default=None, description="Key 能力标签,如 {'cache_1h': true, 'context_1m': true}"
|
||||
)
|
||||
|
||||
# 缓存与熔断配置
|
||||
cache_ttl_minutes: int = Field(
|
||||
default=5, ge=0, le=60, description="缓存 TTL(分钟),0=禁用,默认5分钟"
|
||||
)
|
||||
max_probe_interval_minutes: int = Field(
|
||||
default=32, ge=2, le=32, description="熔断探测间隔(分钟),范围 2-32"
|
||||
)
|
||||
|
||||
# 备注
|
||||
note: Optional[str] = Field(default=None, max_length=500, description="备注说明(可选)")
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def validate_api_key(cls, v: str) -> str:
|
||||
"""验证 API Key 安全性"""
|
||||
# 移除首尾空白
|
||||
v = v.strip()
|
||||
|
||||
# 检查最小长度
|
||||
if len(v) < 10:
|
||||
raise ValueError("API Key 长度不能少于 10 个字符")
|
||||
|
||||
# 检查危险字符(SQL 注入防护)
|
||||
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
|
||||
for char in dangerous_chars:
|
||||
if char in v:
|
||||
raise ValueError(f"API Key 包含非法字符: {char}")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""验证名称(防止 XSS)"""
|
||||
# 移除危险的 HTML 标签
|
||||
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
|
||||
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE)
|
||||
return v.strip()
|
||||
|
||||
@field_validator("note")
|
||||
@classmethod
|
||||
def validate_note(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证备注(防止 XSS)"""
|
||||
if v is None:
|
||||
return v
|
||||
# 移除危险的 HTML 标签
|
||||
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
|
||||
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE)
|
||||
return v.strip()
|
||||
|
||||
|
||||
class EndpointAPIKeyUpdate(BaseModel):
|
||||
"""更新 Endpoint API Key"""
|
||||
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, min_length=10, max_length=500, description="API Key(将自动加密)"
|
||||
)
|
||||
name: Optional[str] = Field(default=None, min_length=1, max_length=100, description="密钥名称")
|
||||
rate_multiplier: Optional[float] = Field(default=None, ge=0.01, description="成本倍率")
|
||||
internal_priority: Optional[int] = Field(
|
||||
default=None, description="Endpoint 内部优先级(提供商优先模式,数字越小越优先)"
|
||||
)
|
||||
global_priority: Optional[int] = Field(
|
||||
default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)"
|
||||
)
|
||||
# 注意:max_concurrent=None 表示不更新,要切换为自适应模式请使用专用 API
|
||||
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
|
||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
||||
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
|
||||
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")
|
||||
allowed_models: Optional[List[str]] = Field(default=None, description="允许使用的模型列表")
|
||||
capabilities: Optional[Dict[str, bool]] = Field(
|
||||
default=None, description="Key 能力标签,如 {'cache_1h': true, 'context_1m': true}"
|
||||
)
|
||||
cache_ttl_minutes: Optional[int] = Field(
|
||||
default=None, ge=0, le=60, description="缓存 TTL(分钟),0=禁用"
|
||||
)
|
||||
max_probe_interval_minutes: Optional[int] = Field(
|
||||
default=None, ge=2, le=32, description="熔断探测间隔(分钟),范围 2-32"
|
||||
)
|
||||
is_active: Optional[bool] = Field(default=None, description="是否启用")
|
||||
note: Optional[str] = Field(default=None, max_length=500, description="备注说明")
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证 API Key 安全性"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
v = v.strip()
|
||||
if len(v) < 10:
|
||||
raise ValueError("API Key 长度不能少于 10 个字符")
|
||||
|
||||
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
|
||||
for char in dangerous_chars:
|
||||
if char in v:
|
||||
raise ValueError(f"API Key 包含非法字符: {char}")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证名称(防止 XSS)"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
|
||||
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE)
|
||||
return v.strip()
|
||||
|
||||
@field_validator("note")
|
||||
@classmethod
|
||||
def validate_note(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证备注(防止 XSS)"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
v = re.sub(r"<script.*?</script>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"<iframe.*?</iframe>", "", v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r"javascript:", "", v, flags=re.IGNORECASE)
|
||||
v = re.sub(r"on\w+\s*=", "", v, flags=re.IGNORECASE)
|
||||
return v.strip()
|
||||
|
||||
|
||||
class EndpointAPIKeyResponse(BaseModel):
|
||||
"""Endpoint API Key 响应"""
|
||||
|
||||
id: str
|
||||
endpoint_id: str
|
||||
|
||||
# Key 信息(脱敏)
|
||||
api_key_masked: str = Field(..., description="脱敏后的 Key")
|
||||
api_key_plain: Optional[str] = Field(default=None, description="完整的 Key")
|
||||
name: str = Field(..., description="密钥名称")
|
||||
|
||||
# 成本计算
|
||||
rate_multiplier: float = Field(default=1.0, description="成本倍率")
|
||||
|
||||
# 优先级和限制
|
||||
internal_priority: int = Field(default=50, description="Endpoint 内部优先级")
|
||||
global_priority: Optional[int] = Field(default=None, description="全局 Key 优先级")
|
||||
max_concurrent: Optional[int] = None
|
||||
rate_limit: Optional[int] = None
|
||||
daily_limit: Optional[int] = None
|
||||
monthly_limit: Optional[int] = None
|
||||
allowed_models: Optional[List[str]] = None
|
||||
capabilities: Optional[Dict[str, bool]] = Field(
|
||||
default=None, description="Key 能力标签"
|
||||
)
|
||||
|
||||
# 缓存与熔断配置
|
||||
cache_ttl_minutes: int = Field(default=5, description="缓存 TTL(分钟),0=禁用")
|
||||
max_probe_interval_minutes: int = Field(default=32, description="熔断探测间隔(分钟)")
|
||||
|
||||
# 健康度
|
||||
health_score: float
|
||||
consecutive_failures: int
|
||||
last_failure_at: Optional[datetime] = None
|
||||
|
||||
# 熔断器状态(滑动窗口 + 半开模式)
|
||||
circuit_breaker_open: bool = Field(default=False, description="熔断器是否打开")
|
||||
circuit_breaker_open_at: Optional[datetime] = Field(default=None, description="熔断器打开时间")
|
||||
next_probe_at: Optional[datetime] = Field(default=None, description="下次进入半开状态时间")
|
||||
half_open_until: Optional[datetime] = Field(default=None, description="半开状态结束时间")
|
||||
half_open_successes: Optional[int] = Field(default=0, description="半开状态成功次数")
|
||||
half_open_failures: Optional[int] = Field(default=0, description="半开状态失败次数")
|
||||
request_results_window: Optional[List[dict]] = Field(None, description="请求结果滑动窗口")
|
||||
|
||||
# 使用统计
|
||||
request_count: int
|
||||
success_count: int
|
||||
error_count: int
|
||||
success_rate: float = Field(default=0.0, description="成功率")
|
||||
avg_response_time_ms: float = Field(default=0.0, description="平均响应时间(毫秒)")
|
||||
|
||||
# 状态
|
||||
is_active: bool
|
||||
|
||||
# 自适应并发信息
|
||||
is_adaptive: bool = Field(default=False, description="是否为自适应模式(max_concurrent=NULL)")
|
||||
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
|
||||
effective_limit: Optional[int] = Field(None, description="当前有效限制")
|
||||
# 滑动窗口利用率采样
|
||||
utilization_samples: Optional[List[dict]] = Field(None, description="利用率采样窗口")
|
||||
last_probe_increase_at: Optional[datetime] = Field(None, description="上次探测性扩容时间")
|
||||
concurrent_429_count: Optional[int] = None
|
||||
rpm_429_count: Optional[int] = None
|
||||
last_429_at: Optional[datetime] = None
|
||||
last_429_type: Optional[str] = None
|
||||
|
||||
# 备注
|
||||
note: Optional[str] = None
|
||||
|
||||
# 时间戳
|
||||
last_used_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ========== 健康监控相关 ==========
|
||||
|
||||
|
||||
class HealthStatusResponse(BaseModel):
|
||||
"""健康状态响应(仅 Key 级别)"""
|
||||
|
||||
# Key 健康状态
|
||||
key_id: str
|
||||
key_health_score: float
|
||||
key_consecutive_failures: int
|
||||
key_last_failure_at: Optional[datetime] = None
|
||||
key_is_active: bool
|
||||
key_statistics: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 熔断器状态(滑动窗口 + 半开模式)
|
||||
circuit_breaker_open: bool = False
|
||||
circuit_breaker_open_at: Optional[datetime] = None
|
||||
next_probe_at: Optional[datetime] = None
|
||||
half_open_until: Optional[datetime] = None
|
||||
half_open_successes: int = 0
|
||||
half_open_failures: int = 0
|
||||
|
||||
|
||||
class HealthSummaryResponse(BaseModel):
|
||||
"""健康状态摘要"""
|
||||
|
||||
endpoints: Dict[str, int] = Field(..., description="Endpoint 统计 (total, active, unhealthy)")
|
||||
keys: Dict[str, int] = Field(..., description="Key 统计 (total, active, unhealthy)")
|
||||
|
||||
|
||||
# ========== 并发控制相关 ==========
|
||||
|
||||
|
||||
class ConcurrencyStatusResponse(BaseModel):
|
||||
"""并发状态响应"""
|
||||
|
||||
endpoint_id: Optional[str] = None
|
||||
endpoint_current_concurrency: int = Field(default=0, description="Endpoint 当前并发数")
|
||||
endpoint_max_concurrent: Optional[int] = Field(default=None, description="Endpoint 最大并发数")
|
||||
|
||||
key_id: Optional[str] = None
|
||||
key_current_concurrency: int = Field(default=0, description="Key 当前并发数")
|
||||
key_max_concurrent: Optional[int] = Field(default=None, description="Key 最大并发数")
|
||||
|
||||
|
||||
class ResetConcurrencyRequest(BaseModel):
|
||||
"""重置并发计数请求"""
|
||||
|
||||
endpoint_id: Optional[str] = Field(default=None, description="Endpoint ID(可选)")
|
||||
key_id: Optional[str] = Field(default=None, description="Key ID(可选)")
|
||||
|
||||
|
||||
class KeyPriorityItem(BaseModel):
|
||||
"""单个 Key 优先级项"""
|
||||
|
||||
key_id: str = Field(..., description="Key ID")
|
||||
internal_priority: int = Field(..., ge=0, description="Endpoint 内部优先级(数字越小越优先)")
|
||||
|
||||
|
||||
class BatchUpdateKeyPriorityRequest(BaseModel):
|
||||
"""批量更新 Key 优先级请求"""
|
||||
|
||||
priorities: List[KeyPriorityItem] = Field(..., min_length=1, description="Key 优先级列表")
|
||||
|
||||
|
||||
# ========== 提供商摘要(增强版) ==========
|
||||
|
||||
|
||||
class ProviderUpdateRequest(BaseModel):
|
||||
"""Provider 基础配置更新请求"""
|
||||
|
||||
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
website: Optional[str] = Field(None, max_length=500, description="主站网站")
|
||||
priority: Optional[int] = None
|
||||
weight: Optional[float] = Field(None, gt=0)
|
||||
provider_priority: Optional[int] = Field(None, description="提供商优先级(数字越小越优先)")
|
||||
is_active: Optional[bool] = None
|
||||
billing_type: Optional[str] = Field(
|
||||
None, description="计费类型:monthly_quota/pay_as_you_go/free_tier"
|
||||
)
|
||||
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="订阅配额(美元)")
|
||||
quota_reset_day: Optional[int] = Field(None, ge=1, le=31, description="配额重置日(1-31)")
|
||||
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
|
||||
rpm_limit: Optional[int] = Field(
|
||||
None, ge=0, description="每分钟请求数限制(NULL=无限制,0=禁止请求)"
|
||||
)
|
||||
|
||||
|
||||
class ProviderWithEndpointsSummary(BaseModel):
|
||||
"""Provider 和 Endpoints 摘要"""
|
||||
|
||||
# Provider 基本信息
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
provider_priority: int = Field(default=100, description="提供商优先级(数字越小越优先)")
|
||||
is_active: bool
|
||||
|
||||
# 计费相关字段
|
||||
billing_type: Optional[str] = None
|
||||
monthly_quota_usd: Optional[float] = None
|
||||
monthly_used_usd: Optional[float] = None
|
||||
quota_reset_day: Optional[int] = Field(default=None, description="配额重置周期(天数)")
|
||||
quota_last_reset_at: Optional[datetime] = Field(default=None, description="当前周期开始时间")
|
||||
quota_expires_at: Optional[datetime] = Field(default=None, description="配额过期时间")
|
||||
|
||||
# RPM 限制
|
||||
rpm_limit: Optional[int] = Field(
|
||||
default=None, description="每分钟请求数限制(NULL=无限制,0=禁止请求)"
|
||||
)
|
||||
rpm_used: Optional[int] = Field(default=None, description="当前分钟已用请求数")
|
||||
rpm_reset_at: Optional[datetime] = Field(default=None, description="RPM 重置时间")
|
||||
|
||||
# Endpoint 统计
|
||||
total_endpoints: int = Field(default=0, description="总 Endpoint 数量")
|
||||
active_endpoints: int = Field(default=0, description="活跃 Endpoint 数量")
|
||||
|
||||
# Key 统计(所有 Endpoints 的 Keys)
|
||||
total_keys: int = Field(default=0, description="总 Key 数量")
|
||||
active_keys: int = Field(default=0, description="活跃 Key 数量")
|
||||
|
||||
# Model 统计
|
||||
total_models: int = Field(default=0, description="总模型数量")
|
||||
active_models: int = Field(default=0, description="活跃模型数量")
|
||||
|
||||
# API 格式列表
|
||||
api_formats: List[str] = Field(default=[], description="支持的 API 格式列表")
|
||||
|
||||
# Endpoint 健康度详情
|
||||
endpoint_health_details: List[Dict[str, Any]] = Field(
|
||||
default=[],
|
||||
description="每个 Endpoint 的健康度详情 [{api_format: str, health_score: float, is_active: bool}]",
|
||||
)
|
||||
|
||||
# 健康度统计
|
||||
avg_health_score: float = Field(default=1.0, description="平均健康度")
|
||||
unhealthy_endpoints: int = Field(
|
||||
default=0, description="不健康的端点数量(health_score < 0.5)"
|
||||
)
|
||||
|
||||
# 时间戳
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ========== 健康监控可视化模型 ==========
|
||||
|
||||
|
||||
class EndpointHealthEvent(BaseModel):
|
||||
"""单个端点的请求事件"""
|
||||
|
||||
timestamp: datetime
|
||||
status: str
|
||||
status_code: Optional[int] = None
|
||||
latency_ms: Optional[int] = None
|
||||
error_type: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class EndpointHealthMonitor(BaseModel):
|
||||
"""端点健康监控信息"""
|
||||
|
||||
endpoint_id: str
|
||||
api_format: str
|
||||
is_active: bool
|
||||
total_attempts: int
|
||||
success_count: int
|
||||
failed_count: int
|
||||
skipped_count: int
|
||||
success_rate: float = Field(default=1.0, description="最近事件窗口的成功率")
|
||||
last_event_at: Optional[datetime] = None
|
||||
events: List[EndpointHealthEvent] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ProviderEndpointHealthMonitorResponse(BaseModel):
|
||||
"""Provider 下所有端点的健康监控"""
|
||||
|
||||
provider_id: str
|
||||
provider_name: str
|
||||
generated_at: datetime
|
||||
endpoints: List[EndpointHealthMonitor] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ApiFormatHealthMonitor(BaseModel):
|
||||
"""按 API 格式聚合的健康监控信息"""
|
||||
|
||||
api_format: str
|
||||
total_attempts: int
|
||||
success_count: int
|
||||
failed_count: int
|
||||
skipped_count: int
|
||||
success_rate: float = Field(default=1.0, description="最近事件窗口的成功率")
|
||||
provider_count: int = Field(default=0, description="参与统计的 Provider 数量")
|
||||
key_count: int = Field(default=0, description="参与统计的 API Key 数量")
|
||||
last_event_at: Optional[datetime] = None
|
||||
events: List[EndpointHealthEvent] = Field(default_factory=list)
|
||||
timeline: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Usage 表生成的健康时间线(healthy/warning/unhealthy/unknown)",
|
||||
)
|
||||
time_range_start: Optional[datetime] = Field(
|
||||
default=None, description="时间线所覆盖区间的开始时间"
|
||||
)
|
||||
time_range_end: Optional[datetime] = Field(
|
||||
default=None, description="时间线所覆盖区间的结束时间"
|
||||
)
|
||||
|
||||
|
||||
class ApiFormatHealthMonitorResponse(BaseModel):
|
||||
"""所有 API 格式的健康监控汇总"""
|
||||
|
||||
generated_at: datetime
|
||||
formats: List[ApiFormatHealthMonitor] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ========== 公开健康监控模型(不含敏感信息) ==========
|
||||
|
||||
|
||||
class PublicHealthEvent(BaseModel):
|
||||
"""公开版单个请求事件(不含敏感信息如 provider_id、key_id)"""
|
||||
|
||||
timestamp: datetime
|
||||
status: str
|
||||
status_code: Optional[int] = None
|
||||
latency_ms: Optional[int] = None
|
||||
error_type: Optional[str] = None
|
||||
|
||||
|
||||
class PublicApiFormatHealthMonitor(BaseModel):
|
||||
"""公开版 API 格式健康监控信息(不含敏感信息)"""
|
||||
|
||||
api_format: str
|
||||
api_path: str = Field(default="/", description="该 API 格式的本站请求路径")
|
||||
total_attempts: int = Field(default=0, description="总请求次数")
|
||||
success_count: int = Field(default=0, description="成功次数")
|
||||
failed_count: int = Field(default=0, description="失败次数")
|
||||
skipped_count: int = Field(default=0, description="跳过次数")
|
||||
success_rate: float = Field(default=1.0, description="成功率")
|
||||
last_event_at: Optional[datetime] = None
|
||||
events: List[PublicHealthEvent] = Field(default_factory=list, description="事件列表")
|
||||
timeline: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Usage 表生成的健康时间线(healthy/warning/unhealthy/unknown)",
|
||||
)
|
||||
time_range_start: Optional[datetime] = Field(
|
||||
default=None, description="时间线覆盖区间开始时间"
|
||||
)
|
||||
time_range_end: Optional[datetime] = Field(
|
||||
default=None, description="时间线覆盖区间结束时间"
|
||||
)
|
||||
|
||||
|
||||
class PublicApiFormatHealthMonitorResponse(BaseModel):
|
||||
"""公开版健康监控汇总(不含敏感信息)"""
|
||||
|
||||
generated_at: datetime
|
||||
formats: List[PublicApiFormatHealthMonitor] = Field(default_factory=list)
|
||||
468
src/models/gemini.py
Normal file
468
src/models/gemini.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""
|
||||
Google Gemini API 请求/响应模型
|
||||
|
||||
支持 Gemini 3 Pro 及之前版本的 API 格式
|
||||
参考文档: https://ai.google.dev/gemini-api/docs/gemini-3
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class BaseModelWithExtras(BaseModel):
|
||||
"""允许额外字段的基础模型"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 内容块定义
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiTextPart(BaseModelWithExtras):
|
||||
"""文本内容块"""
|
||||
|
||||
text: str
|
||||
thought_signature: Optional[str] = Field(
|
||||
default=None,
|
||||
alias="thoughtSignature",
|
||||
description="Gemini 3 思维签名,用于维护多轮对话中的推理上下文",
|
||||
)
|
||||
|
||||
|
||||
class GeminiInlineData(BaseModelWithExtras):
|
||||
"""内联数据(图片等)"""
|
||||
|
||||
mime_type: str = Field(alias="mimeType")
|
||||
data: str # base64 encoded
|
||||
|
||||
|
||||
class GeminiMediaResolution(BaseModelWithExtras):
|
||||
"""
|
||||
媒体分辨率配置 (Gemini 3 新增)
|
||||
|
||||
控制图片/视频的处理分辨率:
|
||||
- media_resolution_low: 图片 280 tokens, 视频 70 tokens/帧
|
||||
- media_resolution_medium: 图片 560 tokens, 视频 70 tokens/帧
|
||||
- media_resolution_high: 图片 1120 tokens, 视频 280 tokens/帧
|
||||
"""
|
||||
|
||||
level: Literal["media_resolution_low", "media_resolution_medium", "media_resolution_high"]
|
||||
|
||||
|
||||
class GeminiFileData(BaseModelWithExtras):
|
||||
"""文件引用"""
|
||||
|
||||
mime_type: Optional[str] = Field(default=None, alias="mimeType")
|
||||
file_uri: str = Field(alias="fileUri")
|
||||
|
||||
|
||||
class GeminiFunctionCall(BaseModelWithExtras):
|
||||
"""函数调用"""
|
||||
|
||||
name: str
|
||||
args: Dict[str, Any]
|
||||
|
||||
|
||||
class GeminiFunctionResponse(BaseModelWithExtras):
|
||||
"""函数响应"""
|
||||
|
||||
name: str
|
||||
response: Dict[str, Any]
|
||||
|
||||
|
||||
class GeminiPart(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini 内容部分 - 支持多种类型
|
||||
|
||||
可以是以下类型之一:
|
||||
- text: 文本内容
|
||||
- inline_data: 内联数据(图片等)
|
||||
- file_data: 文件引用
|
||||
- function_call: 函数调用
|
||||
- function_response: 函数响应
|
||||
|
||||
Gemini 3 新增:
|
||||
- thought_signature: 思维签名,用于维护推理上下文
|
||||
- media_resolution: 媒体分辨率配置
|
||||
"""
|
||||
|
||||
text: Optional[str] = None
|
||||
inline_data: Optional[GeminiInlineData] = Field(default=None, alias="inlineData")
|
||||
file_data: Optional[GeminiFileData] = Field(default=None, alias="fileData")
|
||||
function_call: Optional[GeminiFunctionCall] = Field(default=None, alias="functionCall")
|
||||
function_response: Optional[GeminiFunctionResponse] = Field(
|
||||
default=None, alias="functionResponse"
|
||||
)
|
||||
# Gemini 3 新增
|
||||
thought_signature: Optional[str] = Field(
|
||||
default=None,
|
||||
alias="thoughtSignature",
|
||||
description="思维签名,用于函数调用和图片生成的上下文保持",
|
||||
)
|
||||
media_resolution: Optional[GeminiMediaResolution] = Field(
|
||||
default=None, alias="mediaResolution", description="媒体分辨率配置"
|
||||
)
|
||||
|
||||
|
||||
class GeminiContent(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini 消息内容
|
||||
|
||||
对应 Gemini API 的 Content 对象
|
||||
"""
|
||||
|
||||
role: Optional[Literal["user", "model"]] = None
|
||||
parts: List[Union[GeminiPart, Dict[str, Any]]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 配置定义
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiImageConfig(BaseModelWithExtras):
|
||||
"""
|
||||
图片生成配置 (Gemini 3 Pro Image)
|
||||
|
||||
用于 gemini-3-pro-image-preview 模型
|
||||
"""
|
||||
|
||||
aspect_ratio: Optional[str] = Field(
|
||||
default=None, alias="aspectRatio", description="图片宽高比,如 '16:9', '1:1', '4:3'"
|
||||
)
|
||||
image_size: Optional[Literal["2K", "4K"]] = Field(
|
||||
default=None, alias="imageSize", description="图片尺寸: 2K 或 4K"
|
||||
)
|
||||
|
||||
|
||||
class GeminiGenerationConfig(BaseModelWithExtras):
|
||||
"""
|
||||
生成配置
|
||||
|
||||
Gemini 3 新增:
|
||||
- thinking_level: 思考深度 (low/medium/high)
|
||||
- response_json_schema: 结构化输出的 JSON Schema
|
||||
- image_config: 图片生成配置
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = Field(
|
||||
default=None, description="采样温度,Gemini 3 建议保持默认值 1.0"
|
||||
)
|
||||
top_p: Optional[float] = Field(default=None, alias="topP")
|
||||
top_k: Optional[int] = Field(default=None, alias="topK")
|
||||
max_output_tokens: Optional[int] = Field(default=None, alias="maxOutputTokens")
|
||||
stop_sequences: Optional[List[str]] = Field(default=None, alias="stopSequences")
|
||||
candidate_count: Optional[int] = Field(default=None, alias="candidateCount")
|
||||
response_mime_type: Optional[str] = Field(default=None, alias="responseMimeType")
|
||||
response_schema: Optional[Dict[str, Any]] = Field(default=None, alias="responseSchema")
|
||||
# Gemini 3 新增
|
||||
response_json_schema: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="responseJsonSchema", description="结构化输出的 JSON Schema"
|
||||
)
|
||||
thinking_level: Optional[Literal["low", "medium", "high"]] = Field(
|
||||
default=None,
|
||||
alias="thinkingLevel",
|
||||
description="Gemini 3 思考深度: low(快速), medium(平衡), high(深度推理,默认)",
|
||||
)
|
||||
image_config: Optional[GeminiImageConfig] = Field(
|
||||
default=None, alias="imageConfig", description="图片生成配置"
|
||||
)
|
||||
|
||||
|
||||
class GeminiSafetySettings(BaseModelWithExtras):
|
||||
"""安全设置"""
|
||||
|
||||
category: str
|
||||
threshold: str
|
||||
|
||||
|
||||
class GeminiFunctionDeclaration(BaseModelWithExtras):
|
||||
"""函数声明"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class GeminiGoogleSearchTool(BaseModelWithExtras):
|
||||
"""Google Search 工具 (Gemini 3)"""
|
||||
|
||||
pass # 空对象表示启用
|
||||
|
||||
|
||||
class GeminiUrlContextTool(BaseModelWithExtras):
|
||||
"""URL Context 工具 (Gemini 3)"""
|
||||
|
||||
pass # 空对象表示启用
|
||||
|
||||
|
||||
class GeminiCodeExecutionTool(BaseModelWithExtras):
|
||||
"""代码执行工具"""
|
||||
|
||||
pass # 空对象表示启用
|
||||
|
||||
|
||||
class GeminiTool(BaseModelWithExtras):
|
||||
"""
|
||||
工具定义
|
||||
|
||||
支持的工具类型:
|
||||
- function_declarations: 自定义函数
|
||||
- code_execution: 代码执行
|
||||
- google_search: Google 搜索 (Gemini 3)
|
||||
- url_context: URL 上下文 (Gemini 3)
|
||||
"""
|
||||
|
||||
function_declarations: Optional[List[GeminiFunctionDeclaration]] = Field(
|
||||
default=None, alias="functionDeclarations"
|
||||
)
|
||||
code_execution: Optional[Dict[str, Any]] = Field(default=None, alias="codeExecution")
|
||||
# Gemini 3 内置工具
|
||||
google_search: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="googleSearch", description="启用 Google 搜索工具"
|
||||
)
|
||||
url_context: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="urlContext", description="启用 URL 上下文工具"
|
||||
)
|
||||
|
||||
|
||||
class GeminiToolConfig(BaseModelWithExtras):
|
||||
"""工具配置"""
|
||||
|
||||
function_calling_config: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="functionCallingConfig"
|
||||
)
|
||||
|
||||
|
||||
class GeminiSystemInstruction(BaseModelWithExtras):
|
||||
"""系统指令"""
|
||||
|
||||
parts: List[Union[GeminiPart, Dict[str, Any]]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 请求模型
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiGenerateContentRequest(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini generateContent 请求模型
|
||||
|
||||
对应 POST /v1beta/models/{model}:generateContent 端点
|
||||
"""
|
||||
|
||||
contents: List[GeminiContent]
|
||||
system_instruction: Optional[GeminiSystemInstruction] = Field(
|
||||
default=None, alias="systemInstruction"
|
||||
)
|
||||
tools: Optional[List[GeminiTool]] = None
|
||||
tool_config: Optional[GeminiToolConfig] = Field(default=None, alias="toolConfig")
|
||||
safety_settings: Optional[List[GeminiSafetySettings]] = Field(
|
||||
default=None, alias="safetySettings"
|
||||
)
|
||||
generation_config: Optional[GeminiGenerationConfig] = Field(
|
||||
default=None, alias="generationConfig"
|
||||
)
|
||||
|
||||
|
||||
class GeminiStreamGenerateContentRequest(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini streamGenerateContent 请求模型
|
||||
|
||||
对应 POST /v1beta/models/{model}:streamGenerateContent 端点
|
||||
与 generateContent 相同,但返回流式响应
|
||||
"""
|
||||
|
||||
contents: List[GeminiContent]
|
||||
system_instruction: Optional[GeminiSystemInstruction] = Field(
|
||||
default=None, alias="systemInstruction"
|
||||
)
|
||||
tools: Optional[List[GeminiTool]] = None
|
||||
tool_config: Optional[GeminiToolConfig] = Field(default=None, alias="toolConfig")
|
||||
safety_settings: Optional[List[GeminiSafetySettings]] = Field(
|
||||
default=None, alias="safetySettings"
|
||||
)
|
||||
generation_config: Optional[GeminiGenerationConfig] = Field(
|
||||
default=None, alias="generationConfig"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 统一请求模型(用于内部处理)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiRequest(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini 统一请求模型
|
||||
|
||||
内部使用,统一处理 generateContent 和 streamGenerateContent
|
||||
|
||||
注意: Gemini API 通过 URL 端点区分流式/非流式请求:
|
||||
- generateContent - 非流式
|
||||
- streamGenerateContent - 流式
|
||||
请求体中不应包含 stream 字段
|
||||
"""
|
||||
|
||||
model: Optional[str] = Field(default=None, description="模型名称,从 URL 路径提取(内部使用)")
|
||||
contents: List[GeminiContent]
|
||||
system_instruction: Optional[GeminiSystemInstruction] = Field(
|
||||
default=None, alias="systemInstruction"
|
||||
)
|
||||
tools: Optional[List[GeminiTool]] = None
|
||||
tool_config: Optional[GeminiToolConfig] = Field(default=None, alias="toolConfig")
|
||||
safety_settings: Optional[List[GeminiSafetySettings]] = Field(
|
||||
default=None, alias="safetySettings"
|
||||
)
|
||||
generation_config: Optional[GeminiGenerationConfig] = Field(
|
||||
default=None, alias="generationConfig"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 响应模型
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiUsageMetadata(BaseModelWithExtras):
|
||||
"""Token 使用量"""
|
||||
|
||||
prompt_token_count: int = Field(default=0, alias="promptTokenCount")
|
||||
candidates_token_count: int = Field(default=0, alias="candidatesTokenCount")
|
||||
total_token_count: int = Field(default=0, alias="totalTokenCount")
|
||||
cached_content_token_count: Optional[int] = Field(default=None, alias="cachedContentTokenCount")
|
||||
|
||||
|
||||
class GeminiSafetyRating(BaseModelWithExtras):
|
||||
"""安全评级"""
|
||||
|
||||
category: str
|
||||
probability: str
|
||||
blocked: Optional[bool] = None
|
||||
|
||||
|
||||
class GeminiCitationSource(BaseModelWithExtras):
|
||||
"""引用来源"""
|
||||
|
||||
start_index: Optional[int] = Field(default=None, alias="startIndex")
|
||||
end_index: Optional[int] = Field(default=None, alias="endIndex")
|
||||
uri: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
|
||||
|
||||
class GeminiCitationMetadata(BaseModelWithExtras):
|
||||
"""引用元数据"""
|
||||
|
||||
citation_sources: Optional[List[GeminiCitationSource]] = Field(
|
||||
default=None, alias="citationSources"
|
||||
)
|
||||
|
||||
|
||||
class GeminiGroundingMetadata(BaseModelWithExtras):
|
||||
"""
|
||||
Grounding 元数据 (Gemini 3)
|
||||
|
||||
当使用 Google Search 工具时返回
|
||||
"""
|
||||
|
||||
search_entry_point: Optional[Dict[str, Any]] = Field(default=None, alias="searchEntryPoint")
|
||||
grounding_chunks: Optional[List[Dict[str, Any]]] = Field(default=None, alias="groundingChunks")
|
||||
grounding_supports: Optional[List[Dict[str, Any]]] = Field(
|
||||
default=None, alias="groundingSupports"
|
||||
)
|
||||
web_search_queries: Optional[List[str]] = Field(default=None, alias="webSearchQueries")
|
||||
|
||||
|
||||
class GeminiCandidate(BaseModelWithExtras):
|
||||
"""候选响应"""
|
||||
|
||||
content: Optional[GeminiContent] = None
|
||||
finish_reason: Optional[str] = Field(default=None, alias="finishReason")
|
||||
safety_ratings: Optional[List[GeminiSafetyRating]] = Field(default=None, alias="safetyRatings")
|
||||
citation_metadata: Optional[GeminiCitationMetadata] = Field(
|
||||
default=None, alias="citationMetadata"
|
||||
)
|
||||
grounding_metadata: Optional[GeminiGroundingMetadata] = Field(
|
||||
default=None, alias="groundingMetadata"
|
||||
)
|
||||
token_count: Optional[int] = Field(default=None, alias="tokenCount")
|
||||
index: Optional[int] = None
|
||||
|
||||
|
||||
class GeminiPromptFeedback(BaseModelWithExtras):
|
||||
"""提示反馈"""
|
||||
|
||||
block_reason: Optional[str] = Field(default=None, alias="blockReason")
|
||||
safety_ratings: Optional[List[GeminiSafetyRating]] = Field(default=None, alias="safetyRatings")
|
||||
|
||||
|
||||
class GeminiGenerateContentResponse(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini generateContent 响应模型
|
||||
|
||||
对应 generateContent 端点的响应体
|
||||
"""
|
||||
|
||||
candidates: Optional[List[GeminiCandidate]] = None
|
||||
prompt_feedback: Optional[GeminiPromptFeedback] = Field(default=None, alias="promptFeedback")
|
||||
usage_metadata: Optional[GeminiUsageMetadata] = Field(default=None, alias="usageMetadata")
|
||||
model_version: Optional[str] = Field(default=None, alias="modelVersion")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 流式响应模型
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiStreamChunk(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini 流式响应块
|
||||
|
||||
流式响应中的单个数据块,结构与完整响应相同
|
||||
"""
|
||||
|
||||
candidates: Optional[List[GeminiCandidate]] = None
|
||||
prompt_feedback: Optional[GeminiPromptFeedback] = Field(default=None, alias="promptFeedback")
|
||||
usage_metadata: Optional[GeminiUsageMetadata] = Field(default=None, alias="usageMetadata")
|
||||
model_version: Optional[str] = Field(default=None, alias="modelVersion")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 错误响应
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiErrorDetail(BaseModelWithExtras):
|
||||
"""错误详情"""
|
||||
|
||||
type: Optional[str] = Field(default=None, alias="@type")
|
||||
reason: Optional[str] = None
|
||||
domain: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class GeminiError(BaseModelWithExtras):
|
||||
"""错误信息"""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
status: str
|
||||
details: Optional[List[GeminiErrorDetail]] = None
|
||||
|
||||
|
||||
class GeminiErrorResponse(BaseModelWithExtras):
|
||||
"""错误响应"""
|
||||
|
||||
error: GeminiError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thought Signature 常量
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 用于从其他模型迁移对话时绕过签名验证
|
||||
DUMMY_THOUGHT_SIGNATURE = "context_engineering_is_the_way_to_go"
|
||||
153
src/models/openai.py
Normal file
153
src/models/openai.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
OpenAI API 数据模型定义
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
# 配置允许额外字段,以支持 API 的新特性
|
||||
class BaseModelWithExtras(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class OpenAIMessage(BaseModelWithExtras):
|
||||
"""OpenAI消息模型"""
|
||||
|
||||
role: str
|
||||
content: Optional[Union[str, List[Dict[str, Any]]]] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAIFunction(BaseModelWithExtras):
|
||||
"""OpenAI函数定义"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Dict[str, Any]
|
||||
|
||||
|
||||
class OpenAITool(BaseModelWithExtras):
|
||||
"""OpenAI工具定义"""
|
||||
|
||||
type: str = "function"
|
||||
function: OpenAIFunction
|
||||
|
||||
|
||||
class OpenAIRequest(BaseModelWithExtras):
|
||||
"""OpenAI请求模型"""
|
||||
|
||||
model: str
|
||||
messages: List[OpenAIMessage]
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = None
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
tools: Optional[List[OpenAITool]] = None
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
response_format: Optional[Dict[str, Any]] = None
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class ResponsesInputMessage(BaseModelWithExtras):
|
||||
"""Responses API 输入消息"""
|
||||
|
||||
type: str = "message"
|
||||
role: str
|
||||
content: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class ResponsesReasoningConfig(BaseModelWithExtras):
|
||||
"""Responses API 推理配置"""
|
||||
|
||||
effort: str = "high" # low, medium, high
|
||||
summary: str = "auto" # auto, off
|
||||
|
||||
|
||||
class ResponsesRequest(BaseModelWithExtras):
|
||||
"""OpenAI Responses API 请求模型(用于 Claude Code 等客户端)"""
|
||||
|
||||
model: str
|
||||
instructions: Optional[str] = None
|
||||
input: List[ResponsesInputMessage]
|
||||
tools: Optional[List[Dict[str, Any]]] = None
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
|
||||
parallel_tool_calls: Optional[bool] = False
|
||||
reasoning: Optional[ResponsesReasoningConfig] = None
|
||||
store: Optional[bool] = False
|
||||
stream: Optional[bool] = True
|
||||
include: Optional[List[str]] = None
|
||||
prompt_cache_key: Optional[str] = None
|
||||
# 其他参数
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
|
||||
|
||||
class OpenAIUsage(BaseModelWithExtras):
|
||||
"""OpenAI使用统计"""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class OpenAIChoice(BaseModelWithExtras):
|
||||
"""OpenAI选择结果"""
|
||||
|
||||
index: int
|
||||
message: OpenAIMessage
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class OpenAIResponse(BaseModelWithExtras):
|
||||
"""OpenAI响应模型"""
|
||||
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[OpenAIChoice]
|
||||
usage: Optional[OpenAIUsage] = None
|
||||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAIStreamDelta(BaseModelWithExtras):
|
||||
"""OpenAI流式响应增量"""
|
||||
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class OpenAIStreamChoice(BaseModelWithExtras):
|
||||
"""OpenAI流式响应选择"""
|
||||
|
||||
index: int
|
||||
delta: OpenAIStreamDelta
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class OpenAIStreamResponse(BaseModelWithExtras):
|
||||
"""OpenAI流式响应模型"""
|
||||
|
||||
id: str
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[OpenAIStreamChoice]
|
||||
system_fingerprint: Optional[str] = None
|
||||
435
src/models/pydantic_models.py
Normal file
435
src/models/pydantic_models.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Pydantic 数据模型(阶段一统一模型管理)
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from .api import ModelCreate
|
||||
|
||||
|
||||
# ========== 阶梯计费相关模型 ==========
|
||||
|
||||
|
||||
class CacheTTLPricing(BaseModel):
|
||||
"""缓存时长定价配置"""
|
||||
|
||||
ttl_minutes: int = Field(..., ge=1, description="缓存时长(分钟)")
|
||||
cache_creation_price_per_1m: float = Field(..., ge=0, description="该时长的缓存创建价格/M tokens")
|
||||
|
||||
|
||||
class PricingTier(BaseModel):
|
||||
"""单个价格阶梯配置"""
|
||||
|
||||
up_to: Optional[int] = Field(
|
||||
None,
|
||||
ge=1,
|
||||
description="阶梯上限(tokens),null 表示无上限(最后一个阶梯)"
|
||||
)
|
||||
input_price_per_1m: float = Field(..., ge=0, description="输入价格/M tokens")
|
||||
output_price_per_1m: float = Field(..., ge=0, description="输出价格/M tokens")
|
||||
cache_creation_price_per_1m: Optional[float] = Field(
|
||||
None, ge=0, description="缓存创建价格/M tokens"
|
||||
)
|
||||
cache_read_price_per_1m: Optional[float] = Field(
|
||||
None, ge=0, description="缓存读取价格/M tokens"
|
||||
)
|
||||
cache_ttl_pricing: Optional[List[CacheTTLPricing]] = Field(
|
||||
None, description="按缓存时长分价格(可选)"
|
||||
)
|
||||
|
||||
|
||||
class TieredPricingConfig(BaseModel):
|
||||
"""阶梯计费配置"""
|
||||
|
||||
tiers: List[PricingTier] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="价格阶梯列表,按 up_to 升序排列"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_tiers(self) -> "TieredPricingConfig":
|
||||
"""验证阶梯配置的合法性"""
|
||||
tiers = self.tiers
|
||||
if not tiers:
|
||||
raise ValueError("至少需要一个价格阶梯")
|
||||
|
||||
# 检查阶梯顺序和唯一性
|
||||
prev_up_to = 0
|
||||
has_unlimited = False
|
||||
|
||||
for i, tier in enumerate(tiers):
|
||||
if has_unlimited:
|
||||
raise ValueError("无上限阶梯(up_to=null)必须是最后一个")
|
||||
|
||||
if tier.up_to is None:
|
||||
has_unlimited = True
|
||||
else:
|
||||
if tier.up_to <= prev_up_to:
|
||||
raise ValueError(
|
||||
f"阶梯 {i+1} 的 up_to ({tier.up_to}) 必须大于前一个阶梯 ({prev_up_to})"
|
||||
)
|
||||
prev_up_to = tier.up_to
|
||||
|
||||
# 验证缓存时长定价顺序
|
||||
if tier.cache_ttl_pricing:
|
||||
prev_ttl = 0
|
||||
for ttl_pricing in tier.cache_ttl_pricing:
|
||||
if ttl_pricing.ttl_minutes <= prev_ttl:
|
||||
raise ValueError(
|
||||
f"cache_ttl_pricing 必须按 ttl_minutes 升序排列"
|
||||
)
|
||||
prev_ttl = ttl_pricing.ttl_minutes
|
||||
|
||||
# 最后一个阶梯必须是无上限的
|
||||
if not has_unlimited:
|
||||
raise ValueError("最后一个阶梯必须设置 up_to=null(无上限)")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
# ========== 其他模型 ==========
|
||||
|
||||
|
||||
class ModelCapabilities(BaseModel):
|
||||
"""模型能力聚合"""
|
||||
|
||||
supports_vision: bool = False
|
||||
supports_function_calling: bool = False
|
||||
supports_streaming: bool = False
|
||||
|
||||
|
||||
class ModelPriceRange(BaseModel):
|
||||
"""统一模型价格区间"""
|
||||
|
||||
min_input: Optional[float] = None
|
||||
max_input: Optional[float] = None
|
||||
min_output: Optional[float] = None
|
||||
max_output: Optional[float] = None
|
||||
|
||||
|
||||
class ModelCatalogProviderDetail(BaseModel):
|
||||
"""统一模型目录中的关联提供商信息"""
|
||||
|
||||
provider_id: str
|
||||
provider_name: str
|
||||
provider_display_name: Optional[str]
|
||||
model_id: Optional[str]
|
||||
target_model: str
|
||||
input_price_per_1m: Optional[float]
|
||||
output_price_per_1m: Optional[float]
|
||||
cache_creation_price_per_1m: Optional[float]
|
||||
cache_read_price_per_1m: Optional[float]
|
||||
cache_1h_creation_price_per_1m: Optional[float] = None # 1h 缓存创建价格
|
||||
price_per_request: Optional[float] = None # 按次计费价格
|
||||
effective_tiered_pricing: Optional[Dict[str, Any]] = None # 有效阶梯计费配置(含继承)
|
||||
tier_count: int = 1 # 阶梯数量
|
||||
supports_vision: Optional[bool] = None
|
||||
supports_function_calling: Optional[bool] = None
|
||||
supports_streaming: Optional[bool] = None
|
||||
is_active: bool
|
||||
mapping_id: Optional[str]
|
||||
|
||||
|
||||
class OrphanedModel(BaseModel):
|
||||
"""孤立的统一模型(Mapping 存在但 GlobalModel 缺失)"""
|
||||
|
||||
alias: str # 别名
|
||||
global_model_name: Optional[str] # 关联的 GlobalModel 名称(如果有)
|
||||
mapping_count: int
|
||||
|
||||
|
||||
class ModelCatalogItem(BaseModel):
|
||||
"""统一模型目录条目(方案 A:基于 GlobalModel)"""
|
||||
|
||||
global_model_name: str # GlobalModel.name
|
||||
display_name: str # GlobalModel.display_name
|
||||
description: Optional[str] # GlobalModel.description
|
||||
aliases: List[str] # 所有指向该 GlobalModel 的别名列表
|
||||
providers: List[ModelCatalogProviderDetail] # 支持该模型的 Provider 列表
|
||||
price_range: ModelPriceRange # 价格区间(从所有 Provider 的 Model 中聚合)
|
||||
total_providers: int
|
||||
capabilities: ModelCapabilities # 能力聚合(从所有 Provider 的 Model 中聚合)
|
||||
|
||||
|
||||
class ModelCatalogResponse(BaseModel):
|
||||
"""统一模型目录响应"""
|
||||
|
||||
models: List[ModelCatalogItem]
|
||||
total: int
|
||||
orphaned_models: List[OrphanedModel]
|
||||
|
||||
|
||||
class ProviderModelPriceInfo(BaseModel):
|
||||
"""Provider 维度的模型价格信息"""
|
||||
|
||||
input_price_per_1m: Optional[float]
|
||||
output_price_per_1m: Optional[float]
|
||||
cache_creation_price_per_1m: Optional[float]
|
||||
cache_read_price_per_1m: Optional[float]
|
||||
price_per_request: Optional[float] = None # 按次计费价格
|
||||
|
||||
|
||||
class ProviderAvailableSourceModel(BaseModel):
|
||||
"""Provider 支持的统一模型条目(方案 A)"""
|
||||
|
||||
global_model_name: str # GlobalModel.name
|
||||
display_name: str # GlobalModel.display_name
|
||||
provider_model_name: str # Model.provider_model_name (Provider 侧的模型名)
|
||||
has_alias: bool # 是否有别名指向该 GlobalModel
|
||||
aliases: List[str] # 别名列表
|
||||
model_id: Optional[str] # Model.id
|
||||
price: ProviderModelPriceInfo
|
||||
capabilities: ModelCapabilities
|
||||
is_active: bool
|
||||
|
||||
|
||||
class ProviderAvailableSourceModelsResponse(BaseModel):
|
||||
"""Provider 可用统一模型响应"""
|
||||
|
||||
models: List[ProviderAvailableSourceModel]
|
||||
total: int
|
||||
|
||||
|
||||
class BatchAssignProviderConfig(BaseModel):
|
||||
"""批量添加映射的 Provider 配置"""
|
||||
|
||||
provider_id: str
|
||||
create_model: bool = Field(False, description="是否需要创建新的 Model")
|
||||
model_data: Optional[ModelCreate] = Field(
|
||||
None, description="create_model=true 时需要提供的模型配置", alias="model_config"
|
||||
)
|
||||
model_id: Optional[str] = Field(None, description="create_model=false 时需要提供的现有模型 ID")
|
||||
|
||||
|
||||
class BatchAssignModelMappingRequest(BaseModel):
|
||||
"""批量添加模型映射请求(方案 A:暂不支持,需要重构)"""
|
||||
|
||||
global_model_id: str # 要分配的 GlobalModel ID
|
||||
providers: List[BatchAssignProviderConfig]
|
||||
|
||||
|
||||
class BatchAssignProviderResult(BaseModel):
|
||||
"""批量映射结果条目"""
|
||||
|
||||
provider_id: str
|
||||
mapping_id: Optional[str]
|
||||
created_model: bool
|
||||
model_id: Optional[str]
|
||||
updated: bool = False
|
||||
|
||||
|
||||
class BatchAssignError(BaseModel):
|
||||
"""批量映射错误信息"""
|
||||
|
||||
provider_id: str
|
||||
error: str
|
||||
|
||||
|
||||
class BatchAssignModelMappingResponse(BaseModel):
|
||||
"""批量映射响应"""
|
||||
|
||||
success: bool
|
||||
created_mappings: List[BatchAssignProviderResult]
|
||||
errors: List[BatchAssignError]
|
||||
|
||||
|
||||
# ========== 阶段二:GlobalModel 相关模型 ==========
|
||||
|
||||
|
||||
class GlobalModelCreate(BaseModel):
|
||||
"""创建 GlobalModel 请求"""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100, description="统一模型名(唯一)")
|
||||
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
official_url: Optional[str] = Field(None, max_length=500, description="官方文档链接")
|
||||
icon_url: Optional[str] = Field(None, max_length=500, description="图标 URL")
|
||||
# 按次计费配置(可选,与阶梯计费叠加)
|
||||
default_price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
|
||||
# 统一阶梯计费配置(必填)
|
||||
# 固定价格也用单阶梯表示: {"tiers": [{"up_to": null, "input_price_per_1m": X, ...}]}
|
||||
default_tiered_pricing: TieredPricingConfig = Field(
|
||||
..., description="阶梯计费配置(固定价格用单阶梯表示)"
|
||||
)
|
||||
# 默认能力配置
|
||||
default_supports_vision: Optional[bool] = Field(False, description="默认是否支持视觉")
|
||||
default_supports_function_calling: Optional[bool] = Field(
|
||||
False, description="默认是否支持函数调用"
|
||||
)
|
||||
default_supports_streaming: Optional[bool] = Field(True, description="默认是否支持流式输出")
|
||||
default_supports_extended_thinking: Optional[bool] = Field(
|
||||
False, description="默认是否支持扩展思考"
|
||||
)
|
||||
default_supports_image_generation: Optional[bool] = Field(
|
||||
False, description="默认是否支持图像生成"
|
||||
)
|
||||
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"])
|
||||
supported_capabilities: Optional[List[str]] = Field(
|
||||
None, description="支持的 Key 能力列表"
|
||||
)
|
||||
is_active: Optional[bool] = Field(True, description="是否激活")
|
||||
|
||||
|
||||
class GlobalModelUpdate(BaseModel):
|
||||
"""更新 GlobalModel 请求"""
|
||||
|
||||
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
official_url: Optional[str] = Field(None, max_length=500)
|
||||
icon_url: Optional[str] = Field(None, max_length=500)
|
||||
is_active: Optional[bool] = None
|
||||
# 按次计费配置
|
||||
default_price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing: Optional[TieredPricingConfig] = Field(
|
||||
None, description="阶梯计费配置"
|
||||
)
|
||||
# 默认能力配置
|
||||
default_supports_vision: Optional[bool] = None
|
||||
default_supports_function_calling: Optional[bool] = None
|
||||
default_supports_streaming: Optional[bool] = None
|
||||
default_supports_extended_thinking: Optional[bool] = None
|
||||
default_supports_image_generation: Optional[bool] = None
|
||||
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"])
|
||||
supported_capabilities: Optional[List[str]] = Field(
|
||||
None, description="支持的 Key 能力列表"
|
||||
)
|
||||
|
||||
|
||||
class GlobalModelResponse(BaseModel):
|
||||
"""GlobalModel 响应"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
official_url: Optional[str]
|
||||
icon_url: Optional[str]
|
||||
is_active: bool
|
||||
# 按次计费配置
|
||||
default_price_per_request: Optional[float] = Field(None, description="每次请求固定费用")
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing: TieredPricingConfig = Field(
|
||||
..., description="阶梯计费配置"
|
||||
)
|
||||
# 默认能力配置
|
||||
default_supports_vision: Optional[bool]
|
||||
default_supports_function_calling: Optional[bool]
|
||||
default_supports_streaming: Optional[bool]
|
||||
default_supports_extended_thinking: Optional[bool]
|
||||
default_supports_image_generation: Optional[bool]
|
||||
# Key 能力配置 - 模型支持的能力列表
|
||||
supported_capabilities: Optional[List[str]] = Field(
|
||||
default=None, description="支持的 Key 能力列表"
|
||||
)
|
||||
# 统计数据(可选)
|
||||
provider_count: Optional[int] = Field(default=0, description="支持的 Provider 数量")
|
||||
alias_count: Optional[int] = Field(default=0, description="别名数量")
|
||||
usage_count: Optional[int] = Field(default=0, description="调用次数")
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class GlobalModelWithStats(GlobalModelResponse):
|
||||
"""带统计信息的 GlobalModel"""
|
||||
|
||||
total_models: int = Field(..., description="关联的 Model 数量")
|
||||
total_providers: int = Field(..., description="支持的 Provider 数量")
|
||||
price_range: ModelPriceRange
|
||||
|
||||
|
||||
class GlobalModelListResponse(BaseModel):
|
||||
"""GlobalModel 列表响应"""
|
||||
|
||||
models: List[GlobalModelResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class BatchAssignToProvidersRequest(BaseModel):
|
||||
"""批量为 Provider 添加 GlobalModel 实现"""
|
||||
|
||||
provider_ids: List[str] = Field(..., min_items=1, description="Provider ID 列表")
|
||||
create_models: bool = Field(default=False, description="是否自动创建 Model 记录")
|
||||
|
||||
|
||||
class BatchAssignToProvidersResponse(BaseModel):
|
||||
"""批量分配响应"""
|
||||
|
||||
success: List[dict]
|
||||
errors: List[dict]
|
||||
|
||||
|
||||
class BatchAssignModelsToProviderRequest(BaseModel):
|
||||
"""批量为 Provider 关联 GlobalModel"""
|
||||
|
||||
global_model_ids: List[str] = Field(..., min_length=1, description="GlobalModel ID 列表")
|
||||
|
||||
|
||||
class BatchAssignModelsToProviderResponse(BaseModel):
|
||||
"""批量关联 GlobalModel 到 Provider 的响应"""
|
||||
|
||||
success: List[dict]
|
||||
errors: List[dict]
|
||||
|
||||
|
||||
class UpdateModelMappingRequest(BaseModel):
|
||||
"""更新模型映射请求"""
|
||||
|
||||
source_model: Optional[str] = Field(
|
||||
None, min_length=1, max_length=200, description="源模型名或别名"
|
||||
)
|
||||
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
|
||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时为全局别名)")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class UpdateModelMappingResponse(BaseModel):
|
||||
"""更新模型映射响应"""
|
||||
|
||||
success: bool
|
||||
mapping_id: str
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class DeleteModelMappingResponse(BaseModel):
|
||||
"""删除模型映射响应"""
|
||||
|
||||
success: bool
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchAssignError",
|
||||
"BatchAssignModelMappingRequest",
|
||||
"BatchAssignModelMappingResponse",
|
||||
"BatchAssignModelsToProviderRequest",
|
||||
"BatchAssignModelsToProviderResponse",
|
||||
"BatchAssignProviderConfig",
|
||||
"BatchAssignProviderResult",
|
||||
"BatchAssignToProvidersRequest",
|
||||
"BatchAssignToProvidersResponse",
|
||||
"DeleteModelMappingResponse",
|
||||
"GlobalModelCreate",
|
||||
"GlobalModelListResponse",
|
||||
"GlobalModelResponse",
|
||||
"GlobalModelUpdate",
|
||||
"GlobalModelWithStats",
|
||||
"ModelCapabilities",
|
||||
"ModelCatalogItem",
|
||||
"ModelCatalogProviderDetail",
|
||||
"ModelCatalogResponse",
|
||||
"ModelPriceRange",
|
||||
"OrphanedModel",
|
||||
"ProviderAvailableSourceModel",
|
||||
"ProviderAvailableSourceModelsResponse",
|
||||
"ProviderModelPriceInfo",
|
||||
"UpdateModelMappingRequest",
|
||||
"UpdateModelMappingResponse",
|
||||
]
|
||||
Reference in New Issue
Block a user