mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-11 03:58:28 +08:00
Initial commit
This commit is contained in:
249
src/api/admin/providers/routes.py
Normal file
249
src/api/admin/providers/routes.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""管理员 Provider 管理路由。"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.database import get_db
|
||||
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
|
||||
from src.models.database import Provider
|
||||
|
||||
router = APIRouter(tags=["Provider CRUD"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_providers(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
is_active: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminListProvidersAdapter(skip=skip, limit=limit, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_provider(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminCreateProviderAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{provider_id}")
|
||||
async def update_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminUpdateProviderAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{provider_id}")
|
||||
async def delete_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminDeleteProviderAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminListProvidersAdapter(AdminApiAdapter):
|
||||
def __init__(self, skip: int, limit: int, is_active: Optional[bool]):
|
||||
self.skip = skip
|
||||
self.limit = limit
|
||||
self.is_active = is_active
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = db.query(Provider)
|
||||
if self.is_active is not None:
|
||||
query = query.filter(Provider.is_active == self.is_active)
|
||||
providers = query.offset(self.skip).limit(self.limit).all()
|
||||
|
||||
data = []
|
||||
for provider in providers:
|
||||
api_format = getattr(provider, "api_format", None)
|
||||
base_url = getattr(provider, "base_url", None)
|
||||
api_key = getattr(provider, "api_key", None)
|
||||
priority = getattr(provider, "priority", provider.provider_priority)
|
||||
|
||||
data.append(
|
||||
{
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"api_format": api_format.value if api_format else None,
|
||||
"base_url": base_url,
|
||||
"api_key": "***" if api_key else None,
|
||||
"priority": priority,
|
||||
"is_active": provider.is_active,
|
||||
"created_at": provider.created_at.isoformat(),
|
||||
"updated_at": provider.updated_at.isoformat() if provider.updated_at else None,
|
||||
}
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="list_providers",
|
||||
filter_is_active=self.is_active,
|
||||
limit=self.limit,
|
||||
skip=self.skip,
|
||||
result_count=len(data),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AdminCreateProviderAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
|
||||
validated_data = CreateProviderRequest.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
# 将 Pydantic 验证错误转换为友好的错误信息
|
||||
errors = []
|
||||
for error in exc.errors():
|
||||
field = " -> ".join(str(x) for x in error["loc"])
|
||||
errors.append(f"{field}: {error['msg']}")
|
||||
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
|
||||
|
||||
try:
|
||||
# 检查名称是否已存在
|
||||
existing = db.query(Provider).filter(Provider.name == validated_data.name).first()
|
||||
if existing:
|
||||
raise InvalidRequestException(f"提供商名称 '{validated_data.name}' 已存在")
|
||||
|
||||
# 将验证后的数据转换为枚举类型
|
||||
billing_type = (
|
||||
ProviderBillingType(validated_data.billing_type)
|
||||
if validated_data.billing_type
|
||||
else ProviderBillingType.PAY_AS_YOU_GO
|
||||
)
|
||||
|
||||
# 创建 Provider 对象
|
||||
provider = Provider(
|
||||
name=validated_data.name,
|
||||
display_name=validated_data.display_name,
|
||||
description=validated_data.description,
|
||||
website=validated_data.website,
|
||||
billing_type=billing_type,
|
||||
monthly_quota_usd=validated_data.monthly_quota_usd,
|
||||
quota_reset_day=validated_data.quota_reset_day,
|
||||
quota_last_reset_at=validated_data.quota_last_reset_at,
|
||||
quota_expires_at=validated_data.quota_expires_at,
|
||||
rpm_limit=validated_data.rpm_limit,
|
||||
provider_priority=validated_data.provider_priority,
|
||||
is_active=validated_data.is_active,
|
||||
rate_limit=validated_data.rate_limit,
|
||||
concurrent_limit=validated_data.concurrent_limit,
|
||||
config=validated_data.config,
|
||||
)
|
||||
|
||||
db.add(provider)
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="create_provider",
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.name,
|
||||
billing_type=provider.billing_type.value if provider.billing_type else None,
|
||||
is_active=provider.is_active,
|
||||
provider_priority=provider.provider_priority,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"message": "提供商创建成功",
|
||||
}
|
||||
except InvalidRequestException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
class AdminUpdateProviderAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
# 查找 Provider
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("提供商不存在", "provider")
|
||||
|
||||
try:
|
||||
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
|
||||
validated_data = UpdateProviderRequest.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
# 将 Pydantic 验证错误转换为友好的错误信息
|
||||
errors = []
|
||||
for error in exc.errors():
|
||||
field = " -> ".join(str(x) for x in error["loc"])
|
||||
errors.append(f"{field}: {error['msg']}")
|
||||
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
|
||||
|
||||
try:
|
||||
# 更新字段(只更新非 None 的字段)
|
||||
update_data = validated_data.model_dump(exclude_unset=True)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field == "billing_type" and value is not None:
|
||||
# billing_type 需要转换为枚举
|
||||
setattr(provider, field, ProviderBillingType(value))
|
||||
else:
|
||||
setattr(provider, field, value)
|
||||
|
||||
provider.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="update_provider",
|
||||
provider_id=provider.id,
|
||||
changed_fields=list(update_data.keys()),
|
||||
is_active=provider.is_active,
|
||||
provider_priority=provider.provider_priority,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"is_active": provider.is_active,
|
||||
"message": "提供商更新成功",
|
||||
}
|
||||
except InvalidRequestException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
class AdminDeleteProviderAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("提供商不存在", "provider")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="delete_provider",
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.name,
|
||||
)
|
||||
db.delete(provider)
|
||||
db.commit()
|
||||
return {"message": "提供商已删除"}
|
||||
Reference in New Issue
Block a user