mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-13 21:17:21 +08:00
426 lines
16 KiB
Python
426 lines
16 KiB
Python
|
|
"""
|
|||
|
|
Endpoint API Keys 管理
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import uuid
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from datetime import datetime, timezone
|
|||
|
|
from typing import Dict, List
|
|||
|
|
|
|||
|
|
from fastapi import APIRouter, Depends, Query, Request
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from src.api.base.admin_adapter import AdminApiAdapter
|
|||
|
|
from src.api.base.pipeline import ApiRequestPipeline
|
|||
|
|
from src.core.crypto import crypto_service
|
|||
|
|
from src.core.exceptions import InvalidRequestException, NotFoundException
|
|||
|
|
from src.core.key_capabilities import get_capability
|
|||
|
|
from src.core.logger import logger
|
|||
|
|
from src.database import get_db
|
|||
|
|
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
|||
|
|
from src.models.endpoint_models import (
|
|||
|
|
BatchUpdateKeyPriorityRequest,
|
|||
|
|
EndpointAPIKeyCreate,
|
|||
|
|
EndpointAPIKeyResponse,
|
|||
|
|
EndpointAPIKeyUpdate,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
router = APIRouter(tags=["Endpoint Keys"])
|
|||
|
|
pipeline = ApiRequestPipeline()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
|
|||
|
|
async def list_endpoint_keys(
|
|||
|
|
endpoint_id: str,
|
|||
|
|
request: Request,
|
|||
|
|
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
|||
|
|
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
) -> List[EndpointAPIKeyResponse]:
|
|||
|
|
"""获取 Endpoint 的所有 Keys"""
|
|||
|
|
adapter = AdminListEndpointKeysAdapter(
|
|||
|
|
endpoint_id=endpoint_id,
|
|||
|
|
skip=skip,
|
|||
|
|
limit=limit,
|
|||
|
|
)
|
|||
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
|
|||
|
|
async def add_endpoint_key(
|
|||
|
|
endpoint_id: str,
|
|||
|
|
key_data: EndpointAPIKeyCreate,
|
|||
|
|
request: Request,
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
) -> EndpointAPIKeyResponse:
|
|||
|
|
"""为 Endpoint 添加 Key"""
|
|||
|
|
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
|
|||
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
|
|||
|
|
async def update_endpoint_key(
|
|||
|
|
key_id: str,
|
|||
|
|
key_data: EndpointAPIKeyUpdate,
|
|||
|
|
request: Request,
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
) -> EndpointAPIKeyResponse:
|
|||
|
|
"""更新 Endpoint Key"""
|
|||
|
|
adapter = AdminUpdateEndpointKeyAdapter(key_id=key_id, key_data=key_data)
|
|||
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/keys/grouped-by-format")
|
|||
|
|
async def get_keys_grouped_by_format(
|
|||
|
|
request: Request,
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
) -> dict:
|
|||
|
|
"""获取按 API 格式分组的所有 Keys(用于全局优先级管理)"""
|
|||
|
|
adapter = AdminGetKeysGroupedByFormatAdapter()
|
|||
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/keys/{key_id}")
|
|||
|
|
async def delete_endpoint_key(
|
|||
|
|
key_id: str,
|
|||
|
|
request: Request,
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
) -> dict:
|
|||
|
|
"""删除 Endpoint Key"""
|
|||
|
|
adapter = AdminDeleteEndpointKeyAdapter(key_id=key_id)
|
|||
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.put("/{endpoint_id}/keys/batch-priority")
|
|||
|
|
async def batch_update_key_priority(
|
|||
|
|
endpoint_id: str,
|
|||
|
|
request: Request,
|
|||
|
|
priority_data: BatchUpdateKeyPriorityRequest,
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
) -> dict:
|
|||
|
|
"""批量更新 Endpoint 下 Keys 的优先级(用于拖动排序)"""
|
|||
|
|
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
|
|||
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# -------- Adapters --------
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AdminListEndpointKeysAdapter(AdminApiAdapter):
|
|||
|
|
endpoint_id: str
|
|||
|
|
skip: int
|
|||
|
|
limit: int
|
|||
|
|
|
|||
|
|
async def handle(self, context): # type: ignore[override]
|
|||
|
|
db = context.db
|
|||
|
|
endpoint = (
|
|||
|
|
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
|||
|
|
)
|
|||
|
|
if not endpoint:
|
|||
|
|
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
|||
|
|
|
|||
|
|
keys = (
|
|||
|
|
db.query(ProviderAPIKey)
|
|||
|
|
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
|
|||
|
|
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
|
|||
|
|
.offset(self.skip)
|
|||
|
|
.limit(self.limit)
|
|||
|
|
.all()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result: List[EndpointAPIKeyResponse] = []
|
|||
|
|
for key in keys:
|
|||
|
|
try:
|
|||
|
|
decrypted_key = crypto_service.decrypt(key.api_key)
|
|||
|
|
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
|||
|
|
except Exception:
|
|||
|
|
masked_key = "***ERROR***"
|
|||
|
|
|
|||
|
|
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
|||
|
|
avg_response_time_ms = (
|
|||
|
|
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
is_adaptive = key.max_concurrent is None
|
|||
|
|
key_dict = key.__dict__.copy()
|
|||
|
|
key_dict.pop("_sa_instance_state", None)
|
|||
|
|
key_dict.update(
|
|||
|
|
{
|
|||
|
|
"api_key_masked": masked_key,
|
|||
|
|
"api_key_plain": None,
|
|||
|
|
"success_rate": success_rate,
|
|||
|
|
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
|||
|
|
"is_adaptive": is_adaptive,
|
|||
|
|
"effective_limit": (
|
|||
|
|
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
|||
|
|
),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
result.append(EndpointAPIKeyResponse(**key_dict))
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
|
|||
|
|
endpoint_id: str
|
|||
|
|
key_data: EndpointAPIKeyCreate
|
|||
|
|
|
|||
|
|
async def handle(self, context): # type: ignore[override]
|
|||
|
|
db = context.db
|
|||
|
|
endpoint = (
|
|||
|
|
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
|||
|
|
)
|
|||
|
|
if not endpoint:
|
|||
|
|
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
|||
|
|
|
|||
|
|
if self.key_data.endpoint_id != self.endpoint_id:
|
|||
|
|
raise InvalidRequestException("endpoint_id 不匹配")
|
|||
|
|
|
|||
|
|
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
|
|||
|
|
now = datetime.now(timezone.utc)
|
|||
|
|
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
|
|||
|
|
new_key = ProviderAPIKey(
|
|||
|
|
id=str(uuid.uuid4()),
|
|||
|
|
endpoint_id=self.endpoint_id,
|
|||
|
|
api_key=encrypted_key,
|
|||
|
|
name=self.key_data.name,
|
|||
|
|
note=self.key_data.note,
|
|||
|
|
rate_multiplier=self.key_data.rate_multiplier,
|
|||
|
|
internal_priority=self.key_data.internal_priority,
|
|||
|
|
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
|
|||
|
|
rate_limit=self.key_data.rate_limit,
|
|||
|
|
daily_limit=self.key_data.daily_limit,
|
|||
|
|
monthly_limit=self.key_data.monthly_limit,
|
|||
|
|
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
|
|||
|
|
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
|
|||
|
|
request_count=0,
|
|||
|
|
success_count=0,
|
|||
|
|
error_count=0,
|
|||
|
|
total_response_time_ms=0,
|
|||
|
|
is_active=True,
|
|||
|
|
last_used_at=None,
|
|||
|
|
created_at=now,
|
|||
|
|
updated_at=now,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
db.add(new_key)
|
|||
|
|
db.commit()
|
|||
|
|
db.refresh(new_key)
|
|||
|
|
|
|||
|
|
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
|
|||
|
|
|
|||
|
|
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
|
|||
|
|
is_adaptive = new_key.max_concurrent is None
|
|||
|
|
response_dict = new_key.__dict__.copy()
|
|||
|
|
response_dict.pop("_sa_instance_state", None)
|
|||
|
|
response_dict.update(
|
|||
|
|
{
|
|||
|
|
"api_key_masked": masked_key,
|
|||
|
|
"api_key_plain": self.key_data.api_key,
|
|||
|
|
"success_rate": 0.0,
|
|||
|
|
"avg_response_time_ms": 0.0,
|
|||
|
|
"is_adaptive": is_adaptive,
|
|||
|
|
"effective_limit": (
|
|||
|
|
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
|
|||
|
|
),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return EndpointAPIKeyResponse(**response_dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
|||
|
|
key_id: str
|
|||
|
|
key_data: EndpointAPIKeyUpdate
|
|||
|
|
|
|||
|
|
async def handle(self, context): # type: ignore[override]
|
|||
|
|
db = context.db
|
|||
|
|
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
|||
|
|
if not key:
|
|||
|
|
raise NotFoundException(f"Key {self.key_id} 不存在")
|
|||
|
|
|
|||
|
|
update_data = self.key_data.model_dump(exclude_unset=True)
|
|||
|
|
if "api_key" in update_data:
|
|||
|
|
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
|
|||
|
|
|
|||
|
|
for field, value in update_data.items():
|
|||
|
|
setattr(key, field, value)
|
|||
|
|
key.updated_at = datetime.now(timezone.utc)
|
|||
|
|
|
|||
|
|
db.commit()
|
|||
|
|
db.refresh(key)
|
|||
|
|
|
|||
|
|
logger.info(f"[OK] 更新 Key: ID={self.key_id}, Updates={list(update_data.keys())}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
decrypted_key = crypto_service.decrypt(key.api_key)
|
|||
|
|
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
|||
|
|
except Exception:
|
|||
|
|
masked_key = "***ERROR***"
|
|||
|
|
|
|||
|
|
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
|||
|
|
avg_response_time_ms = (
|
|||
|
|
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
is_adaptive = key.max_concurrent is None
|
|||
|
|
response_dict = key.__dict__.copy()
|
|||
|
|
response_dict.pop("_sa_instance_state", None)
|
|||
|
|
response_dict.update(
|
|||
|
|
{
|
|||
|
|
"api_key_masked": masked_key,
|
|||
|
|
"api_key_plain": None,
|
|||
|
|
"success_rate": success_rate,
|
|||
|
|
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
|||
|
|
"is_adaptive": is_adaptive,
|
|||
|
|
"effective_limit": (
|
|||
|
|
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
|||
|
|
),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
return EndpointAPIKeyResponse(**response_dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
|
|||
|
|
key_id: str
|
|||
|
|
|
|||
|
|
async def handle(self, context): # type: ignore[override]
|
|||
|
|
db = context.db
|
|||
|
|
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
|||
|
|
if not key:
|
|||
|
|
raise NotFoundException(f"Key {self.key_id} 不存在")
|
|||
|
|
|
|||
|
|
endpoint_id = key.endpoint_id
|
|||
|
|
try:
|
|||
|
|
db.delete(key)
|
|||
|
|
db.commit()
|
|||
|
|
except Exception as exc:
|
|||
|
|
db.rollback()
|
|||
|
|
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}")
|
|||
|
|
return {"message": f"Key {self.key_id} 已删除"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
|
|||
|
|
async def handle(self, context): # type: ignore[override]
|
|||
|
|
db = context.db
|
|||
|
|
|
|||
|
|
keys = (
|
|||
|
|
db.query(ProviderAPIKey, ProviderEndpoint, Provider)
|
|||
|
|
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
|||
|
|
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
|||
|
|
.filter(
|
|||
|
|
ProviderAPIKey.is_active.is_(True),
|
|||
|
|
ProviderEndpoint.is_active.is_(True),
|
|||
|
|
Provider.is_active.is_(True),
|
|||
|
|
)
|
|||
|
|
.order_by(
|
|||
|
|
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc()
|
|||
|
|
)
|
|||
|
|
.all()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
grouped: Dict[str, List[dict]] = {}
|
|||
|
|
for key, endpoint, provider in keys:
|
|||
|
|
api_format = endpoint.api_format
|
|||
|
|
if api_format not in grouped:
|
|||
|
|
grouped[api_format] = []
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
decrypted_key = crypto_service.decrypt(key.api_key)
|
|||
|
|
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
|||
|
|
except Exception:
|
|||
|
|
masked_key = "***ERROR***"
|
|||
|
|
|
|||
|
|
# 计算健康度指标
|
|||
|
|
success_rate = key.success_count / key.request_count if key.request_count > 0 else None
|
|||
|
|
avg_response_time_ms = (
|
|||
|
|
round(key.total_response_time_ms / key.success_count, 2)
|
|||
|
|
if key.success_count > 0
|
|||
|
|
else None
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 将 capabilities dict 转换为启用的能力简短名称列表
|
|||
|
|
caps_list = []
|
|||
|
|
if key.capabilities:
|
|||
|
|
for cap_name, enabled in key.capabilities.items():
|
|||
|
|
if enabled:
|
|||
|
|
cap_def = get_capability(cap_name)
|
|||
|
|
caps_list.append(cap_def.short_name if cap_def else cap_name)
|
|||
|
|
|
|||
|
|
grouped[api_format].append(
|
|||
|
|
{
|
|||
|
|
"id": key.id,
|
|||
|
|
"name": key.name,
|
|||
|
|
"api_key_masked": masked_key,
|
|||
|
|
"internal_priority": key.internal_priority,
|
|||
|
|
"global_priority": key.global_priority,
|
|||
|
|
"rate_multiplier": key.rate_multiplier,
|
|||
|
|
"is_active": key.is_active,
|
|||
|
|
"circuit_breaker_open": key.circuit_breaker_open,
|
|||
|
|
"provider_name": provider.display_name or provider.name,
|
|||
|
|
"endpoint_base_url": endpoint.base_url,
|
|||
|
|
"api_format": api_format,
|
|||
|
|
"capabilities": caps_list,
|
|||
|
|
"success_rate": success_rate,
|
|||
|
|
"avg_response_time_ms": avg_response_time_ms,
|
|||
|
|
"request_count": key.request_count,
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 直接返回分组对象,供前端使用
|
|||
|
|
return grouped
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter):
|
|||
|
|
endpoint_id: str
|
|||
|
|
priority_data: BatchUpdateKeyPriorityRequest
|
|||
|
|
|
|||
|
|
async def handle(self, context): # type: ignore[override]
|
|||
|
|
db = context.db
|
|||
|
|
endpoint = (
|
|||
|
|
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
|||
|
|
)
|
|||
|
|
if not endpoint:
|
|||
|
|
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
|||
|
|
|
|||
|
|
# 获取所有需要更新的 Key ID
|
|||
|
|
key_ids = [item.key_id for item in self.priority_data.priorities]
|
|||
|
|
|
|||
|
|
# 验证所有 Key 都属于该 Endpoint
|
|||
|
|
keys = (
|
|||
|
|
db.query(ProviderAPIKey)
|
|||
|
|
.filter(
|
|||
|
|
ProviderAPIKey.id.in_(key_ids),
|
|||
|
|
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
|||
|
|
)
|
|||
|
|
.all()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if len(keys) != len(key_ids):
|
|||
|
|
found_ids = {k.id for k in keys}
|
|||
|
|
missing_ids = set(key_ids) - found_ids
|
|||
|
|
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
|
|||
|
|
|
|||
|
|
# 批量更新优先级
|
|||
|
|
key_map = {k.id: k for k in keys}
|
|||
|
|
updated_count = 0
|
|||
|
|
for item in self.priority_data.priorities:
|
|||
|
|
key = key_map.get(item.key_id)
|
|||
|
|
if key and key.internal_priority != item.internal_priority:
|
|||
|
|
key.internal_priority = item.internal_priority
|
|||
|
|
key.updated_at = datetime.now(timezone.utc)
|
|||
|
|
updated_count += 1
|
|||
|
|
|
|||
|
|
db.commit()
|
|||
|
|
|
|||
|
|
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}")
|
|||
|
|
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count}
|