Files
Aether/src/api/admin/providers/routes.py

339 lines
13 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""管理员 Provider 管理路由。"""
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, Query, Request
from pydantic import ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.database import get_db
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
from src.models.database import Provider
router = APIRouter(tags=["Provider CRUD"])
pipeline = ApiRequestPipeline()
@router.get("/")
async def list_providers(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=500),
is_active: Optional[bool] = None,
db: Session = Depends(get_db),
):
"""
获取提供商列表
获取所有提供商的基本信息列表支持分页和状态过滤
**查询参数**:
- `skip`: 跳过的记录数用于分页默认为 0
- `limit`: 返回的最大记录数范围 1-500默认为 100
- `is_active`: 可选的活跃状态过滤true 仅返回活跃提供商false 返回禁用提供商不传则返回全部
**返回字段**:
- `id`: 提供商 ID
- `name`: 提供商名称唯一标识
- `display_name`: 显示名称
- `api_format`: API 格式 claudeopenaigemini
- `base_url`: API 基础 URL
- `api_key`: API 密钥脱敏显示
- `priority`: 优先级
- `is_active`: 是否活跃
- `created_at`: 创建时间
- `updated_at`: 更新时间
"""
2025-12-10 20:52:44 +08:00
adapter = AdminListProvidersAdapter(skip=skip, limit=limit, is_active=is_active)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/")
async def create_provider(request: Request, db: Session = Depends(get_db)):
"""
创建新提供商
创建一个新的 AI 模型提供商配置
**请求体字段**:
- `name`: 提供商名称必填唯一用于系统标识
- `display_name`: 显示名称必填
- `description`: 描述信息可选
- `website`: 官网地址可选
- `billing_type`: 计费类型可选pay_as_you_go/subscription/prepaid默认 pay_as_you_go
- `monthly_quota_usd`: 月度配额美元可选
- `quota_reset_day`: 配额重置日期1-31可选
- `quota_last_reset_at`: 上次配额重置时间可选
- `quota_expires_at`: 配额过期时间可选
- `rpm_limit`: 每分钟请求数限制可选
- `provider_priority`: 提供商优先级数字越小优先级越高默认 100
- `is_active`: 是否启用默认 true
- `concurrent_limit`: 并发限制可选
- `config`: 额外配置信息JSON可选
**返回字段**:
- `id`: 新创建的提供商 ID
- `name`: 提供商名称
- `display_name`: 显示名称
- `message`: 成功提示信息
"""
2025-12-10 20:52:44 +08:00
adapter = AdminCreateProviderAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{provider_id}")
async def update_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
"""
更新提供商配置
更新指定提供商的配置信息只需传入需要更新的字段未传入的字段保持不变
**路径参数**:
- `provider_id`: 提供商 ID
**请求体字段**所有字段可选:
- `name`: 提供商名称
- `display_name`: 显示名称
- `description`: 描述信息
- `website`: 官网地址
- `billing_type`: 计费类型pay_as_you_go/subscription/prepaid
- `monthly_quota_usd`: 月度配额美元
- `quota_reset_day`: 配额重置日期1-31
- `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间
- `rpm_limit`: 每分钟请求数限制
- `provider_priority`: 提供商优先级
- `is_active`: 是否启用
- `concurrent_limit`: 并发限制
- `config`: 额外配置信息JSON
**返回字段**:
- `id`: 提供商 ID
- `name`: 提供商名称
- `is_active`: 是否启用
- `message`: 成功提示信息
"""
2025-12-10 20:52:44 +08:00
adapter = AdminUpdateProviderAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{provider_id}")
async def delete_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
"""
删除提供商
删除指定的提供商注意此操作会级联删除关联的端点密钥和模型配置
**路径参数**:
- `provider_id`: 提供商 ID
**返回字段**:
- `message`: 删除成功提示信息
"""
2025-12-10 20:52:44 +08:00
adapter = AdminDeleteProviderAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminListProvidersAdapter(AdminApiAdapter):
def __init__(self, skip: int, limit: int, is_active: Optional[bool]):
self.skip = skip
self.limit = limit
self.is_active = is_active
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(Provider)
if self.is_active is not None:
query = query.filter(Provider.is_active == self.is_active)
providers = query.offset(self.skip).limit(self.limit).all()
data = []
for provider in providers:
api_format = getattr(provider, "api_format", None)
base_url = getattr(provider, "base_url", None)
api_key = getattr(provider, "api_key", None)
priority = getattr(provider, "priority", provider.provider_priority)
data.append(
{
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"api_format": api_format.value if api_format else None,
"base_url": base_url,
"api_key": "***" if api_key else None,
"priority": priority,
"is_active": provider.is_active,
"created_at": provider.created_at.isoformat(),
"updated_at": provider.updated_at.isoformat() if provider.updated_at else None,
}
)
context.add_audit_metadata(
action="list_providers",
filter_is_active=self.is_active,
limit=self.limit,
skip=self.skip,
result_count=len(data),
)
return data
class AdminCreateProviderAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
validated_data = CreateProviderRequest.model_validate(payload)
except ValidationError as exc:
# 将 Pydantic 验证错误转换为友好的错误信息
errors = []
for error in exc.errors():
field = " -> ".join(str(x) for x in error["loc"])
errors.append(f"{field}: {error['msg']}")
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
try:
# 检查名称是否已存在
existing = db.query(Provider).filter(Provider.name == validated_data.name).first()
if existing:
raise InvalidRequestException(f"提供商名称 '{validated_data.name}' 已存在")
# 将验证后的数据转换为枚举类型
billing_type = (
ProviderBillingType(validated_data.billing_type)
if validated_data.billing_type
else ProviderBillingType.PAY_AS_YOU_GO
)
# 创建 Provider 对象
provider = Provider(
name=validated_data.name,
display_name=validated_data.display_name,
description=validated_data.description,
website=validated_data.website,
billing_type=billing_type,
monthly_quota_usd=validated_data.monthly_quota_usd,
quota_reset_day=validated_data.quota_reset_day,
quota_last_reset_at=validated_data.quota_last_reset_at,
quota_expires_at=validated_data.quota_expires_at,
rpm_limit=validated_data.rpm_limit,
provider_priority=validated_data.provider_priority,
is_active=validated_data.is_active,
concurrent_limit=validated_data.concurrent_limit,
config=validated_data.config,
)
db.add(provider)
db.commit()
db.refresh(provider)
context.add_audit_metadata(
action="create_provider",
provider_id=provider.id,
provider_name=provider.name,
billing_type=provider.billing_type.value if provider.billing_type else None,
is_active=provider.is_active,
provider_priority=provider.provider_priority,
)
return {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"message": "提供商创建成功",
}
except InvalidRequestException:
db.rollback()
raise
except Exception:
db.rollback()
raise
class AdminUpdateProviderAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
# 查找 Provider
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("提供商不存在", "provider")
try:
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
validated_data = UpdateProviderRequest.model_validate(payload)
except ValidationError as exc:
# 将 Pydantic 验证错误转换为友好的错误信息
errors = []
for error in exc.errors():
field = " -> ".join(str(x) for x in error["loc"])
errors.append(f"{field}: {error['msg']}")
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
try:
# 更新字段(只更新非 None 的字段)
update_data = validated_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "billing_type" and value is not None:
# billing_type 需要转换为枚举
setattr(provider, field, ProviderBillingType(value))
else:
setattr(provider, field, value)
provider.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(provider)
context.add_audit_metadata(
action="update_provider",
provider_id=provider.id,
changed_fields=list(update_data.keys()),
is_active=provider.is_active,
provider_priority=provider.provider_priority,
)
return {
"id": provider.id,
"name": provider.name,
"is_active": provider.is_active,
"message": "提供商更新成功",
}
except InvalidRequestException:
db.rollback()
raise
except Exception:
db.rollback()
raise
class AdminDeleteProviderAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("提供商不存在", "provider")
context.add_audit_metadata(
action="delete_provider",
provider_id=provider.id,
provider_name=provider.name,
)
db.delete(provider)
db.commit()
return {"message": "提供商已删除"}