mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
346 lines
12 KiB
Python
346 lines
12 KiB
Python
"""
|
||
ProviderEndpoint CRUD 管理 API
|
||
"""
|
||
|
||
import uuid
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timezone
|
||
from typing import List
|
||
|
||
from fastapi import APIRouter, Depends, Query, Request
|
||
from sqlalchemy import and_, func
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.api.base.admin_adapter import AdminApiAdapter
|
||
from src.api.base.pipeline import ApiRequestPipeline
|
||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||
from src.core.logger import logger
|
||
from src.database import get_db
|
||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||
from src.models.endpoint_models import (
|
||
ProviderEndpointCreate,
|
||
ProviderEndpointResponse,
|
||
ProviderEndpointUpdate,
|
||
)
|
||
|
||
router = APIRouter(tags=["Endpoint Management"])
|
||
pipeline = ApiRequestPipeline()
|
||
|
||
|
||
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
|
||
async def list_provider_endpoints(
|
||
provider_id: str,
|
||
request: Request,
|
||
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||
db: Session = Depends(get_db),
|
||
) -> List[ProviderEndpointResponse]:
|
||
"""获取指定 Provider 的所有 Endpoints"""
|
||
adapter = AdminListProviderEndpointsAdapter(
|
||
provider_id=provider_id,
|
||
skip=skip,
|
||
limit=limit,
|
||
)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.post("/providers/{provider_id}/endpoints", response_model=ProviderEndpointResponse)
|
||
async def create_provider_endpoint(
|
||
provider_id: str,
|
||
endpoint_data: ProviderEndpointCreate,
|
||
request: Request,
|
||
db: Session = Depends(get_db),
|
||
) -> ProviderEndpointResponse:
|
||
"""为 Provider 创建新的 Endpoint"""
|
||
adapter = AdminCreateProviderEndpointAdapter(
|
||
provider_id=provider_id,
|
||
endpoint_data=endpoint_data,
|
||
)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/{endpoint_id}", response_model=ProviderEndpointResponse)
|
||
async def get_endpoint(
|
||
endpoint_id: str,
|
||
request: Request,
|
||
db: Session = Depends(get_db),
|
||
) -> ProviderEndpointResponse:
|
||
"""获取 Endpoint 详情"""
|
||
adapter = AdminGetProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.put("/{endpoint_id}", response_model=ProviderEndpointResponse)
|
||
async def update_endpoint(
|
||
endpoint_id: str,
|
||
endpoint_data: ProviderEndpointUpdate,
|
||
request: Request,
|
||
db: Session = Depends(get_db),
|
||
) -> ProviderEndpointResponse:
|
||
"""更新 Endpoint"""
|
||
adapter = AdminUpdateProviderEndpointAdapter(
|
||
endpoint_id=endpoint_id,
|
||
endpoint_data=endpoint_data,
|
||
)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.delete("/{endpoint_id}")
|
||
async def delete_endpoint(
|
||
endpoint_id: str,
|
||
request: Request,
|
||
db: Session = Depends(get_db),
|
||
) -> dict:
|
||
"""删除 Endpoint(级联删除所有关联的 Keys)"""
|
||
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
# -------- Adapters --------
|
||
|
||
|
||
@dataclass
|
||
class AdminListProviderEndpointsAdapter(AdminApiAdapter):
|
||
provider_id: str
|
||
skip: int
|
||
limit: int
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||
if not provider:
|
||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||
|
||
endpoints = (
|
||
db.query(ProviderEndpoint)
|
||
.filter(ProviderEndpoint.provider_id == self.provider_id)
|
||
.order_by(ProviderEndpoint.created_at.desc())
|
||
.offset(self.skip)
|
||
.limit(self.limit)
|
||
.all()
|
||
)
|
||
|
||
endpoint_ids = [ep.id for ep in endpoints]
|
||
total_keys_map = {}
|
||
active_keys_map = {}
|
||
if endpoint_ids:
|
||
total_rows = (
|
||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
|
||
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
|
||
.group_by(ProviderAPIKey.endpoint_id)
|
||
.all()
|
||
)
|
||
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
|
||
|
||
active_rows = (
|
||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
|
||
.filter(
|
||
and_(
|
||
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
|
||
ProviderAPIKey.is_active.is_(True),
|
||
)
|
||
)
|
||
.group_by(ProviderAPIKey.endpoint_id)
|
||
.all()
|
||
)
|
||
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
|
||
|
||
result: List[ProviderEndpointResponse] = []
|
||
for endpoint in endpoints:
|
||
endpoint_dict = {
|
||
**endpoint.__dict__,
|
||
"provider_name": provider.name,
|
||
"api_format": endpoint.api_format,
|
||
"total_keys": total_keys_map.get(endpoint.id, 0),
|
||
"active_keys": active_keys_map.get(endpoint.id, 0),
|
||
}
|
||
endpoint_dict.pop("_sa_instance_state", None)
|
||
result.append(ProviderEndpointResponse(**endpoint_dict))
|
||
|
||
return result
|
||
|
||
|
||
@dataclass
|
||
class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
||
provider_id: str
|
||
endpoint_data: ProviderEndpointCreate
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||
if not provider:
|
||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||
|
||
if self.endpoint_data.provider_id != self.provider_id:
|
||
raise InvalidRequestException("provider_id 不匹配")
|
||
|
||
existing = (
|
||
db.query(ProviderEndpoint)
|
||
.filter(
|
||
and_(
|
||
ProviderEndpoint.provider_id == self.provider_id,
|
||
ProviderEndpoint.api_format == self.endpoint_data.api_format,
|
||
)
|
||
)
|
||
.first()
|
||
)
|
||
if existing:
|
||
raise InvalidRequestException(
|
||
f"Provider {provider.name} 已存在 {self.endpoint_data.api_format} 格式的 Endpoint"
|
||
)
|
||
|
||
now = datetime.now(timezone.utc)
|
||
new_endpoint = ProviderEndpoint(
|
||
id=str(uuid.uuid4()),
|
||
provider_id=self.provider_id,
|
||
api_format=self.endpoint_data.api_format,
|
||
base_url=self.endpoint_data.base_url,
|
||
headers=self.endpoint_data.headers,
|
||
timeout=self.endpoint_data.timeout,
|
||
max_retries=self.endpoint_data.max_retries,
|
||
max_concurrent=self.endpoint_data.max_concurrent,
|
||
rate_limit=self.endpoint_data.rate_limit,
|
||
is_active=True,
|
||
config=self.endpoint_data.config,
|
||
created_at=now,
|
||
updated_at=now,
|
||
)
|
||
|
||
db.add(new_endpoint)
|
||
db.commit()
|
||
db.refresh(new_endpoint)
|
||
|
||
logger.info(f"[OK] 创建 Endpoint: Provider={provider.name}, Format={self.endpoint_data.api_format}, ID={new_endpoint.id}")
|
||
|
||
endpoint_dict = {
|
||
k: v
|
||
for k, v in new_endpoint.__dict__.items()
|
||
if k not in {"api_format", "_sa_instance_state"}
|
||
}
|
||
return ProviderEndpointResponse(
|
||
**endpoint_dict,
|
||
provider_name=provider.name,
|
||
api_format=new_endpoint.api_format,
|
||
total_keys=0,
|
||
active_keys=0,
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class AdminGetProviderEndpointAdapter(AdminApiAdapter):
|
||
endpoint_id: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
endpoint = (
|
||
db.query(ProviderEndpoint, Provider)
|
||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||
.filter(ProviderEndpoint.id == self.endpoint_id)
|
||
.first()
|
||
)
|
||
if not endpoint:
|
||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||
|
||
endpoint_obj, provider = endpoint
|
||
total_keys = (
|
||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||
)
|
||
active_keys = (
|
||
db.query(ProviderAPIKey)
|
||
.filter(
|
||
and_(
|
||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||
ProviderAPIKey.is_active.is_(True),
|
||
)
|
||
)
|
||
.count()
|
||
)
|
||
|
||
endpoint_dict = {
|
||
k: v
|
||
for k, v in endpoint_obj.__dict__.items()
|
||
if k not in {"api_format", "_sa_instance_state"}
|
||
}
|
||
return ProviderEndpointResponse(
|
||
**endpoint_dict,
|
||
provider_name=provider.name,
|
||
api_format=endpoint_obj.api_format,
|
||
total_keys=total_keys,
|
||
active_keys=active_keys,
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
||
endpoint_id: str
|
||
endpoint_data: ProviderEndpointUpdate
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
endpoint = (
|
||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||
)
|
||
if not endpoint:
|
||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||
|
||
update_data = self.endpoint_data.model_dump(exclude_unset=True)
|
||
for field, value in update_data.items():
|
||
setattr(endpoint, field, value)
|
||
endpoint.updated_at = datetime.now(timezone.utc)
|
||
|
||
db.commit()
|
||
db.refresh(endpoint)
|
||
|
||
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
|
||
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
|
||
|
||
total_keys = (
|
||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||
)
|
||
active_keys = (
|
||
db.query(ProviderAPIKey)
|
||
.filter(
|
||
and_(
|
||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||
ProviderAPIKey.is_active.is_(True),
|
||
)
|
||
)
|
||
.count()
|
||
)
|
||
|
||
endpoint_dict = {
|
||
k: v
|
||
for k, v in endpoint.__dict__.items()
|
||
if k not in {"api_format", "_sa_instance_state"}
|
||
}
|
||
return ProviderEndpointResponse(
|
||
**endpoint_dict,
|
||
provider_name=provider.name if provider else "Unknown",
|
||
api_format=endpoint.api_format,
|
||
total_keys=total_keys,
|
||
active_keys=active_keys,
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
|
||
endpoint_id: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
endpoint = (
|
||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||
)
|
||
if not endpoint:
|
||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||
|
||
keys_count = (
|
||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||
)
|
||
db.delete(endpoint)
|
||
db.commit()
|
||
|
||
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
|
||
|
||
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}
|