Files
Aether/src/api/admin/providers/routes.py
fawney19 a12b43ce5c refactor: 清理数据库字段命名歧义
- users 表:重命名 allowed_endpoints 为 allowed_api_formats(修正历史命名错误)
- api_keys 表:删除 allowed_endpoints 字段(未使用的功能)
- providers 表:删除 rate_limit 字段(与 rpm_limit 重复)
- usage 表:重命名 provider 为 provider_name(避免与 provider_id 外键混淆)

同步更新前后端所有相关代码
2026-01-07 19:53:32 +08:00

339 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""管理员 Provider 管理路由。"""
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, Query, Request
from pydantic import ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.database import get_db
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
from src.models.database import Provider
router = APIRouter(tags=["Provider CRUD"])
pipeline = ApiRequestPipeline()
@router.get("/")
async def list_providers(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=500),
is_active: Optional[bool] = None,
db: Session = Depends(get_db),
):
"""
获取提供商列表
获取所有提供商的基本信息列表,支持分页和状态过滤。
**查询参数**:
- `skip`: 跳过的记录数,用于分页,默认为 0
- `limit`: 返回的最大记录数,范围 1-500默认为 100
- `is_active`: 可选的活跃状态过滤true 仅返回活跃提供商false 返回禁用提供商,不传则返回全部
**返回字段**:
- `id`: 提供商 ID
- `name`: 提供商名称(唯一标识)
- `display_name`: 显示名称
- `api_format`: API 格式(如 claude、openai、gemini 等)
- `base_url`: API 基础 URL
- `api_key`: API 密钥(脱敏显示)
- `priority`: 优先级
- `is_active`: 是否活跃
- `created_at`: 创建时间
- `updated_at`: 更新时间
"""
adapter = AdminListProvidersAdapter(skip=skip, limit=limit, is_active=is_active)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/")
async def create_provider(request: Request, db: Session = Depends(get_db)):
"""
创建新提供商
创建一个新的 AI 模型提供商配置。
**请求体字段**:
- `name`: 提供商名称(必填,唯一,用于系统标识)
- `display_name`: 显示名称(必填)
- `description`: 描述信息(可选)
- `website`: 官网地址(可选)
- `billing_type`: 计费类型可选pay_as_you_go/subscription/prepaid默认 pay_as_you_go
- `monthly_quota_usd`: 月度配额(美元)(可选)
- `quota_reset_day`: 配额重置日期1-31可选
- `quota_last_reset_at`: 上次配额重置时间(可选)
- `quota_expires_at`: 配额过期时间(可选)
- `rpm_limit`: 每分钟请求数限制(可选)
- `provider_priority`: 提供商优先级(数字越小优先级越高,默认 100
- `is_active`: 是否启用(默认 true
- `concurrent_limit`: 并发限制(可选)
- `config`: 额外配置信息JSON可选
**返回字段**:
- `id`: 新创建的提供商 ID
- `name`: 提供商名称
- `display_name`: 显示名称
- `message`: 成功提示信息
"""
adapter = AdminCreateProviderAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{provider_id}")
async def update_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
"""
更新提供商配置
更新指定提供商的配置信息。只需传入需要更新的字段,未传入的字段保持不变。
**路径参数**:
- `provider_id`: 提供商 ID
**请求体字段**(所有字段可选):
- `name`: 提供商名称
- `display_name`: 显示名称
- `description`: 描述信息
- `website`: 官网地址
- `billing_type`: 计费类型pay_as_you_go/subscription/prepaid
- `monthly_quota_usd`: 月度配额(美元)
- `quota_reset_day`: 配额重置日期1-31
- `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间
- `rpm_limit`: 每分钟请求数限制
- `provider_priority`: 提供商优先级
- `is_active`: 是否启用
- `concurrent_limit`: 并发限制
- `config`: 额外配置信息JSON
**返回字段**:
- `id`: 提供商 ID
- `name`: 提供商名称
- `is_active`: 是否启用
- `message`: 成功提示信息
"""
adapter = AdminUpdateProviderAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{provider_id}")
async def delete_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
"""
删除提供商
删除指定的提供商。注意:此操作会级联删除关联的端点、密钥和模型配置。
**路径参数**:
- `provider_id`: 提供商 ID
**返回字段**:
- `message`: 删除成功提示信息
"""
adapter = AdminDeleteProviderAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminListProvidersAdapter(AdminApiAdapter):
def __init__(self, skip: int, limit: int, is_active: Optional[bool]):
self.skip = skip
self.limit = limit
self.is_active = is_active
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(Provider)
if self.is_active is not None:
query = query.filter(Provider.is_active == self.is_active)
providers = query.offset(self.skip).limit(self.limit).all()
data = []
for provider in providers:
api_format = getattr(provider, "api_format", None)
base_url = getattr(provider, "base_url", None)
api_key = getattr(provider, "api_key", None)
priority = getattr(provider, "priority", provider.provider_priority)
data.append(
{
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"api_format": api_format.value if api_format else None,
"base_url": base_url,
"api_key": "***" if api_key else None,
"priority": priority,
"is_active": provider.is_active,
"created_at": provider.created_at.isoformat(),
"updated_at": provider.updated_at.isoformat() if provider.updated_at else None,
}
)
context.add_audit_metadata(
action="list_providers",
filter_is_active=self.is_active,
limit=self.limit,
skip=self.skip,
result_count=len(data),
)
return data
class AdminCreateProviderAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
validated_data = CreateProviderRequest.model_validate(payload)
except ValidationError as exc:
# 将 Pydantic 验证错误转换为友好的错误信息
errors = []
for error in exc.errors():
field = " -> ".join(str(x) for x in error["loc"])
errors.append(f"{field}: {error['msg']}")
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
try:
# 检查名称是否已存在
existing = db.query(Provider).filter(Provider.name == validated_data.name).first()
if existing:
raise InvalidRequestException(f"提供商名称 '{validated_data.name}' 已存在")
# 将验证后的数据转换为枚举类型
billing_type = (
ProviderBillingType(validated_data.billing_type)
if validated_data.billing_type
else ProviderBillingType.PAY_AS_YOU_GO
)
# 创建 Provider 对象
provider = Provider(
name=validated_data.name,
display_name=validated_data.display_name,
description=validated_data.description,
website=validated_data.website,
billing_type=billing_type,
monthly_quota_usd=validated_data.monthly_quota_usd,
quota_reset_day=validated_data.quota_reset_day,
quota_last_reset_at=validated_data.quota_last_reset_at,
quota_expires_at=validated_data.quota_expires_at,
rpm_limit=validated_data.rpm_limit,
provider_priority=validated_data.provider_priority,
is_active=validated_data.is_active,
concurrent_limit=validated_data.concurrent_limit,
config=validated_data.config,
)
db.add(provider)
db.commit()
db.refresh(provider)
context.add_audit_metadata(
action="create_provider",
provider_id=provider.id,
provider_name=provider.name,
billing_type=provider.billing_type.value if provider.billing_type else None,
is_active=provider.is_active,
provider_priority=provider.provider_priority,
)
return {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"message": "提供商创建成功",
}
except InvalidRequestException:
db.rollback()
raise
except Exception:
db.rollback()
raise
class AdminUpdateProviderAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
# 查找 Provider
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("提供商不存在", "provider")
try:
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
validated_data = UpdateProviderRequest.model_validate(payload)
except ValidationError as exc:
# 将 Pydantic 验证错误转换为友好的错误信息
errors = []
for error in exc.errors():
field = " -> ".join(str(x) for x in error["loc"])
errors.append(f"{field}: {error['msg']}")
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
try:
# 更新字段(只更新非 None 的字段)
update_data = validated_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "billing_type" and value is not None:
# billing_type 需要转换为枚举
setattr(provider, field, ProviderBillingType(value))
else:
setattr(provider, field, value)
provider.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(provider)
context.add_audit_metadata(
action="update_provider",
provider_id=provider.id,
changed_fields=list(update_data.keys()),
is_active=provider.is_active,
provider_priority=provider.provider_priority,
)
return {
"id": provider.id,
"name": provider.name,
"is_active": provider.is_active,
"message": "提供商更新成功",
}
except InvalidRequestException:
db.rollback()
raise
except Exception:
db.rollback()
raise
class AdminDeleteProviderAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("提供商不存在", "provider")
context.add_audit_metadata(
action="delete_provider",
provider_id=provider.id,
provider_name=provider.name,
)
db.delete(provider)
db.commit()
return {"message": "提供商已删除"}