mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-09 19:22:26 +08:00
- users 表:重命名 allowed_endpoints 为 allowed_api_formats(修正历史命名错误) - api_keys 表:删除 allowed_endpoints 字段(未使用的功能) - providers 表:删除 rate_limit 字段(与 rpm_limit 重复) - usage 表:重命名 provider 为 provider_name(避免与 provider_id 外键混淆) 同步更新前后端所有相关代码
339 lines
13 KiB
Python
339 lines
13 KiB
Python
"""管理员 Provider 管理路由。"""
|
||
|
||
from datetime import datetime, timezone
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, Depends, Query, Request
|
||
from pydantic import ValidationError
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.api.base.admin_adapter import AdminApiAdapter
|
||
from src.api.base.pipeline import ApiRequestPipeline
|
||
from src.core.enums import ProviderBillingType
|
||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||
from src.database import get_db
|
||
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
|
||
from src.models.database import Provider
|
||
|
||
router = APIRouter(tags=["Provider CRUD"])
|
||
pipeline = ApiRequestPipeline()
|
||
|
||
|
||
@router.get("/")
|
||
async def list_providers(
|
||
request: Request,
|
||
skip: int = Query(0, ge=0),
|
||
limit: int = Query(100, ge=1, le=500),
|
||
is_active: Optional[bool] = None,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""
|
||
获取提供商列表
|
||
|
||
获取所有提供商的基本信息列表,支持分页和状态过滤。
|
||
|
||
**查询参数**:
|
||
- `skip`: 跳过的记录数,用于分页,默认为 0
|
||
- `limit`: 返回的最大记录数,范围 1-500,默认为 100
|
||
- `is_active`: 可选的活跃状态过滤,true 仅返回活跃提供商,false 返回禁用提供商,不传则返回全部
|
||
|
||
**返回字段**:
|
||
- `id`: 提供商 ID
|
||
- `name`: 提供商名称(唯一标识)
|
||
- `display_name`: 显示名称
|
||
- `api_format`: API 格式(如 claude、openai、gemini 等)
|
||
- `base_url`: API 基础 URL
|
||
- `api_key`: API 密钥(脱敏显示)
|
||
- `priority`: 优先级
|
||
- `is_active`: 是否活跃
|
||
- `created_at`: 创建时间
|
||
- `updated_at`: 更新时间
|
||
"""
|
||
adapter = AdminListProvidersAdapter(skip=skip, limit=limit, is_active=is_active)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.post("/")
|
||
async def create_provider(request: Request, db: Session = Depends(get_db)):
|
||
"""
|
||
创建新提供商
|
||
|
||
创建一个新的 AI 模型提供商配置。
|
||
|
||
**请求体字段**:
|
||
- `name`: 提供商名称(必填,唯一,用于系统标识)
|
||
- `display_name`: 显示名称(必填)
|
||
- `description`: 描述信息(可选)
|
||
- `website`: 官网地址(可选)
|
||
- `billing_type`: 计费类型(可选,pay_as_you_go/subscription/prepaid,默认 pay_as_you_go)
|
||
- `monthly_quota_usd`: 月度配额(美元)(可选)
|
||
- `quota_reset_day`: 配额重置日期(1-31)(可选)
|
||
- `quota_last_reset_at`: 上次配额重置时间(可选)
|
||
- `quota_expires_at`: 配额过期时间(可选)
|
||
- `rpm_limit`: 每分钟请求数限制(可选)
|
||
- `provider_priority`: 提供商优先级(数字越小优先级越高,默认 100)
|
||
- `is_active`: 是否启用(默认 true)
|
||
- `concurrent_limit`: 并发限制(可选)
|
||
- `config`: 额外配置信息(JSON,可选)
|
||
|
||
**返回字段**:
|
||
- `id`: 新创建的提供商 ID
|
||
- `name`: 提供商名称
|
||
- `display_name`: 显示名称
|
||
- `message`: 成功提示信息
|
||
"""
|
||
adapter = AdminCreateProviderAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.put("/{provider_id}")
|
||
async def update_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
|
||
"""
|
||
更新提供商配置
|
||
|
||
更新指定提供商的配置信息。只需传入需要更新的字段,未传入的字段保持不变。
|
||
|
||
**路径参数**:
|
||
- `provider_id`: 提供商 ID
|
||
|
||
**请求体字段**(所有字段可选):
|
||
- `name`: 提供商名称
|
||
- `display_name`: 显示名称
|
||
- `description`: 描述信息
|
||
- `website`: 官网地址
|
||
- `billing_type`: 计费类型(pay_as_you_go/subscription/prepaid)
|
||
- `monthly_quota_usd`: 月度配额(美元)
|
||
- `quota_reset_day`: 配额重置日期(1-31)
|
||
- `quota_last_reset_at`: 上次配额重置时间
|
||
- `quota_expires_at`: 配额过期时间
|
||
- `rpm_limit`: 每分钟请求数限制
|
||
- `provider_priority`: 提供商优先级
|
||
- `is_active`: 是否启用
|
||
- `concurrent_limit`: 并发限制
|
||
- `config`: 额外配置信息(JSON)
|
||
|
||
**返回字段**:
|
||
- `id`: 提供商 ID
|
||
- `name`: 提供商名称
|
||
- `is_active`: 是否启用
|
||
- `message`: 成功提示信息
|
||
"""
|
||
adapter = AdminUpdateProviderAdapter(provider_id=provider_id)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.delete("/{provider_id}")
|
||
async def delete_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
|
||
"""
|
||
删除提供商
|
||
|
||
删除指定的提供商。注意:此操作会级联删除关联的端点、密钥和模型配置。
|
||
|
||
**路径参数**:
|
||
- `provider_id`: 提供商 ID
|
||
|
||
**返回字段**:
|
||
- `message`: 删除成功提示信息
|
||
"""
|
||
adapter = AdminDeleteProviderAdapter(provider_id=provider_id)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
class AdminListProvidersAdapter(AdminApiAdapter):
|
||
def __init__(self, skip: int, limit: int, is_active: Optional[bool]):
|
||
self.skip = skip
|
||
self.limit = limit
|
||
self.is_active = is_active
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
query = db.query(Provider)
|
||
if self.is_active is not None:
|
||
query = query.filter(Provider.is_active == self.is_active)
|
||
providers = query.offset(self.skip).limit(self.limit).all()
|
||
|
||
data = []
|
||
for provider in providers:
|
||
api_format = getattr(provider, "api_format", None)
|
||
base_url = getattr(provider, "base_url", None)
|
||
api_key = getattr(provider, "api_key", None)
|
||
priority = getattr(provider, "priority", provider.provider_priority)
|
||
|
||
data.append(
|
||
{
|
||
"id": provider.id,
|
||
"name": provider.name,
|
||
"display_name": provider.display_name,
|
||
"api_format": api_format.value if api_format else None,
|
||
"base_url": base_url,
|
||
"api_key": "***" if api_key else None,
|
||
"priority": priority,
|
||
"is_active": provider.is_active,
|
||
"created_at": provider.created_at.isoformat(),
|
||
"updated_at": provider.updated_at.isoformat() if provider.updated_at else None,
|
||
}
|
||
)
|
||
context.add_audit_metadata(
|
||
action="list_providers",
|
||
filter_is_active=self.is_active,
|
||
limit=self.limit,
|
||
skip=self.skip,
|
||
result_count=len(data),
|
||
)
|
||
return data
|
||
|
||
|
||
class AdminCreateProviderAdapter(AdminApiAdapter):
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
payload = context.ensure_json_body()
|
||
|
||
try:
|
||
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
|
||
validated_data = CreateProviderRequest.model_validate(payload)
|
||
except ValidationError as exc:
|
||
# 将 Pydantic 验证错误转换为友好的错误信息
|
||
errors = []
|
||
for error in exc.errors():
|
||
field = " -> ".join(str(x) for x in error["loc"])
|
||
errors.append(f"{field}: {error['msg']}")
|
||
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
|
||
|
||
try:
|
||
# 检查名称是否已存在
|
||
existing = db.query(Provider).filter(Provider.name == validated_data.name).first()
|
||
if existing:
|
||
raise InvalidRequestException(f"提供商名称 '{validated_data.name}' 已存在")
|
||
|
||
# 将验证后的数据转换为枚举类型
|
||
billing_type = (
|
||
ProviderBillingType(validated_data.billing_type)
|
||
if validated_data.billing_type
|
||
else ProviderBillingType.PAY_AS_YOU_GO
|
||
)
|
||
|
||
# 创建 Provider 对象
|
||
provider = Provider(
|
||
name=validated_data.name,
|
||
display_name=validated_data.display_name,
|
||
description=validated_data.description,
|
||
website=validated_data.website,
|
||
billing_type=billing_type,
|
||
monthly_quota_usd=validated_data.monthly_quota_usd,
|
||
quota_reset_day=validated_data.quota_reset_day,
|
||
quota_last_reset_at=validated_data.quota_last_reset_at,
|
||
quota_expires_at=validated_data.quota_expires_at,
|
||
rpm_limit=validated_data.rpm_limit,
|
||
provider_priority=validated_data.provider_priority,
|
||
is_active=validated_data.is_active,
|
||
concurrent_limit=validated_data.concurrent_limit,
|
||
config=validated_data.config,
|
||
)
|
||
|
||
db.add(provider)
|
||
db.commit()
|
||
db.refresh(provider)
|
||
|
||
context.add_audit_metadata(
|
||
action="create_provider",
|
||
provider_id=provider.id,
|
||
provider_name=provider.name,
|
||
billing_type=provider.billing_type.value if provider.billing_type else None,
|
||
is_active=provider.is_active,
|
||
provider_priority=provider.provider_priority,
|
||
)
|
||
|
||
return {
|
||
"id": provider.id,
|
||
"name": provider.name,
|
||
"display_name": provider.display_name,
|
||
"message": "提供商创建成功",
|
||
}
|
||
except InvalidRequestException:
|
||
db.rollback()
|
||
raise
|
||
except Exception:
|
||
db.rollback()
|
||
raise
|
||
|
||
|
||
class AdminUpdateProviderAdapter(AdminApiAdapter):
|
||
def __init__(self, provider_id: str):
|
||
self.provider_id = provider_id
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
payload = context.ensure_json_body()
|
||
|
||
# 查找 Provider
|
||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||
if not provider:
|
||
raise NotFoundException("提供商不存在", "provider")
|
||
|
||
try:
|
||
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
|
||
validated_data = UpdateProviderRequest.model_validate(payload)
|
||
except ValidationError as exc:
|
||
# 将 Pydantic 验证错误转换为友好的错误信息
|
||
errors = []
|
||
for error in exc.errors():
|
||
field = " -> ".join(str(x) for x in error["loc"])
|
||
errors.append(f"{field}: {error['msg']}")
|
||
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
|
||
|
||
try:
|
||
# 更新字段(只更新非 None 的字段)
|
||
update_data = validated_data.model_dump(exclude_unset=True)
|
||
|
||
for field, value in update_data.items():
|
||
if field == "billing_type" and value is not None:
|
||
# billing_type 需要转换为枚举
|
||
setattr(provider, field, ProviderBillingType(value))
|
||
else:
|
||
setattr(provider, field, value)
|
||
|
||
provider.updated_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
db.refresh(provider)
|
||
|
||
context.add_audit_metadata(
|
||
action="update_provider",
|
||
provider_id=provider.id,
|
||
changed_fields=list(update_data.keys()),
|
||
is_active=provider.is_active,
|
||
provider_priority=provider.provider_priority,
|
||
)
|
||
|
||
return {
|
||
"id": provider.id,
|
||
"name": provider.name,
|
||
"is_active": provider.is_active,
|
||
"message": "提供商更新成功",
|
||
}
|
||
except InvalidRequestException:
|
||
db.rollback()
|
||
raise
|
||
except Exception:
|
||
db.rollback()
|
||
raise
|
||
|
||
|
||
class AdminDeleteProviderAdapter(AdminApiAdapter):
|
||
def __init__(self, provider_id: str):
|
||
self.provider_id = provider_id
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||
if not provider:
|
||
raise NotFoundException("提供商不存在", "provider")
|
||
|
||
context.add_audit_metadata(
|
||
action="delete_provider",
|
||
provider_id=provider.id,
|
||
provider_name=provider.name,
|
||
)
|
||
db.delete(provider)
|
||
db.commit()
|
||
return {"message": "提供商已删除"}
|