mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 10:12:27 +08:00
Initial commit
This commit is contained in:
24
src/api/admin/endpoints/__init__.py
Normal file
24
src/api/admin/endpoints/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Endpoint management API routers."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .concurrency import router as concurrency_router
|
||||
from .health import router as health_router
|
||||
from .keys import router as keys_router
|
||||
from .routes import router as routes_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/endpoints", tags=["Endpoint Management"])
|
||||
|
||||
# Endpoint CRUD
|
||||
router.include_router(routes_router)
|
||||
|
||||
# Endpoint Keys management
|
||||
router.include_router(keys_router)
|
||||
|
||||
# Health monitoring
|
||||
router.include_router(health_router)
|
||||
|
||||
# Concurrency control
|
||||
router.include_router(concurrency_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
116
src/api/admin/endpoints/concurrency.py
Normal file
116
src/api/admin/endpoints/concurrency.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Endpoint 并发控制管理 API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.database import get_db
|
||||
from src.models.database import ProviderAPIKey, ProviderEndpoint
|
||||
from src.models.endpoint_models import (
|
||||
ConcurrencyStatusResponse,
|
||||
ResetConcurrencyRequest,
|
||||
)
|
||||
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
||||
|
||||
router = APIRouter(tags=["Concurrency Control"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/concurrency/endpoint/{endpoint_id}", response_model=ConcurrencyStatusResponse)
|
||||
async def get_endpoint_concurrency(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ConcurrencyStatusResponse:
|
||||
"""获取 Endpoint 当前并发状态"""
|
||||
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/concurrency/key/{key_id}", response_model=ConcurrencyStatusResponse)
|
||||
async def get_key_concurrency(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ConcurrencyStatusResponse:
|
||||
"""获取 Key 当前并发状态"""
|
||||
adapter = AdminKeyConcurrencyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/concurrency")
|
||||
async def reset_concurrency(
|
||||
request: ResetConcurrencyRequest,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Reset concurrency counters (admin function, use with caution)"""
|
||||
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminEndpointConcurrencyAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
endpoint_count, _ = await concurrency_manager.get_current_concurrency(
|
||||
endpoint_id=self.endpoint_id
|
||||
)
|
||||
|
||||
return ConcurrencyStatusResponse(
|
||||
endpoint_id=self.endpoint_id,
|
||||
endpoint_current_concurrency=endpoint_count,
|
||||
endpoint_max_concurrent=endpoint.max_concurrent,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminKeyConcurrencyAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
_, key_count = await concurrency_manager.get_current_concurrency(key_id=self.key_id)
|
||||
|
||||
return ConcurrencyStatusResponse(
|
||||
key_id=self.key_id,
|
||||
key_current_concurrency=key_count,
|
||||
key_max_concurrent=key.max_concurrent,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminResetConcurrencyAdapter(AdminApiAdapter):
|
||||
endpoint_id: Optional[str]
|
||||
key_id: Optional[str]
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
await concurrency_manager.reset_concurrency(
|
||||
endpoint_id=self.endpoint_id, key_id=self.key_id
|
||||
)
|
||||
return {"message": "并发计数已重置"}
|
||||
476
src/api/admin/endpoints/health.py
Normal file
476
src/api/admin/endpoints/health.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""
|
||||
Endpoint 健康监控 API
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, RequestCandidate
|
||||
from src.models.endpoint_models import (
|
||||
ApiFormatHealthMonitor,
|
||||
ApiFormatHealthMonitorResponse,
|
||||
EndpointHealthEvent,
|
||||
HealthStatusResponse,
|
||||
HealthSummaryResponse,
|
||||
)
|
||||
from src.services.health.endpoint import EndpointHealthService
|
||||
from src.services.health.monitor import health_monitor
|
||||
|
||||
router = APIRouter(tags=["Endpoint Health"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/health/summary", response_model=HealthSummaryResponse)
|
||||
async def get_health_summary(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> HealthSummaryResponse:
|
||||
"""获取健康状态摘要"""
|
||||
adapter = AdminHealthSummaryAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/health/status")
|
||||
async def get_endpoint_health_status(
|
||||
request: Request,
|
||||
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取端点健康状态(简化视图,与用户端点统一)
|
||||
|
||||
与 /health/api-formats 的区别:
|
||||
- /health/status: 返回聚合的时间线状态(50个时间段),基于 Usage 表
|
||||
- /health/api-formats: 返回详细的事件列表,基于 RequestCandidate 表
|
||||
"""
|
||||
adapter = AdminEndpointHealthStatusAdapter(lookback_hours=lookback_hours)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/health/api-formats", response_model=ApiFormatHealthMonitorResponse)
|
||||
async def get_api_format_health_monitor(
|
||||
request: Request,
|
||||
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
|
||||
per_format_limit: int = Query(60, ge=10, le=200, description="每个 API 格式的事件数量"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> ApiFormatHealthMonitorResponse:
|
||||
"""获取按 API 格式聚合的健康监控时间线(详细事件列表)"""
|
||||
adapter = AdminApiFormatHealthMonitorAdapter(
|
||||
lookback_hours=lookback_hours,
|
||||
per_format_limit=per_format_limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/health/key/{key_id}", response_model=HealthStatusResponse)
|
||||
async def get_key_health(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> HealthStatusResponse:
|
||||
"""获取 Key 健康状态"""
|
||||
adapter = AdminKeyHealthAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/health/keys/{key_id}")
|
||||
async def recover_key_health(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
Recover key health status
|
||||
|
||||
Resets health_score to 1.0, closes circuit breaker,
|
||||
cancels auto-disable, and resets all failure counts.
|
||||
|
||||
Parameters:
|
||||
- key_id: Key ID (path parameter)
|
||||
"""
|
||||
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/health/keys")
|
||||
async def recover_all_keys_health(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
Batch recover all circuit-broken keys
|
||||
|
||||
Finds all keys with circuit_breaker_open=True and:
|
||||
1. Resets health_score to 1.0
|
||||
2. Closes circuit breaker
|
||||
3. Resets failure counts
|
||||
"""
|
||||
adapter = AdminRecoverAllKeysHealthAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
class AdminHealthSummaryAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
summary = health_monitor.get_all_health_status(context.db)
|
||||
return HealthSummaryResponse(**summary)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminEndpointHealthStatusAdapter(AdminApiAdapter):
|
||||
"""管理员端点健康状态适配器(与用户端点统一,但包含管理员字段)"""
|
||||
|
||||
lookback_hours: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.services.health.endpoint import EndpointHealthService
|
||||
|
||||
db = context.db
|
||||
|
||||
# 使用共享服务获取健康状态(管理员视图)
|
||||
result = EndpointHealthService.get_endpoint_health_by_format(
|
||||
db=db,
|
||||
lookback_hours=self.lookback_hours,
|
||||
include_admin_fields=True, # 包含管理员字段
|
||||
use_cache=False, # 管理员不使用缓存,确保实时性
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="endpoint_health_status",
|
||||
format_count=len(result),
|
||||
lookback_hours=self.lookback_hours,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
|
||||
lookback_hours: int
|
||||
per_format_limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
now = datetime.now(timezone.utc)
|
||||
since = now - timedelta(hours=self.lookback_hours)
|
||||
|
||||
# 1. 获取所有活跃的 API 格式及其 Provider 数量
|
||||
active_formats = (
|
||||
db.query(
|
||||
ProviderEndpoint.api_format,
|
||||
func.count(func.distinct(ProviderEndpoint.provider_id)).label("provider_count"),
|
||||
)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.group_by(ProviderEndpoint.api_format)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构建所有格式的 provider_count 映射
|
||||
all_formats: Dict[str, int] = {}
|
||||
for api_format_enum, provider_count in active_formats:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
all_formats[api_format] = provider_count
|
||||
|
||||
# 1.1 获取所有活跃的 API 格式及其 API Key 数量
|
||||
active_keys = (
|
||||
db.query(
|
||||
ProviderEndpoint.api_format,
|
||||
func.count(ProviderAPIKey.id).label("key_count"),
|
||||
)
|
||||
.join(ProviderAPIKey, ProviderEndpoint.id == ProviderAPIKey.endpoint_id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
.group_by(ProviderEndpoint.api_format)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构建所有格式的 key_count 映射
|
||||
key_counts: Dict[str, int] = {}
|
||||
for api_format_enum, key_count in active_keys:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
key_counts[api_format] = key_count
|
||||
|
||||
# 1.2 建立每个 API 格式对应的 Endpoint ID 列表,供 Usage 时间线生成使用
|
||||
endpoint_rows = (
|
||||
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
endpoint_map: Dict[str, List[str]] = defaultdict(list)
|
||||
for api_format_enum, endpoint_id in endpoint_rows:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
endpoint_map[api_format].append(endpoint_id)
|
||||
|
||||
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
|
||||
# 只统计最终状态:success, failed, skipped
|
||||
final_statuses = ["success", "failed", "skipped"]
|
||||
status_counts_query = (
|
||||
db.query(
|
||||
ProviderEndpoint.api_format,
|
||||
RequestCandidate.status,
|
||||
func.count(RequestCandidate.id).label("count"),
|
||||
)
|
||||
.join(RequestCandidate, ProviderEndpoint.id == RequestCandidate.endpoint_id)
|
||||
.filter(
|
||||
RequestCandidate.created_at >= since,
|
||||
RequestCandidate.status.in_(final_statuses),
|
||||
)
|
||||
.group_by(ProviderEndpoint.api_format, RequestCandidate.status)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构建每个格式的状态统计
|
||||
status_counts: Dict[str, Dict[str, int]] = {}
|
||||
for api_format_enum, status, count in status_counts_query:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
if api_format not in status_counts:
|
||||
status_counts[api_format] = {"success": 0, "failed": 0, "skipped": 0}
|
||||
status_counts[api_format][status] = count
|
||||
|
||||
# 3. 获取最近一段时间的 RequestCandidate(限制数量)
|
||||
# 使用上面定义的 final_statuses,排除中间状态
|
||||
limit_rows = max(500, self.per_format_limit * 10)
|
||||
rows = (
|
||||
db.query(
|
||||
RequestCandidate,
|
||||
ProviderEndpoint.api_format,
|
||||
ProviderEndpoint.provider_id,
|
||||
)
|
||||
.join(ProviderEndpoint, RequestCandidate.endpoint_id == ProviderEndpoint.id)
|
||||
.filter(
|
||||
RequestCandidate.created_at >= since,
|
||||
RequestCandidate.status.in_(final_statuses),
|
||||
)
|
||||
.order_by(RequestCandidate.created_at.desc())
|
||||
.limit(limit_rows)
|
||||
.all()
|
||||
)
|
||||
|
||||
grouped_attempts: Dict[str, List[RequestCandidate]] = {}
|
||||
|
||||
for attempt, api_format_enum, provider_id in rows:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
if api_format not in grouped_attempts:
|
||||
grouped_attempts[api_format] = []
|
||||
|
||||
# 只保留每个 API 格式最近 per_format_limit 条记录
|
||||
if len(grouped_attempts[api_format]) < self.per_format_limit:
|
||||
grouped_attempts[api_format].append(attempt)
|
||||
|
||||
# 4. 为所有活跃格式生成监控数据(包括没有请求记录的)
|
||||
monitors: List[ApiFormatHealthMonitor] = []
|
||||
for api_format in all_formats:
|
||||
attempts = grouped_attempts.get(api_format, [])
|
||||
# 获取窗口内的真实统计数据
|
||||
# 只统计最终状态:success, failed, skipped
|
||||
# 中间状态(available, pending, used, started)不计入统计
|
||||
format_stats = status_counts.get(api_format, {"success": 0, "failed": 0, "skipped": 0})
|
||||
real_success_count = format_stats.get("success", 0)
|
||||
real_failed_count = format_stats.get("failed", 0)
|
||||
real_skipped_count = format_stats.get("skipped", 0)
|
||||
# total_attempts 只包含最终状态的请求数
|
||||
total_attempts = real_success_count + real_failed_count + real_skipped_count
|
||||
|
||||
# 时间线按时间正序
|
||||
attempts_sorted = list(reversed(attempts))
|
||||
events: List[EndpointHealthEvent] = []
|
||||
for attempt in attempts_sorted:
|
||||
event_timestamp = attempt.finished_at or attempt.started_at or attempt.created_at
|
||||
events.append(
|
||||
EndpointHealthEvent(
|
||||
timestamp=event_timestamp,
|
||||
status=attempt.status,
|
||||
status_code=attempt.status_code,
|
||||
latency_ms=attempt.latency_ms,
|
||||
error_type=attempt.error_type,
|
||||
error_message=attempt.error_message,
|
||||
)
|
||||
)
|
||||
|
||||
# 成功率 = success / (success + failed)
|
||||
# skipped 不算失败,不计入成功率分母
|
||||
# 无实际完成请求时成功率为 1.0(灰色状态)
|
||||
actual_completed = real_success_count + real_failed_count
|
||||
success_rate = real_success_count / actual_completed if actual_completed > 0 else 1.0
|
||||
last_event_at = events[-1].timestamp if events else None
|
||||
|
||||
# 生成 Usage 基于时间窗口的健康时间线
|
||||
timeline_data = EndpointHealthService._generate_timeline_from_usage(
|
||||
db=db,
|
||||
endpoint_ids=endpoint_map.get(api_format, []),
|
||||
now=now,
|
||||
lookback_hours=self.lookback_hours,
|
||||
)
|
||||
|
||||
monitors.append(
|
||||
ApiFormatHealthMonitor(
|
||||
api_format=api_format,
|
||||
total_attempts=total_attempts, # 真实总请求数
|
||||
success_count=real_success_count, # 真实成功数
|
||||
failed_count=real_failed_count, # 真实失败数
|
||||
skipped_count=real_skipped_count, # 真实跳过数
|
||||
success_rate=success_rate, # 基于真实统计的成功率
|
||||
provider_count=all_formats[api_format],
|
||||
key_count=key_counts.get(api_format, 0),
|
||||
last_event_at=last_event_at,
|
||||
events=events, # 限制为 per_format_limit 条(用于时间线显示)
|
||||
timeline=timeline_data.get("timeline", []),
|
||||
time_range_start=timeline_data.get("time_range_start"),
|
||||
time_range_end=timeline_data.get("time_range_end"),
|
||||
)
|
||||
)
|
||||
|
||||
response = ApiFormatHealthMonitorResponse(
|
||||
generated_at=now,
|
||||
formats=monitors,
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="api_format_health_monitor",
|
||||
format_count=len(monitors),
|
||||
lookback_hours=self.lookback_hours,
|
||||
per_format_limit=self.per_format_limit,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminKeyHealthAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
health_data = health_monitor.get_key_health(context.db, self.key_id)
|
||||
if not health_data:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
return HealthStatusResponse(
|
||||
key_id=health_data["key_id"],
|
||||
key_health_score=health_data["health_score"],
|
||||
key_consecutive_failures=health_data["consecutive_failures"],
|
||||
key_last_failure_at=health_data["last_failure_at"],
|
||||
key_is_active=health_data["is_active"],
|
||||
key_statistics=health_data["statistics"],
|
||||
circuit_breaker_open=health_data["circuit_breaker_open"],
|
||||
circuit_breaker_open_at=health_data["circuit_breaker_open_at"],
|
||||
next_probe_at=health_data["next_probe_at"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
key.health_score = 1.0
|
||||
key.consecutive_failures = 0
|
||||
key.last_failure_at = None
|
||||
key.circuit_breaker_open = False
|
||||
key.circuit_breaker_open_at = None
|
||||
key.next_probe_at = None
|
||||
if not key.is_active:
|
||||
key.is_active = True
|
||||
|
||||
db.commit()
|
||||
|
||||
admin_name = context.user.username if context.user else "admin"
|
||||
logger.info(f"管理员恢复Key健康状态: {self.key_id} (health_score: 1.0, circuit_breaker: closed)")
|
||||
|
||||
return {
|
||||
"message": "Key已完全恢复",
|
||||
"details": {
|
||||
"health_score": 1.0,
|
||||
"circuit_breaker_open": False,
|
||||
"is_active": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
||||
"""批量恢复所有熔断 Key 的健康状态"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 查找所有熔断的 Key
|
||||
circuit_open_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.circuit_breaker_open == True).all()
|
||||
)
|
||||
|
||||
if not circuit_open_keys:
|
||||
return {
|
||||
"message": "没有需要恢复的 Key",
|
||||
"recovered_count": 0,
|
||||
"recovered_keys": [],
|
||||
}
|
||||
|
||||
recovered_keys = []
|
||||
for key in circuit_open_keys:
|
||||
key.health_score = 1.0
|
||||
key.consecutive_failures = 0
|
||||
key.last_failure_at = None
|
||||
key.circuit_breaker_open = False
|
||||
key.circuit_breaker_open_at = None
|
||||
key.next_probe_at = None
|
||||
recovered_keys.append(
|
||||
{
|
||||
"key_id": key.id,
|
||||
"key_name": key.name,
|
||||
"endpoint_id": key.endpoint_id,
|
||||
}
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 重置健康监控器的计数
|
||||
from src.services.health.monitor import HealthMonitor, health_open_circuits
|
||||
|
||||
HealthMonitor._open_circuit_keys = 0
|
||||
health_open_circuits.set(0)
|
||||
|
||||
admin_name = context.user.username if context.user else "admin"
|
||||
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
|
||||
|
||||
return {
|
||||
"message": f"已恢复 {len(recovered_keys)} 个 Key",
|
||||
"recovered_count": len(recovered_keys),
|
||||
"recovered_keys": recovered_keys,
|
||||
}
|
||||
425
src/api/admin/endpoints/keys.py
Normal file
425
src/api/admin/endpoints/keys.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Endpoint API Keys 管理
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.key_capabilities import get_capability
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
from src.models.endpoint_models import (
|
||||
BatchUpdateKeyPriorityRequest,
|
||||
EndpointAPIKeyCreate,
|
||||
EndpointAPIKeyResponse,
|
||||
EndpointAPIKeyUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Endpoint Keys"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
|
||||
async def list_endpoint_keys(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[EndpointAPIKeyResponse]:
|
||||
"""获取 Endpoint 的所有 Keys"""
|
||||
adapter = AdminListEndpointKeysAdapter(
|
||||
endpoint_id=endpoint_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
|
||||
async def add_endpoint_key(
|
||||
endpoint_id: str,
|
||||
key_data: EndpointAPIKeyCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""为 Endpoint 添加 Key"""
|
||||
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
|
||||
async def update_endpoint_key(
|
||||
key_id: str,
|
||||
key_data: EndpointAPIKeyUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""更新 Endpoint Key"""
|
||||
adapter = AdminUpdateEndpointKeyAdapter(key_id=key_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/keys/grouped-by-format")
|
||||
async def get_keys_grouped_by_format(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取按 API 格式分组的所有 Keys(用于全局优先级管理)"""
|
||||
adapter = AdminGetKeysGroupedByFormatAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/keys/{key_id}")
|
||||
async def delete_endpoint_key(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除 Endpoint Key"""
|
||||
adapter = AdminDeleteEndpointKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{endpoint_id}/keys/batch-priority")
|
||||
async def batch_update_key_priority(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
priority_data: BatchUpdateKeyPriorityRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""批量更新 Endpoint 下 Keys 的优先级(用于拖动排序)"""
|
||||
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListEndpointKeysAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
|
||||
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
|
||||
.offset(self.skip)
|
||||
.limit(self.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
result: List[EndpointAPIKeyResponse] = []
|
||||
for key in keys:
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
||||
avg_response_time_ms = (
|
||||
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
||||
)
|
||||
|
||||
is_adaptive = key.max_concurrent is None
|
||||
key_dict = key.__dict__.copy()
|
||||
key_dict.pop("_sa_instance_state", None)
|
||||
key_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": None,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
result.append(EndpointAPIKeyResponse(**key_dict))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
key_data: EndpointAPIKeyCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
if self.key_data.endpoint_id != self.endpoint_id:
|
||||
raise InvalidRequestException("endpoint_id 不匹配")
|
||||
|
||||
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
|
||||
now = datetime.now(timezone.utc)
|
||||
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
|
||||
new_key = ProviderAPIKey(
|
||||
id=str(uuid.uuid4()),
|
||||
endpoint_id=self.endpoint_id,
|
||||
api_key=encrypted_key,
|
||||
name=self.key_data.name,
|
||||
note=self.key_data.note,
|
||||
rate_multiplier=self.key_data.rate_multiplier,
|
||||
internal_priority=self.key_data.internal_priority,
|
||||
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
|
||||
rate_limit=self.key_data.rate_limit,
|
||||
daily_limit=self.key_data.daily_limit,
|
||||
monthly_limit=self.key_data.monthly_limit,
|
||||
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
|
||||
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
|
||||
request_count=0,
|
||||
success_count=0,
|
||||
error_count=0,
|
||||
total_response_time_ms=0,
|
||||
is_active=True,
|
||||
last_used_at=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
db.add(new_key)
|
||||
db.commit()
|
||||
db.refresh(new_key)
|
||||
|
||||
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
|
||||
|
||||
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
|
||||
is_adaptive = new_key.max_concurrent is None
|
||||
response_dict = new_key.__dict__.copy()
|
||||
response_dict.pop("_sa_instance_state", None)
|
||||
response_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": self.key_data.api_key,
|
||||
"success_rate": 0.0,
|
||||
"avg_response_time_ms": 0.0,
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return EndpointAPIKeyResponse(**response_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
key_data: EndpointAPIKeyUpdate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
update_data = self.key_data.model_dump(exclude_unset=True)
|
||||
if "api_key" in update_data:
|
||||
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(key, field, value)
|
||||
key.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
|
||||
logger.info(f"[OK] 更新 Key: ID={self.key_id}, Updates={list(update_data.keys())}")
|
||||
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
||||
avg_response_time_ms = (
|
||||
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
||||
)
|
||||
|
||||
is_adaptive = key.max_concurrent is None
|
||||
response_dict = key.__dict__.copy()
|
||||
response_dict.pop("_sa_instance_state", None)
|
||||
response_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": None,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
return EndpointAPIKeyResponse(**response_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
endpoint_id = key.endpoint_id
|
||||
try:
|
||||
db.delete(key)
|
||||
db.commit()
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
|
||||
raise
|
||||
|
||||
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}")
|
||||
return {"message": f"Key {self.key_id} 已删除"}
|
||||
|
||||
|
||||
class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
keys = (
|
||||
db.query(ProviderAPIKey, ProviderEndpoint, Provider)
|
||||
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.order_by(
|
||||
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc()
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
grouped: Dict[str, List[dict]] = {}
|
||||
for key, endpoint, provider in keys:
|
||||
api_format = endpoint.api_format
|
||||
if api_format not in grouped:
|
||||
grouped[api_format] = []
|
||||
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
# 计算健康度指标
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else None
|
||||
avg_response_time_ms = (
|
||||
round(key.total_response_time_ms / key.success_count, 2)
|
||||
if key.success_count > 0
|
||||
else None
|
||||
)
|
||||
|
||||
# 将 capabilities dict 转换为启用的能力简短名称列表
|
||||
caps_list = []
|
||||
if key.capabilities:
|
||||
for cap_name, enabled in key.capabilities.items():
|
||||
if enabled:
|
||||
cap_def = get_capability(cap_name)
|
||||
caps_list.append(cap_def.short_name if cap_def else cap_name)
|
||||
|
||||
grouped[api_format].append(
|
||||
{
|
||||
"id": key.id,
|
||||
"name": key.name,
|
||||
"api_key_masked": masked_key,
|
||||
"internal_priority": key.internal_priority,
|
||||
"global_priority": key.global_priority,
|
||||
"rate_multiplier": key.rate_multiplier,
|
||||
"is_active": key.is_active,
|
||||
"circuit_breaker_open": key.circuit_breaker_open,
|
||||
"provider_name": provider.display_name or provider.name,
|
||||
"endpoint_base_url": endpoint.base_url,
|
||||
"api_format": api_format,
|
||||
"capabilities": caps_list,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": avg_response_time_ms,
|
||||
"request_count": key.request_count,
|
||||
}
|
||||
)
|
||||
|
||||
# 直接返回分组对象,供前端使用
|
||||
return grouped
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
priority_data: BatchUpdateKeyPriorityRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
# 获取所有需要更新的 Key ID
|
||||
key_ids = [item.key_id for item in self.priority_data.priorities]
|
||||
|
||||
# 验证所有 Key 都属于该 Endpoint
|
||||
keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
ProviderAPIKey.id.in_(key_ids),
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(keys) != len(key_ids):
|
||||
found_ids = {k.id for k in keys}
|
||||
missing_ids = set(key_ids) - found_ids
|
||||
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
|
||||
|
||||
# 批量更新优先级
|
||||
key_map = {k.id: k for k in keys}
|
||||
updated_count = 0
|
||||
for item in self.priority_data.priorities:
|
||||
key = key_map.get(item.key_id)
|
||||
if key and key.internal_priority != item.internal_priority:
|
||||
key.internal_priority = item.internal_priority
|
||||
key.updated_at = datetime.now(timezone.utc)
|
||||
updated_count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}")
|
||||
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count}
|
||||
345
src/api/admin/endpoints/routes.py
Normal file
345
src/api/admin/endpoints/routes.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
ProviderEndpoint CRUD 管理 API
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
from src.models.endpoint_models import (
|
||||
ProviderEndpointCreate,
|
||||
ProviderEndpointResponse,
|
||||
ProviderEndpointUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Endpoint Management"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
|
||||
async def list_provider_endpoints(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ProviderEndpointResponse]:
|
||||
"""获取指定 Provider 的所有 Endpoints"""
|
||||
adapter = AdminListProviderEndpointsAdapter(
|
||||
provider_id=provider_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/providers/{provider_id}/endpoints", response_model=ProviderEndpointResponse)
|
||||
async def create_provider_endpoint(
|
||||
provider_id: str,
|
||||
endpoint_data: ProviderEndpointCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""为 Provider 创建新的 Endpoint"""
|
||||
adapter = AdminCreateProviderEndpointAdapter(
|
||||
provider_id=provider_id,
|
||||
endpoint_data=endpoint_data,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{endpoint_id}", response_model=ProviderEndpointResponse)
|
||||
async def get_endpoint(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""获取 Endpoint 详情"""
|
||||
adapter = AdminGetProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{endpoint_id}", response_model=ProviderEndpointResponse)
|
||||
async def update_endpoint(
|
||||
endpoint_id: str,
|
||||
endpoint_data: ProviderEndpointUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""更新 Endpoint"""
|
||||
adapter = AdminUpdateProviderEndpointAdapter(
|
||||
endpoint_id=endpoint_id,
|
||||
endpoint_data=endpoint_data,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{endpoint_id}")
|
||||
async def delete_endpoint(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除 Endpoint(级联删除所有关联的 Keys)"""
|
||||
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListProviderEndpointsAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||
|
||||
endpoints = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(ProviderEndpoint.provider_id == self.provider_id)
|
||||
.order_by(ProviderEndpoint.created_at.desc())
|
||||
.offset(self.skip)
|
||||
.limit(self.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
endpoint_ids = [ep.id for ep in endpoints]
|
||||
total_keys_map = {}
|
||||
active_keys_map = {}
|
||||
if endpoint_ids:
|
||||
total_rows = (
|
||||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
|
||||
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
|
||||
.group_by(ProviderAPIKey.endpoint_id)
|
||||
.all()
|
||||
)
|
||||
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
|
||||
|
||||
active_rows = (
|
||||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.group_by(ProviderAPIKey.endpoint_id)
|
||||
.all()
|
||||
)
|
||||
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
|
||||
|
||||
result: List[ProviderEndpointResponse] = []
|
||||
for endpoint in endpoints:
|
||||
endpoint_dict = {
|
||||
**endpoint.__dict__,
|
||||
"provider_name": provider.name,
|
||||
"api_format": endpoint.api_format,
|
||||
"total_keys": total_keys_map.get(endpoint.id, 0),
|
||||
"active_keys": active_keys_map.get(endpoint.id, 0),
|
||||
}
|
||||
endpoint_dict.pop("_sa_instance_state", None)
|
||||
result.append(ProviderEndpointResponse(**endpoint_dict))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
endpoint_data: ProviderEndpointCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||
|
||||
if self.endpoint_data.provider_id != self.provider_id:
|
||||
raise InvalidRequestException("provider_id 不匹配")
|
||||
|
||||
existing = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(
|
||||
and_(
|
||||
ProviderEndpoint.provider_id == self.provider_id,
|
||||
ProviderEndpoint.api_format == self.endpoint_data.api_format,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise InvalidRequestException(
|
||||
f"Provider {provider.name} 已存在 {self.endpoint_data.api_format} 格式的 Endpoint"
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
new_endpoint = ProviderEndpoint(
|
||||
id=str(uuid.uuid4()),
|
||||
provider_id=self.provider_id,
|
||||
api_format=self.endpoint_data.api_format,
|
||||
base_url=self.endpoint_data.base_url,
|
||||
headers=self.endpoint_data.headers,
|
||||
timeout=self.endpoint_data.timeout,
|
||||
max_retries=self.endpoint_data.max_retries,
|
||||
max_concurrent=self.endpoint_data.max_concurrent,
|
||||
rate_limit=self.endpoint_data.rate_limit,
|
||||
is_active=True,
|
||||
config=self.endpoint_data.config,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
db.add(new_endpoint)
|
||||
db.commit()
|
||||
db.refresh(new_endpoint)
|
||||
|
||||
logger.info(f"[OK] 创建 Endpoint: Provider={provider.name}, Format={self.endpoint_data.api_format}, ID={new_endpoint.id}")
|
||||
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in new_endpoint.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name,
|
||||
api_format=new_endpoint.api_format,
|
||||
total_keys=0,
|
||||
active_keys=0,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint, Provider)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(ProviderEndpoint.id == self.endpoint_id)
|
||||
.first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
endpoint_obj, provider = endpoint
|
||||
total_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
)
|
||||
active_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in endpoint_obj.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name,
|
||||
api_format=endpoint_obj.api_format,
|
||||
total_keys=total_keys,
|
||||
active_keys=active_keys,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
endpoint_data: ProviderEndpointUpdate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
update_data = self.endpoint_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(endpoint, field, value)
|
||||
endpoint.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
db.refresh(endpoint)
|
||||
|
||||
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
|
||||
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
|
||||
|
||||
total_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
)
|
||||
active_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in endpoint.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name if provider else "Unknown",
|
||||
api_format=endpoint.api_format,
|
||||
total_keys=total_keys,
|
||||
active_keys=active_keys,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
keys_count = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
)
|
||||
db.delete(endpoint)
|
||||
db.commit()
|
||||
|
||||
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
|
||||
|
||||
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}
|
||||
Reference in New Issue
Block a user