Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,24 @@
"""Endpoint management API routers."""
from fastapi import APIRouter
from .concurrency import router as concurrency_router
from .health import router as health_router
from .keys import router as keys_router
from .routes import router as routes_router
router = APIRouter(prefix="/api/admin/endpoints", tags=["Endpoint Management"])
# Endpoint CRUD
router.include_router(routes_router)
# Endpoint Keys management
router.include_router(keys_router)
# Health monitoring
router.include_router(health_router)
# Concurrency control
router.include_router(concurrency_router)
__all__ = ["router"]

View File

@@ -0,0 +1,116 @@
"""
Endpoint 并发控制管理 API
"""
from dataclasses import dataclass
from typing import Optional
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.database import get_db
from src.models.database import ProviderAPIKey, ProviderEndpoint
from src.models.endpoint_models import (
ConcurrencyStatusResponse,
ResetConcurrencyRequest,
)
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
router = APIRouter(tags=["Concurrency Control"])
pipeline = ApiRequestPipeline()
@router.get("/concurrency/endpoint/{endpoint_id}", response_model=ConcurrencyStatusResponse)
async def get_endpoint_concurrency(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse:
"""获取 Endpoint 当前并发状态"""
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/concurrency/key/{key_id}", response_model=ConcurrencyStatusResponse)
async def get_key_concurrency(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse:
"""获取 Key 当前并发状态"""
adapter = AdminKeyConcurrencyAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/concurrency")
async def reset_concurrency(
request: ResetConcurrencyRequest,
http_request: Request,
db: Session = Depends(get_db),
) -> dict:
"""Reset concurrency counters (admin function, use with caution)"""
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id)
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminEndpointConcurrencyAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
concurrency_manager = await get_concurrency_manager()
endpoint_count, _ = await concurrency_manager.get_current_concurrency(
endpoint_id=self.endpoint_id
)
return ConcurrencyStatusResponse(
endpoint_id=self.endpoint_id,
endpoint_current_concurrency=endpoint_count,
endpoint_max_concurrent=endpoint.max_concurrent,
)
@dataclass
class AdminKeyConcurrencyAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
concurrency_manager = await get_concurrency_manager()
_, key_count = await concurrency_manager.get_current_concurrency(key_id=self.key_id)
return ConcurrencyStatusResponse(
key_id=self.key_id,
key_current_concurrency=key_count,
key_max_concurrent=key.max_concurrent,
)
@dataclass
class AdminResetConcurrencyAdapter(AdminApiAdapter):
endpoint_id: Optional[str]
key_id: Optional[str]
async def handle(self, context): # type: ignore[override]
concurrency_manager = await get_concurrency_manager()
await concurrency_manager.reset_concurrency(
endpoint_id=self.endpoint_id, key_id=self.key_id
)
return {"message": "并发计数已重置"}

View File

@@ -0,0 +1,476 @@
"""
Endpoint 健康监控 API
"""
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Dict, List
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, RequestCandidate
from src.models.endpoint_models import (
ApiFormatHealthMonitor,
ApiFormatHealthMonitorResponse,
EndpointHealthEvent,
HealthStatusResponse,
HealthSummaryResponse,
)
from src.services.health.endpoint import EndpointHealthService
from src.services.health.monitor import health_monitor
router = APIRouter(tags=["Endpoint Health"])
pipeline = ApiRequestPipeline()
@router.get("/health/summary", response_model=HealthSummaryResponse)
async def get_health_summary(
request: Request,
db: Session = Depends(get_db),
) -> HealthSummaryResponse:
"""获取健康状态摘要"""
adapter = AdminHealthSummaryAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/health/status")
async def get_endpoint_health_status(
request: Request,
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
db: Session = Depends(get_db),
):
"""
获取端点健康状态(简化视图,与用户端点统一)
与 /health/api-formats 的区别:
- /health/status: 返回聚合的时间线状态50个时间段基于 Usage 表
- /health/api-formats: 返回详细的事件列表,基于 RequestCandidate 表
"""
adapter = AdminEndpointHealthStatusAdapter(lookback_hours=lookback_hours)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/health/api-formats", response_model=ApiFormatHealthMonitorResponse)
async def get_api_format_health_monitor(
request: Request,
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
per_format_limit: int = Query(60, ge=10, le=200, description="每个 API 格式的事件数量"),
db: Session = Depends(get_db),
) -> ApiFormatHealthMonitorResponse:
"""获取按 API 格式聚合的健康监控时间线(详细事件列表)"""
adapter = AdminApiFormatHealthMonitorAdapter(
lookback_hours=lookback_hours,
per_format_limit=per_format_limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/health/key/{key_id}", response_model=HealthStatusResponse)
async def get_key_health(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> HealthStatusResponse:
"""获取 Key 健康状态"""
adapter = AdminKeyHealthAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/health/keys/{key_id}")
async def recover_key_health(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""
Recover key health status
Resets health_score to 1.0, closes circuit breaker,
cancels auto-disable, and resets all failure counts.
Parameters:
- key_id: Key ID (path parameter)
"""
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/health/keys")
async def recover_all_keys_health(
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""
Batch recover all circuit-broken keys
Finds all keys with circuit_breaker_open=True and:
1. Resets health_score to 1.0
2. Closes circuit breaker
3. Resets failure counts
"""
adapter = AdminRecoverAllKeysHealthAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
class AdminHealthSummaryAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
summary = health_monitor.get_all_health_status(context.db)
return HealthSummaryResponse(**summary)
@dataclass
class AdminEndpointHealthStatusAdapter(AdminApiAdapter):
"""管理员端点健康状态适配器(与用户端点统一,但包含管理员字段)"""
lookback_hours: int
async def handle(self, context): # type: ignore[override]
from src.services.health.endpoint import EndpointHealthService
db = context.db
# 使用共享服务获取健康状态(管理员视图)
result = EndpointHealthService.get_endpoint_health_by_format(
db=db,
lookback_hours=self.lookback_hours,
include_admin_fields=True, # 包含管理员字段
use_cache=False, # 管理员不使用缓存,确保实时性
)
context.add_audit_metadata(
action="endpoint_health_status",
format_count=len(result),
lookback_hours=self.lookback_hours,
)
return result
@dataclass
class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
lookback_hours: int
per_format_limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
now = datetime.now(timezone.utc)
since = now - timedelta(hours=self.lookback_hours)
# 1. 获取所有活跃的 API 格式及其 Provider 数量
active_formats = (
db.query(
ProviderEndpoint.api_format,
func.count(func.distinct(ProviderEndpoint.provider_id)).label("provider_count"),
)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.group_by(ProviderEndpoint.api_format)
.all()
)
# 构建所有格式的 provider_count 映射
all_formats: Dict[str, int] = {}
for api_format_enum, provider_count in active_formats:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
all_formats[api_format] = provider_count
# 1.1 获取所有活跃的 API 格式及其 API Key 数量
active_keys = (
db.query(
ProviderEndpoint.api_format,
func.count(ProviderAPIKey.id).label("key_count"),
)
.join(ProviderAPIKey, ProviderEndpoint.id == ProviderAPIKey.endpoint_id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
ProviderAPIKey.is_active.is_(True),
)
.group_by(ProviderEndpoint.api_format)
.all()
)
# 构建所有格式的 key_count 映射
key_counts: Dict[str, int] = {}
for api_format_enum, key_count in active_keys:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
key_counts[api_format] = key_count
# 1.2 建立每个 API 格式对应的 Endpoint ID 列表,供 Usage 时间线生成使用
endpoint_rows = (
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.all()
)
endpoint_map: Dict[str, List[str]] = defaultdict(list)
for api_format_enum, endpoint_id in endpoint_rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
endpoint_map[api_format].append(endpoint_id)
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
# 只统计最终状态success, failed, skipped
final_statuses = ["success", "failed", "skipped"]
status_counts_query = (
db.query(
ProviderEndpoint.api_format,
RequestCandidate.status,
func.count(RequestCandidate.id).label("count"),
)
.join(RequestCandidate, ProviderEndpoint.id == RequestCandidate.endpoint_id)
.filter(
RequestCandidate.created_at >= since,
RequestCandidate.status.in_(final_statuses),
)
.group_by(ProviderEndpoint.api_format, RequestCandidate.status)
.all()
)
# 构建每个格式的状态统计
status_counts: Dict[str, Dict[str, int]] = {}
for api_format_enum, status, count in status_counts_query:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
if api_format not in status_counts:
status_counts[api_format] = {"success": 0, "failed": 0, "skipped": 0}
status_counts[api_format][status] = count
# 3. 获取最近一段时间的 RequestCandidate限制数量
# 使用上面定义的 final_statuses排除中间状态
limit_rows = max(500, self.per_format_limit * 10)
rows = (
db.query(
RequestCandidate,
ProviderEndpoint.api_format,
ProviderEndpoint.provider_id,
)
.join(ProviderEndpoint, RequestCandidate.endpoint_id == ProviderEndpoint.id)
.filter(
RequestCandidate.created_at >= since,
RequestCandidate.status.in_(final_statuses),
)
.order_by(RequestCandidate.created_at.desc())
.limit(limit_rows)
.all()
)
grouped_attempts: Dict[str, List[RequestCandidate]] = {}
for attempt, api_format_enum, provider_id in rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
if api_format not in grouped_attempts:
grouped_attempts[api_format] = []
# 只保留每个 API 格式最近 per_format_limit 条记录
if len(grouped_attempts[api_format]) < self.per_format_limit:
grouped_attempts[api_format].append(attempt)
# 4. 为所有活跃格式生成监控数据(包括没有请求记录的)
monitors: List[ApiFormatHealthMonitor] = []
for api_format in all_formats:
attempts = grouped_attempts.get(api_format, [])
# 获取窗口内的真实统计数据
# 只统计最终状态success, failed, skipped
# 中间状态available, pending, used, started不计入统计
format_stats = status_counts.get(api_format, {"success": 0, "failed": 0, "skipped": 0})
real_success_count = format_stats.get("success", 0)
real_failed_count = format_stats.get("failed", 0)
real_skipped_count = format_stats.get("skipped", 0)
# total_attempts 只包含最终状态的请求数
total_attempts = real_success_count + real_failed_count + real_skipped_count
# 时间线按时间正序
attempts_sorted = list(reversed(attempts))
events: List[EndpointHealthEvent] = []
for attempt in attempts_sorted:
event_timestamp = attempt.finished_at or attempt.started_at or attempt.created_at
events.append(
EndpointHealthEvent(
timestamp=event_timestamp,
status=attempt.status,
status_code=attempt.status_code,
latency_ms=attempt.latency_ms,
error_type=attempt.error_type,
error_message=attempt.error_message,
)
)
# 成功率 = success / (success + failed)
# skipped 不算失败,不计入成功率分母
# 无实际完成请求时成功率为 1.0(灰色状态)
actual_completed = real_success_count + real_failed_count
success_rate = real_success_count / actual_completed if actual_completed > 0 else 1.0
last_event_at = events[-1].timestamp if events else None
# 生成 Usage 基于时间窗口的健康时间线
timeline_data = EndpointHealthService._generate_timeline_from_usage(
db=db,
endpoint_ids=endpoint_map.get(api_format, []),
now=now,
lookback_hours=self.lookback_hours,
)
monitors.append(
ApiFormatHealthMonitor(
api_format=api_format,
total_attempts=total_attempts, # 真实总请求数
success_count=real_success_count, # 真实成功数
failed_count=real_failed_count, # 真实失败数
skipped_count=real_skipped_count, # 真实跳过数
success_rate=success_rate, # 基于真实统计的成功率
provider_count=all_formats[api_format],
key_count=key_counts.get(api_format, 0),
last_event_at=last_event_at,
events=events, # 限制为 per_format_limit 条(用于时间线显示)
timeline=timeline_data.get("timeline", []),
time_range_start=timeline_data.get("time_range_start"),
time_range_end=timeline_data.get("time_range_end"),
)
)
response = ApiFormatHealthMonitorResponse(
generated_at=now,
formats=monitors,
)
context.add_audit_metadata(
action="api_format_health_monitor",
format_count=len(monitors),
lookback_hours=self.lookback_hours,
per_format_limit=self.per_format_limit,
)
return response
@dataclass
class AdminKeyHealthAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
health_data = health_monitor.get_key_health(context.db, self.key_id)
if not health_data:
raise NotFoundException(f"Key {self.key_id} 不存在")
return HealthStatusResponse(
key_id=health_data["key_id"],
key_health_score=health_data["health_score"],
key_consecutive_failures=health_data["consecutive_failures"],
key_last_failure_at=health_data["last_failure_at"],
key_is_active=health_data["is_active"],
key_statistics=health_data["statistics"],
circuit_breaker_open=health_data["circuit_breaker_open"],
circuit_breaker_open_at=health_data["circuit_breaker_open_at"],
next_probe_at=health_data["next_probe_at"],
)
@dataclass
class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
key.health_score = 1.0
key.consecutive_failures = 0
key.last_failure_at = None
key.circuit_breaker_open = False
key.circuit_breaker_open_at = None
key.next_probe_at = None
if not key.is_active:
key.is_active = True
db.commit()
admin_name = context.user.username if context.user else "admin"
logger.info(f"管理员恢复Key健康状态: {self.key_id} (health_score: 1.0, circuit_breaker: closed)")
return {
"message": "Key已完全恢复",
"details": {
"health_score": 1.0,
"circuit_breaker_open": False,
"is_active": True,
},
}
class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
"""批量恢复所有熔断 Key 的健康状态"""
async def handle(self, context): # type: ignore[override]
db = context.db
# 查找所有熔断的 Key
circuit_open_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.circuit_breaker_open == True).all()
)
if not circuit_open_keys:
return {
"message": "没有需要恢复的 Key",
"recovered_count": 0,
"recovered_keys": [],
}
recovered_keys = []
for key in circuit_open_keys:
key.health_score = 1.0
key.consecutive_failures = 0
key.last_failure_at = None
key.circuit_breaker_open = False
key.circuit_breaker_open_at = None
key.next_probe_at = None
recovered_keys.append(
{
"key_id": key.id,
"key_name": key.name,
"endpoint_id": key.endpoint_id,
}
)
db.commit()
# 重置健康监控器的计数
from src.services.health.monitor import HealthMonitor, health_open_circuits
HealthMonitor._open_circuit_keys = 0
health_open_circuits.set(0)
admin_name = context.user.username if context.user else "admin"
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
return {
"message": f"已恢复 {len(recovered_keys)} 个 Key",
"recovered_count": len(recovered_keys),
"recovered_keys": recovered_keys,
}

View File

@@ -0,0 +1,425 @@
"""
Endpoint API Keys 管理
"""
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Dict, List
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.crypto import crypto_service
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.key_capabilities import get_capability
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.models.endpoint_models import (
BatchUpdateKeyPriorityRequest,
EndpointAPIKeyCreate,
EndpointAPIKeyResponse,
EndpointAPIKeyUpdate,
)
router = APIRouter(tags=["Endpoint Keys"])
pipeline = ApiRequestPipeline()
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
async def list_endpoint_keys(
endpoint_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[EndpointAPIKeyResponse]:
"""获取 Endpoint 的所有 Keys"""
adapter = AdminListEndpointKeysAdapter(
endpoint_id=endpoint_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
async def add_endpoint_key(
endpoint_id: str,
key_data: EndpointAPIKeyCreate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""为 Endpoint 添加 Key"""
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
async def update_endpoint_key(
key_id: str,
key_data: EndpointAPIKeyUpdate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""更新 Endpoint Key"""
adapter = AdminUpdateEndpointKeyAdapter(key_id=key_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/keys/grouped-by-format")
async def get_keys_grouped_by_format(
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""获取按 API 格式分组的所有 Keys用于全局优先级管理"""
adapter = AdminGetKeysGroupedByFormatAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/keys/{key_id}")
async def delete_endpoint_key(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""删除 Endpoint Key"""
adapter = AdminDeleteEndpointKeyAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{endpoint_id}/keys/batch-priority")
async def batch_update_key_priority(
endpoint_id: str,
request: Request,
priority_data: BatchUpdateKeyPriorityRequest,
db: Session = Depends(get_db),
) -> dict:
"""批量更新 Endpoint 下 Keys 的优先级(用于拖动排序)"""
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminListEndpointKeysAdapter(AdminApiAdapter):
endpoint_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
.offset(self.skip)
.limit(self.limit)
.all()
)
result: List[EndpointAPIKeyResponse] = []
for key in keys:
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
key_dict = key.__dict__.copy()
key_dict.pop("_sa_instance_state", None)
key_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
result.append(EndpointAPIKeyResponse(**key_dict))
return result
@dataclass
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
endpoint_id: str
key_data: EndpointAPIKeyCreate
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
if self.key_data.endpoint_id != self.endpoint_id:
raise InvalidRequestException("endpoint_id 不匹配")
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
now = datetime.now(timezone.utc)
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=self.endpoint_id,
api_key=encrypted_key,
name=self.key_data.name,
note=self.key_data.note,
rate_multiplier=self.key_data.rate_multiplier,
internal_priority=self.key_data.internal_priority,
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
rate_limit=self.key_data.rate_limit,
daily_limit=self.key_data.daily_limit,
monthly_limit=self.key_data.monthly_limit,
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
request_count=0,
success_count=0,
error_count=0,
total_response_time_ms=0,
is_active=True,
last_used_at=None,
created_at=now,
updated_at=now,
)
db.add(new_key)
db.commit()
db.refresh(new_key)
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
is_adaptive = new_key.max_concurrent is None
response_dict = new_key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": self.key_data.api_key,
"success_rate": 0.0,
"avg_response_time_ms": 0.0,
"is_adaptive": is_adaptive,
"effective_limit": (
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
@dataclass
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
key_id: str
key_data: EndpointAPIKeyUpdate
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
update_data = self.key_data.model_dump(exclude_unset=True)
if "api_key" in update_data:
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
for field, value in update_data.items():
setattr(key, field, value)
key.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(key)
logger.info(f"[OK] 更新 Key: ID={self.key_id}, Updates={list(update_data.keys())}")
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
response_dict = key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
@dataclass
class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
endpoint_id = key.endpoint_id
try:
db.delete(key)
db.commit()
except Exception as exc:
db.rollback()
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
raise
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}")
return {"message": f"Key {self.key_id} 已删除"}
class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
keys = (
db.query(ProviderAPIKey, ProviderEndpoint, Provider)
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderAPIKey.is_active.is_(True),
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.order_by(
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc()
)
.all()
)
grouped: Dict[str, List[dict]] = {}
for key, endpoint, provider in keys:
api_format = endpoint.api_format
if api_format not in grouped:
grouped[api_format] = []
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
# 计算健康度指标
success_rate = key.success_count / key.request_count if key.request_count > 0 else None
avg_response_time_ms = (
round(key.total_response_time_ms / key.success_count, 2)
if key.success_count > 0
else None
)
# 将 capabilities dict 转换为启用的能力简短名称列表
caps_list = []
if key.capabilities:
for cap_name, enabled in key.capabilities.items():
if enabled:
cap_def = get_capability(cap_name)
caps_list.append(cap_def.short_name if cap_def else cap_name)
grouped[api_format].append(
{
"id": key.id,
"name": key.name,
"api_key_masked": masked_key,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"rate_multiplier": key.rate_multiplier,
"is_active": key.is_active,
"circuit_breaker_open": key.circuit_breaker_open,
"provider_name": provider.display_name or provider.name,
"endpoint_base_url": endpoint.base_url,
"api_format": api_format,
"capabilities": caps_list,
"success_rate": success_rate,
"avg_response_time_ms": avg_response_time_ms,
"request_count": key.request_count,
}
)
# 直接返回分组对象,供前端使用
return grouped
@dataclass
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter):
endpoint_id: str
priority_data: BatchUpdateKeyPriorityRequest
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
# 获取所有需要更新的 Key ID
key_ids = [item.key_id for item in self.priority_data.priorities]
# 验证所有 Key 都属于该 Endpoint
keys = (
db.query(ProviderAPIKey)
.filter(
ProviderAPIKey.id.in_(key_ids),
ProviderAPIKey.endpoint_id == self.endpoint_id,
)
.all()
)
if len(keys) != len(key_ids):
found_ids = {k.id for k in keys}
missing_ids = set(key_ids) - found_ids
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
# 批量更新优先级
key_map = {k.id: k for k in keys}
updated_count = 0
for item in self.priority_data.priorities:
key = key_map.get(item.key_id)
if key and key.internal_priority != item.internal_priority:
key.internal_priority = item.internal_priority
key.updated_at = datetime.now(timezone.utc)
updated_count += 1
db.commit()
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}")
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count}

View File

@@ -0,0 +1,345 @@
"""
ProviderEndpoint CRUD 管理 API
"""
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import List
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import and_, func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.models.endpoint_models import (
ProviderEndpointCreate,
ProviderEndpointResponse,
ProviderEndpointUpdate,
)
router = APIRouter(tags=["Endpoint Management"])
pipeline = ApiRequestPipeline()
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
async def list_provider_endpoints(
provider_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[ProviderEndpointResponse]:
"""获取指定 Provider 的所有 Endpoints"""
adapter = AdminListProviderEndpointsAdapter(
provider_id=provider_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/providers/{provider_id}/endpoints", response_model=ProviderEndpointResponse)
async def create_provider_endpoint(
provider_id: str,
endpoint_data: ProviderEndpointCreate,
request: Request,
db: Session = Depends(get_db),
) -> ProviderEndpointResponse:
"""为 Provider 创建新的 Endpoint"""
adapter = AdminCreateProviderEndpointAdapter(
provider_id=provider_id,
endpoint_data=endpoint_data,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{endpoint_id}", response_model=ProviderEndpointResponse)
async def get_endpoint(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ProviderEndpointResponse:
"""获取 Endpoint 详情"""
adapter = AdminGetProviderEndpointAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{endpoint_id}", response_model=ProviderEndpointResponse)
async def update_endpoint(
endpoint_id: str,
endpoint_data: ProviderEndpointUpdate,
request: Request,
db: Session = Depends(get_db),
) -> ProviderEndpointResponse:
"""更新 Endpoint"""
adapter = AdminUpdateProviderEndpointAdapter(
endpoint_id=endpoint_id,
endpoint_data=endpoint_data,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{endpoint_id}")
async def delete_endpoint(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""删除 Endpoint级联删除所有关联的 Keys"""
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminListProviderEndpointsAdapter(AdminApiAdapter):
provider_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
endpoints = (
db.query(ProviderEndpoint)
.filter(ProviderEndpoint.provider_id == self.provider_id)
.order_by(ProviderEndpoint.created_at.desc())
.offset(self.skip)
.limit(self.limit)
.all()
)
endpoint_ids = [ep.id for ep in endpoints]
total_keys_map = {}
active_keys_map = {}
if endpoint_ids:
total_rows = (
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
.group_by(ProviderAPIKey.endpoint_id)
.all()
)
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
active_rows = (
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
.filter(
and_(
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
ProviderAPIKey.is_active.is_(True),
)
)
.group_by(ProviderAPIKey.endpoint_id)
.all()
)
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
result: List[ProviderEndpointResponse] = []
for endpoint in endpoints:
endpoint_dict = {
**endpoint.__dict__,
"provider_name": provider.name,
"api_format": endpoint.api_format,
"total_keys": total_keys_map.get(endpoint.id, 0),
"active_keys": active_keys_map.get(endpoint.id, 0),
}
endpoint_dict.pop("_sa_instance_state", None)
result.append(ProviderEndpointResponse(**endpoint_dict))
return result
@dataclass
class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
provider_id: str
endpoint_data: ProviderEndpointCreate
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
if self.endpoint_data.provider_id != self.provider_id:
raise InvalidRequestException("provider_id 不匹配")
existing = (
db.query(ProviderEndpoint)
.filter(
and_(
ProviderEndpoint.provider_id == self.provider_id,
ProviderEndpoint.api_format == self.endpoint_data.api_format,
)
)
.first()
)
if existing:
raise InvalidRequestException(
f"Provider {provider.name} 已存在 {self.endpoint_data.api_format} 格式的 Endpoint"
)
now = datetime.now(timezone.utc)
new_endpoint = ProviderEndpoint(
id=str(uuid.uuid4()),
provider_id=self.provider_id,
api_format=self.endpoint_data.api_format,
base_url=self.endpoint_data.base_url,
headers=self.endpoint_data.headers,
timeout=self.endpoint_data.timeout,
max_retries=self.endpoint_data.max_retries,
max_concurrent=self.endpoint_data.max_concurrent,
rate_limit=self.endpoint_data.rate_limit,
is_active=True,
config=self.endpoint_data.config,
created_at=now,
updated_at=now,
)
db.add(new_endpoint)
db.commit()
db.refresh(new_endpoint)
logger.info(f"[OK] 创建 Endpoint: Provider={provider.name}, Format={self.endpoint_data.api_format}, ID={new_endpoint.id}")
endpoint_dict = {
k: v
for k, v in new_endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=new_endpoint.api_format,
total_keys=0,
active_keys=0,
)
@dataclass
class AdminGetProviderEndpointAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint, Provider)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(ProviderEndpoint.id == self.endpoint_id)
.first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
endpoint_obj, provider = endpoint
total_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
)
active_keys = (
db.query(ProviderAPIKey)
.filter(
and_(
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
)
endpoint_dict = {
k: v
for k, v in endpoint_obj.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=endpoint_obj.api_format,
total_keys=total_keys,
active_keys=active_keys,
)
@dataclass
class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
endpoint_id: str
endpoint_data: ProviderEndpointUpdate
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
update_data = self.endpoint_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(endpoint, field, value)
endpoint.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(endpoint)
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
total_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
)
active_keys = (
db.query(ProviderAPIKey)
.filter(
and_(
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
)
endpoint_dict = {
k: v
for k, v in endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name if provider else "Unknown",
api_format=endpoint.api_format,
total_keys=total_keys,
active_keys=active_keys,
)
@dataclass
class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys_count = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
)
db.delete(endpoint)
db.commit()
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}