2025-12-10 20:52:44 +08:00
|
|
|
|
"""
|
|
|
|
|
|
ProviderEndpoint CRUD 管理 API
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
|
|
from fastapi import APIRouter, Depends, Query, Request
|
|
|
|
|
|
from sqlalchemy import and_, func
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
|
|
from src.api.base.admin_adapter import AdminApiAdapter
|
|
|
|
|
|
from src.api.base.pipeline import ApiRequestPipeline
|
|
|
|
|
|
from src.core.exceptions import InvalidRequestException, NotFoundException
|
|
|
|
|
|
from src.core.logger import logger
|
|
|
|
|
|
from src.database import get_db
|
|
|
|
|
|
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
|
|
|
|
|
from src.models.endpoint_models import (
|
|
|
|
|
|
ProviderEndpointCreate,
|
|
|
|
|
|
ProviderEndpointResponse,
|
|
|
|
|
|
ProviderEndpointUpdate,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
router = APIRouter(tags=["Endpoint Management"])
|
|
|
|
|
|
pipeline = ApiRequestPipeline()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
|
|
|
|
|
|
async def list_provider_endpoints(
|
|
|
|
|
|
provider_id: str,
|
|
|
|
|
|
request: Request,
|
|
|
|
|
|
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
|
|
|
|
|
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
|
) -> List[ProviderEndpointResponse]:
|
|
|
|
|
|
"""获取指定 Provider 的所有 Endpoints"""
|
|
|
|
|
|
adapter = AdminListProviderEndpointsAdapter(
|
|
|
|
|
|
provider_id=provider_id,
|
|
|
|
|
|
skip=skip,
|
|
|
|
|
|
limit=limit,
|
|
|
|
|
|
)
|
|
|
|
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/providers/{provider_id}/endpoints", response_model=ProviderEndpointResponse)
|
|
|
|
|
|
async def create_provider_endpoint(
|
|
|
|
|
|
provider_id: str,
|
|
|
|
|
|
endpoint_data: ProviderEndpointCreate,
|
|
|
|
|
|
request: Request,
|
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
|
) -> ProviderEndpointResponse:
|
|
|
|
|
|
"""为 Provider 创建新的 Endpoint"""
|
|
|
|
|
|
adapter = AdminCreateProviderEndpointAdapter(
|
|
|
|
|
|
provider_id=provider_id,
|
|
|
|
|
|
endpoint_data=endpoint_data,
|
|
|
|
|
|
)
|
|
|
|
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{endpoint_id}", response_model=ProviderEndpointResponse)
|
|
|
|
|
|
async def get_endpoint(
|
|
|
|
|
|
endpoint_id: str,
|
|
|
|
|
|
request: Request,
|
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
|
) -> ProviderEndpointResponse:
|
|
|
|
|
|
"""获取 Endpoint 详情"""
|
|
|
|
|
|
adapter = AdminGetProviderEndpointAdapter(endpoint_id=endpoint_id)
|
|
|
|
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.put("/{endpoint_id}", response_model=ProviderEndpointResponse)
|
|
|
|
|
|
async def update_endpoint(
|
|
|
|
|
|
endpoint_id: str,
|
|
|
|
|
|
endpoint_data: ProviderEndpointUpdate,
|
|
|
|
|
|
request: Request,
|
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
|
) -> ProviderEndpointResponse:
|
|
|
|
|
|
"""更新 Endpoint"""
|
|
|
|
|
|
adapter = AdminUpdateProviderEndpointAdapter(
|
|
|
|
|
|
endpoint_id=endpoint_id,
|
|
|
|
|
|
endpoint_data=endpoint_data,
|
|
|
|
|
|
)
|
|
|
|
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/{endpoint_id}")
|
|
|
|
|
|
async def delete_endpoint(
|
|
|
|
|
|
endpoint_id: str,
|
|
|
|
|
|
request: Request,
|
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
|
) -> dict:
|
|
|
|
|
|
"""删除 Endpoint(级联删除所有关联的 Keys)"""
|
|
|
|
|
|
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
|
|
|
|
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -------- Adapters --------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class AdminListProviderEndpointsAdapter(AdminApiAdapter):
|
|
|
|
|
|
provider_id: str
|
|
|
|
|
|
skip: int
|
|
|
|
|
|
limit: int
|
|
|
|
|
|
|
|
|
|
|
|
async def handle(self, context): # type: ignore[override]
|
|
|
|
|
|
db = context.db
|
|
|
|
|
|
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
|
|
|
|
|
if not provider:
|
|
|
|
|
|
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
|
|
|
|
|
|
|
|
|
|
|
endpoints = (
|
|
|
|
|
|
db.query(ProviderEndpoint)
|
|
|
|
|
|
.filter(ProviderEndpoint.provider_id == self.provider_id)
|
|
|
|
|
|
.order_by(ProviderEndpoint.created_at.desc())
|
|
|
|
|
|
.offset(self.skip)
|
|
|
|
|
|
.limit(self.limit)
|
|
|
|
|
|
.all()
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
endpoint_ids = [ep.id for ep in endpoints]
|
|
|
|
|
|
total_keys_map = {}
|
|
|
|
|
|
active_keys_map = {}
|
|
|
|
|
|
if endpoint_ids:
|
|
|
|
|
|
total_rows = (
|
|
|
|
|
|
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
|
|
|
|
|
|
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
|
|
|
|
|
|
.group_by(ProviderAPIKey.endpoint_id)
|
|
|
|
|
|
.all()
|
|
|
|
|
|
)
|
|
|
|
|
|
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
|
|
|
|
|
|
|
|
|
|
|
|
active_rows = (
|
|
|
|
|
|
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
|
|
|
|
|
|
.filter(
|
|
|
|
|
|
and_(
|
|
|
|
|
|
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
|
|
|
|
|
|
ProviderAPIKey.is_active.is_(True),
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
.group_by(ProviderAPIKey.endpoint_id)
|
|
|
|
|
|
.all()
|
|
|
|
|
|
)
|
|
|
|
|
|
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
|
|
|
|
|
|
|
|
|
|
|
|
result: List[ProviderEndpointResponse] = []
|
|
|
|
|
|
for endpoint in endpoints:
|
|
|
|
|
|
endpoint_dict = {
|
|
|
|
|
|
**endpoint.__dict__,
|
|
|
|
|
|
"provider_name": provider.name,
|
|
|
|
|
|
"api_format": endpoint.api_format,
|
|
|
|
|
|
"total_keys": total_keys_map.get(endpoint.id, 0),
|
|
|
|
|
|
"active_keys": active_keys_map.get(endpoint.id, 0),
|
|
|
|
|
|
}
|
|
|
|
|
|
endpoint_dict.pop("_sa_instance_state", None)
|
|
|
|
|
|
result.append(ProviderEndpointResponse(**endpoint_dict))
|
|
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
|
|
|
|
|
provider_id: str
|
|
|
|
|
|
endpoint_data: ProviderEndpointCreate
|
|
|
|
|
|
|
|
|
|
|
|
async def handle(self, context): # type: ignore[override]
|
|
|
|
|
|
db = context.db
|
|
|
|
|
|
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
|
|
|
|
|
if not provider:
|
|
|
|
|
|
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
|
|
|
|
|
|
|
|
|
|
|
if self.endpoint_data.provider_id != self.provider_id:
|
|
|
|
|
|
raise InvalidRequestException("provider_id 不匹配")
|
|
|
|
|
|
|
|
|
|
|
|
existing = (
|
|
|
|
|
|
db.query(ProviderEndpoint)
|
|
|
|
|
|
.filter(
|
|
|
|
|
|
and_(
|
|
|
|
|
|
ProviderEndpoint.provider_id == self.provider_id,
|
|
|
|
|
|
ProviderEndpoint.api_format == self.endpoint_data.api_format,
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
.first()
|
|
|
|
|
|
)
|
|
|
|
|
|
if existing:
|
|
|
|
|
|
raise InvalidRequestException(
|
|
|
|
|
|
f"Provider {provider.name} 已存在 {self.endpoint_data.api_format} 格式的 Endpoint"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
now = datetime.now(timezone.utc)
|
|
|
|
|
|
new_endpoint = ProviderEndpoint(
|
|
|
|
|
|
id=str(uuid.uuid4()),
|
|
|
|
|
|
provider_id=self.provider_id,
|
|
|
|
|
|
api_format=self.endpoint_data.api_format,
|
|
|
|
|
|
base_url=self.endpoint_data.base_url,
|
|
|
|
|
|
headers=self.endpoint_data.headers,
|
|
|
|
|
|
timeout=self.endpoint_data.timeout,
|
|
|
|
|
|
max_retries=self.endpoint_data.max_retries,
|
|
|
|
|
|
max_concurrent=self.endpoint_data.max_concurrent,
|
|
|
|
|
|
rate_limit=self.endpoint_data.rate_limit,
|
|
|
|
|
|
is_active=True,
|
|
|
|
|
|
config=self.endpoint_data.config,
|
2025-12-18 14:42:06 +08:00
|
|
|
|
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
created_at=now,
|
|
|
|
|
|
updated_at=now,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
db.add(new_endpoint)
|
|
|
|
|
|
db.commit()
|
|
|
|
|
|
db.refresh(new_endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"[OK] 创建 Endpoint: Provider={provider.name}, Format={self.endpoint_data.api_format}, ID={new_endpoint.id}")
|
|
|
|
|
|
|
|
|
|
|
|
endpoint_dict = {
|
|
|
|
|
|
k: v
|
|
|
|
|
|
for k, v in new_endpoint.__dict__.items()
|
|
|
|
|
|
if k not in {"api_format", "_sa_instance_state"}
|
|
|
|
|
|
}
|
|
|
|
|
|
return ProviderEndpointResponse(
|
|
|
|
|
|
**endpoint_dict,
|
|
|
|
|
|
provider_name=provider.name,
|
|
|
|
|
|
api_format=new_endpoint.api_format,
|
|
|
|
|
|
total_keys=0,
|
|
|
|
|
|
active_keys=0,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class AdminGetProviderEndpointAdapter(AdminApiAdapter):
|
|
|
|
|
|
endpoint_id: str
|
|
|
|
|
|
|
|
|
|
|
|
async def handle(self, context): # type: ignore[override]
|
|
|
|
|
|
db = context.db
|
|
|
|
|
|
endpoint = (
|
|
|
|
|
|
db.query(ProviderEndpoint, Provider)
|
|
|
|
|
|
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
|
|
|
|
|
.filter(ProviderEndpoint.id == self.endpoint_id)
|
|
|
|
|
|
.first()
|
|
|
|
|
|
)
|
|
|
|
|
|
if not endpoint:
|
|
|
|
|
|
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
|
|
|
|
|
|
|
|
|
|
|
endpoint_obj, provider = endpoint
|
|
|
|
|
|
total_keys = (
|
|
|
|
|
|
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
|
|
|
|
|
)
|
|
|
|
|
|
active_keys = (
|
|
|
|
|
|
db.query(ProviderAPIKey)
|
|
|
|
|
|
.filter(
|
|
|
|
|
|
and_(
|
|
|
|
|
|
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
|
|
|
|
|
ProviderAPIKey.is_active.is_(True),
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
.count()
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
endpoint_dict = {
|
|
|
|
|
|
k: v
|
|
|
|
|
|
for k, v in endpoint_obj.__dict__.items()
|
|
|
|
|
|
if k not in {"api_format", "_sa_instance_state"}
|
|
|
|
|
|
}
|
|
|
|
|
|
return ProviderEndpointResponse(
|
|
|
|
|
|
**endpoint_dict,
|
|
|
|
|
|
provider_name=provider.name,
|
|
|
|
|
|
api_format=endpoint_obj.api_format,
|
|
|
|
|
|
total_keys=total_keys,
|
|
|
|
|
|
active_keys=active_keys,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
|
|
|
|
|
endpoint_id: str
|
|
|
|
|
|
endpoint_data: ProviderEndpointUpdate
|
|
|
|
|
|
|
|
|
|
|
|
async def handle(self, context): # type: ignore[override]
|
|
|
|
|
|
db = context.db
|
|
|
|
|
|
endpoint = (
|
|
|
|
|
|
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
|
|
|
|
|
)
|
|
|
|
|
|
if not endpoint:
|
|
|
|
|
|
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
|
|
|
|
|
|
|
|
|
|
|
update_data = self.endpoint_data.model_dump(exclude_unset=True)
|
2025-12-18 14:42:06 +08:00
|
|
|
|
# 把 proxy 转换为 dict 存储
|
|
|
|
|
|
if "proxy" in update_data and update_data["proxy"] is not None:
|
|
|
|
|
|
update_data["proxy"] = dict(update_data["proxy"])
|
2025-12-10 20:52:44 +08:00
|
|
|
|
for field, value in update_data.items():
|
|
|
|
|
|
setattr(endpoint, field, value)
|
|
|
|
|
|
endpoint.updated_at = datetime.now(timezone.utc)
|
|
|
|
|
|
|
|
|
|
|
|
db.commit()
|
|
|
|
|
|
db.refresh(endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
|
|
|
|
|
|
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
|
|
|
|
|
|
|
|
|
|
|
|
total_keys = (
|
|
|
|
|
|
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
|
|
|
|
|
)
|
|
|
|
|
|
active_keys = (
|
|
|
|
|
|
db.query(ProviderAPIKey)
|
|
|
|
|
|
.filter(
|
|
|
|
|
|
and_(
|
|
|
|
|
|
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
|
|
|
|
|
ProviderAPIKey.is_active.is_(True),
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
.count()
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
endpoint_dict = {
|
|
|
|
|
|
k: v
|
|
|
|
|
|
for k, v in endpoint.__dict__.items()
|
|
|
|
|
|
if k not in {"api_format", "_sa_instance_state"}
|
|
|
|
|
|
}
|
|
|
|
|
|
return ProviderEndpointResponse(
|
|
|
|
|
|
**endpoint_dict,
|
|
|
|
|
|
provider_name=provider.name if provider else "Unknown",
|
|
|
|
|
|
api_format=endpoint.api_format,
|
|
|
|
|
|
total_keys=total_keys,
|
|
|
|
|
|
active_keys=active_keys,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
|
|
|
|
|
|
endpoint_id: str
|
|
|
|
|
|
|
|
|
|
|
|
async def handle(self, context): # type: ignore[override]
|
|
|
|
|
|
db = context.db
|
|
|
|
|
|
endpoint = (
|
|
|
|
|
|
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
|
|
|
|
|
)
|
|
|
|
|
|
if not endpoint:
|
|
|
|
|
|
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
|
|
|
|
|
|
|
|
|
|
|
keys_count = (
|
|
|
|
|
|
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
|
|
|
|
|
)
|
|
|
|
|
|
db.delete(endpoint)
|
|
|
|
|
|
db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
|
|
|
|
|
|
|
|
|
|
|
|
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}
|