mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 08:12:26 +08:00
Initial commit
This commit is contained in:
20
src/api/public/__init__.py
Normal file
20
src/api/public/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Public-facing API routers."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .capabilities import router as capabilities_router
|
||||
from .catalog import router as catalog_router
|
||||
from .claude import router as claude_router
|
||||
from .gemini import router as gemini_router
|
||||
from .openai import router as openai_router
|
||||
from .system_catalog import router as system_catalog_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(claude_router, tags=["Claude API"])
|
||||
router.include_router(openai_router)
|
||||
router.include_router(gemini_router, tags=["Gemini API"])
|
||||
router.include_router(system_catalog_router, tags=["System Catalog"])
|
||||
router.include_router(catalog_router)
|
||||
router.include_router(capabilities_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
104
src/api/public/capabilities.py
Normal file
104
src/api/public/capabilities.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
能力配置公共 API
|
||||
|
||||
提供系统支持的能力列表,供前端展示和配置使用。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.key_capabilities import (
|
||||
get_all_capabilities,
|
||||
get_user_configurable_capabilities,
|
||||
)
|
||||
from src.database import get_db
|
||||
|
||||
router = APIRouter(prefix="/api/capabilities", tags=["Capabilities"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_capabilities():
|
||||
"""获取所有能力定义"""
|
||||
return {
|
||||
"capabilities": [
|
||||
{
|
||||
"name": cap.name,
|
||||
"display_name": cap.display_name,
|
||||
"short_name": cap.short_name,
|
||||
"description": cap.description,
|
||||
"match_mode": cap.match_mode.value,
|
||||
"config_mode": cap.config_mode.value,
|
||||
}
|
||||
for cap in get_all_capabilities()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/user-configurable")
|
||||
async def list_user_configurable_capabilities():
|
||||
"""获取用户可配置的能力列表(用于前端展示配置选项)"""
|
||||
return {
|
||||
"capabilities": [
|
||||
{
|
||||
"name": cap.name,
|
||||
"display_name": cap.display_name,
|
||||
"short_name": cap.short_name,
|
||||
"description": cap.description,
|
||||
"match_mode": cap.match_mode.value,
|
||||
"config_mode": cap.config_mode.value,
|
||||
}
|
||||
for cap in get_user_configurable_capabilities()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/model/{model_name}")
|
||||
async def get_model_supported_capabilities(
|
||||
model_name: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定模型支持的能力列表
|
||||
|
||||
Args:
|
||||
model_name: 模型名称(如 claude-sonnet-4-20250514)
|
||||
|
||||
Returns:
|
||||
模型支持的能力列表,以及每个能力的详细定义
|
||||
"""
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
|
||||
|
||||
if not global_model:
|
||||
return {
|
||||
"model": model_name,
|
||||
"supported_capabilities": [],
|
||||
"capability_details": [],
|
||||
"error": "模型不存在",
|
||||
}
|
||||
|
||||
supported_caps = global_model.supported_capabilities or []
|
||||
|
||||
# 获取支持的能力详情
|
||||
all_caps = {cap.name: cap for cap in get_all_capabilities()}
|
||||
capability_details = []
|
||||
for cap_name in supported_caps:
|
||||
if cap_name in all_caps:
|
||||
cap = all_caps[cap_name]
|
||||
capability_details.append({
|
||||
"name": cap.name,
|
||||
"display_name": cap.display_name,
|
||||
"description": cap.description,
|
||||
"match_mode": cap.match_mode.value,
|
||||
"config_mode": cap.config_mode.value,
|
||||
})
|
||||
|
||||
return {
|
||||
"model": model_name,
|
||||
"global_model_id": str(global_model.id),
|
||||
"global_model_name": global_model.name,
|
||||
"supported_capabilities": supported_caps,
|
||||
"capability_details": capability_details,
|
||||
}
|
||||
643
src/api/public/catalog.py
Normal file
643
src/api/public/catalog.py
Normal file
@@ -0,0 +1,643 @@
|
||||
"""
|
||||
公开API端点 - 用户可查看的提供商和模型信息
|
||||
不包含敏感信息,普通用户可访问
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import (
|
||||
ProviderStatsResponse,
|
||||
PublicGlobalModelListResponse,
|
||||
PublicGlobalModelResponse,
|
||||
PublicModelMappingResponse,
|
||||
PublicModelResponse,
|
||||
PublicProviderResponse,
|
||||
)
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
ModelMapping,
|
||||
Provider,
|
||||
ProviderEndpoint,
|
||||
RequestCandidate,
|
||||
)
|
||||
from src.models.endpoint_models import (
|
||||
PublicApiFormatHealthMonitor,
|
||||
PublicApiFormatHealthMonitorResponse,
|
||||
PublicHealthEvent,
|
||||
)
|
||||
from src.services.health.endpoint import EndpointHealthService
|
||||
|
||||
router = APIRouter(prefix="/api/public", tags=["Public Catalog"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/providers", response_model=List[PublicProviderResponse])
|
||||
async def get_public_providers(
|
||||
request: Request,
|
||||
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
|
||||
skip: int = Query(0, description="跳过记录数"),
|
||||
limit: int = Query(100, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取提供商列表(用户视图)。"""
|
||||
|
||||
adapter = PublicProvidersAdapter(is_active=is_active, skip=skip, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/models", response_model=List[PublicModelResponse])
|
||||
async def get_public_models(
|
||||
request: Request,
|
||||
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
|
||||
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
|
||||
skip: int = Query(0, description="跳过记录数"),
|
||||
limit: int = Query(100, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = PublicModelsAdapter(
|
||||
provider_id=provider_id, is_active=is_active, skip=skip, limit=limit
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/model-mappings", response_model=List[PublicModelMappingResponse])
|
||||
async def get_public_model_mappings(
|
||||
request: Request,
|
||||
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
|
||||
alias: Optional[str] = Query(None, description="别名过滤(原source_model)"),
|
||||
skip: int = Query(0, description="跳过记录数"),
|
||||
limit: int = Query(100, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = PublicModelMappingsAdapter(
|
||||
provider_id=provider_id,
|
||||
alias=alias,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ProviderStatsResponse)
|
||||
async def get_public_stats(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = PublicStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/search/models")
|
||||
async def search_models(
|
||||
request: Request,
|
||||
q: str = Query(..., description="搜索关键词"),
|
||||
provider_id: Optional[int] = Query(None, description="提供商ID过滤"),
|
||||
limit: int = Query(20, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = PublicSearchModelsAdapter(query=q, provider_id=provider_id, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/health/api-formats", response_model=PublicApiFormatHealthMonitorResponse)
|
||||
async def get_public_api_format_health(
|
||||
request: Request,
|
||||
lookback_hours: int = Query(6, ge=1, le=168, description="回溯小时数"),
|
||||
per_format_limit: int = Query(100, ge=10, le=500, description="每个格式的事件数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取各 API 格式的健康监控数据(公开版,不含敏感信息)"""
|
||||
adapter = PublicApiFormatHealthMonitorAdapter(
|
||||
lookback_hours=lookback_hours,
|
||||
per_format_limit=per_format_limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/global-models", response_model=PublicGlobalModelListResponse)
|
||||
async def get_public_global_models(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回记录数限制"),
|
||||
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取 GlobalModel 列表(用户视图,只读)"""
|
||||
adapter = PublicGlobalModelsAdapter(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
search=search,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
# -------- 公共适配器 --------
|
||||
|
||||
|
||||
class PublicApiAdapter(ApiAdapter):
|
||||
mode = ApiMode.PUBLIC
|
||||
|
||||
def authorize(self, context): # type: ignore[override]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicProvidersAdapter(PublicApiAdapter):
|
||||
is_active: Optional[bool]
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求提供商列表")
|
||||
query = db.query(Provider)
|
||||
if self.is_active is not None:
|
||||
query = query.filter(Provider.is_active == self.is_active)
|
||||
else:
|
||||
query = query.filter(Provider.is_active.is_(True))
|
||||
|
||||
providers = query.offset(self.skip).limit(self.limit).all()
|
||||
result = []
|
||||
for provider in providers:
|
||||
models_count = db.query(Model).filter(Model.provider_id == provider.id).count()
|
||||
active_models_count = (
|
||||
db.query(Model)
|
||||
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True)))
|
||||
.count()
|
||||
)
|
||||
mappings_count = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
and_(ModelMapping.provider_id == provider.id, ModelMapping.is_active.is_(True))
|
||||
)
|
||||
.count()
|
||||
)
|
||||
endpoints_count = len(provider.endpoints) if provider.endpoints else 0
|
||||
active_endpoints_count = (
|
||||
sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0
|
||||
)
|
||||
provider_data = PublicProviderResponse(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
display_name=provider.display_name,
|
||||
description=provider.description,
|
||||
is_active=provider.is_active,
|
||||
provider_priority=provider.provider_priority,
|
||||
models_count=models_count,
|
||||
active_models_count=active_models_count,
|
||||
mappings_count=mappings_count,
|
||||
endpoints_count=endpoints_count,
|
||||
active_endpoints_count=active_endpoints_count,
|
||||
)
|
||||
result.append(provider_data.model_dump())
|
||||
|
||||
logger.debug(f"返回 {len(result)} 个提供商信息")
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicModelsAdapter(PublicApiAdapter):
|
||||
provider_id: Optional[str]
|
||||
is_active: Optional[bool]
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求模型列表")
|
||||
query = (
|
||||
db.query(Model, Provider)
|
||||
.options(joinedload(Model.global_model))
|
||||
.join(Provider)
|
||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
||||
)
|
||||
if self.provider_id is not None:
|
||||
query = query.filter(Model.provider_id == self.provider_id)
|
||||
results = query.offset(self.skip).limit(self.limit).all()
|
||||
|
||||
response = []
|
||||
for model, provider in results:
|
||||
global_model = model.global_model
|
||||
display_name = global_model.display_name if global_model else model.provider_model_name
|
||||
unified_name = global_model.name if global_model else model.provider_model_name
|
||||
model_data = PublicModelResponse(
|
||||
id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
provider_name=provider.name,
|
||||
provider_display_name=provider.display_name,
|
||||
name=unified_name,
|
||||
display_name=display_name,
|
||||
description=global_model.description if global_model else None,
|
||||
tags=None,
|
||||
icon_url=global_model.icon_url if global_model else None,
|
||||
input_price_per_1m=model.get_effective_input_price(),
|
||||
output_price_per_1m=model.get_effective_output_price(),
|
||||
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||
cache_read_price_per_1m=model.get_effective_cache_read_price(),
|
||||
supports_vision=model.get_effective_supports_vision(),
|
||||
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
supports_streaming=model.get_effective_supports_streaming(),
|
||||
is_active=model.is_active,
|
||||
)
|
||||
response.append(model_data.model_dump())
|
||||
|
||||
logger.debug(f"返回 {len(response)} 个模型信息")
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicModelMappingsAdapter(PublicApiAdapter):
|
||||
provider_id: Optional[str]
|
||||
alias: Optional[str] # 原 source_model,改为 alias
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求模型映射列表")
|
||||
|
||||
query = (
|
||||
db.query(ModelMapping, GlobalModel, Provider)
|
||||
.join(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
|
||||
.outerjoin(Provider, ModelMapping.provider_id == Provider.id)
|
||||
.filter(
|
||||
and_(
|
||||
ModelMapping.is_active.is_(True),
|
||||
GlobalModel.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if self.provider_id is not None:
|
||||
provider_global_model_ids = (
|
||||
db.query(Model.global_model_id)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.filter(
|
||||
Provider.id == self.provider_id,
|
||||
Model.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
Model.global_model_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelMapping.provider_id == self.provider_id,
|
||||
and_(
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.target_global_model_id.in_(provider_global_model_ids),
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
|
||||
if self.alias is not None:
|
||||
query = query.filter(ModelMapping.source_model.ilike(f"%{self.alias}%"))
|
||||
|
||||
results = query.offset(self.skip).limit(self.limit).all()
|
||||
response = []
|
||||
for mapping, global_model, provider in results:
|
||||
scope = "provider" if mapping.provider_id else "global"
|
||||
mapping_data = PublicModelMappingResponse(
|
||||
id=mapping.id,
|
||||
source_model=mapping.source_model,
|
||||
target_global_model_id=mapping.target_global_model_id,
|
||||
target_global_model_name=global_model.name if global_model else None,
|
||||
target_global_model_display_name=(
|
||||
global_model.display_name if global_model else None
|
||||
),
|
||||
provider_id=mapping.provider_id,
|
||||
scope=scope,
|
||||
is_active=mapping.is_active,
|
||||
)
|
||||
response.append(mapping_data.model_dump())
|
||||
|
||||
logger.debug(f"返回 {len(response)} 个模型映射")
|
||||
return response
|
||||
|
||||
|
||||
class PublicStatsAdapter(PublicApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求系统统计信息")
|
||||
active_providers = db.query(Provider).filter(Provider.is_active.is_(True)).count()
|
||||
active_models = (
|
||||
db.query(Model)
|
||||
.join(Provider)
|
||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
||||
.count()
|
||||
)
|
||||
from ...models.database import ModelMapping
|
||||
|
||||
active_mappings = db.query(ModelMapping).filter(ModelMapping.is_active.is_(True)).count()
|
||||
formats = (
|
||||
db.query(Provider.api_format).filter(Provider.is_active.is_(True)).distinct().all()
|
||||
)
|
||||
supported_formats = [f.api_format for f in formats if f.api_format]
|
||||
stats = ProviderStatsResponse(
|
||||
total_providers=active_providers,
|
||||
active_providers=active_providers,
|
||||
total_models=active_models,
|
||||
active_models=active_models,
|
||||
total_mappings=active_mappings,
|
||||
supported_formats=supported_formats,
|
||||
)
|
||||
logger.debug("返回系统统计信息")
|
||||
return stats.model_dump()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicSearchModelsAdapter(PublicApiAdapter):
|
||||
query: str
|
||||
provider_id: Optional[int]
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug(f"公共API搜索模型: {self.query}")
|
||||
query_stmt = (
|
||||
db.query(Model, Provider)
|
||||
.options(joinedload(Model.global_model))
|
||||
.join(Provider)
|
||||
.outerjoin(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
||||
)
|
||||
search_filter = (
|
||||
Model.provider_model_name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.display_name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.description.ilike(f"%{self.query}%")
|
||||
)
|
||||
query_stmt = query_stmt.filter(search_filter)
|
||||
if self.provider_id is not None:
|
||||
query_stmt = query_stmt.filter(Model.provider_id == self.provider_id)
|
||||
results = query_stmt.limit(self.limit).all()
|
||||
|
||||
response = []
|
||||
for model, provider in results:
|
||||
global_model = model.global_model
|
||||
display_name = global_model.display_name if global_model else model.provider_model_name
|
||||
unified_name = global_model.name if global_model else model.provider_model_name
|
||||
model_data = PublicModelResponse(
|
||||
id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
provider_name=provider.name,
|
||||
provider_display_name=provider.display_name,
|
||||
name=unified_name,
|
||||
display_name=display_name,
|
||||
description=global_model.description if global_model else None,
|
||||
tags=None,
|
||||
icon_url=global_model.icon_url if global_model else None,
|
||||
input_price_per_1m=model.get_effective_input_price(),
|
||||
output_price_per_1m=model.get_effective_output_price(),
|
||||
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||
cache_read_price_per_1m=model.get_effective_cache_read_price(),
|
||||
supports_vision=model.get_effective_supports_vision(),
|
||||
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
supports_streaming=model.get_effective_supports_streaming(),
|
||||
is_active=model.is_active,
|
||||
)
|
||||
response.append(model_data.model_dump())
|
||||
|
||||
logger.debug(f"搜索 '{self.query}' 返回 {len(response)} 个结果")
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicApiFormatHealthMonitorAdapter(PublicApiAdapter):
|
||||
"""公开版 API 格式健康监控适配器(返回 events 数组,前端复用 EndpointHealthTimeline 组件)"""
|
||||
|
||||
lookback_hours: int
|
||||
per_format_limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
now = datetime.now(timezone.utc)
|
||||
since = now - timedelta(hours=self.lookback_hours)
|
||||
|
||||
# 1. 获取所有活跃的 API 格式
|
||||
active_formats = (
|
||||
db.query(ProviderEndpoint.api_format)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
|
||||
all_formats: List[str] = []
|
||||
for (api_format_enum,) in active_formats:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
all_formats.append(api_format)
|
||||
|
||||
# API 格式 -> Endpoint ID 映射(用于 Usage 时间线)
|
||||
endpoint_rows = (
|
||||
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
endpoint_map: Dict[str, List[str]] = defaultdict(list)
|
||||
for api_format_enum, endpoint_id in endpoint_rows:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
endpoint_map[api_format].append(endpoint_id)
|
||||
|
||||
# 2. 获取最近一段时间的 RequestCandidate(限制数量)
|
||||
# 只查询最终状态的记录:success, failed, skipped
|
||||
final_statuses = ["success", "failed", "skipped"]
|
||||
limit_rows = max(500, self.per_format_limit * 10)
|
||||
rows = (
|
||||
db.query(
|
||||
RequestCandidate,
|
||||
ProviderEndpoint.api_format,
|
||||
)
|
||||
.join(ProviderEndpoint, RequestCandidate.endpoint_id == ProviderEndpoint.id)
|
||||
.filter(
|
||||
RequestCandidate.created_at >= since,
|
||||
RequestCandidate.status.in_(final_statuses),
|
||||
)
|
||||
.order_by(RequestCandidate.created_at.desc())
|
||||
.limit(limit_rows)
|
||||
.all()
|
||||
)
|
||||
|
||||
grouped_candidates: Dict[str, List[RequestCandidate]] = {}
|
||||
|
||||
for candidate, api_format_enum in rows:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
if api_format not in grouped_candidates:
|
||||
grouped_candidates[api_format] = []
|
||||
|
||||
if len(grouped_candidates[api_format]) < self.per_format_limit:
|
||||
grouped_candidates[api_format].append(candidate)
|
||||
|
||||
# 3. 为所有活跃格式生成监控数据
|
||||
monitors: List[PublicApiFormatHealthMonitor] = []
|
||||
for api_format in all_formats:
|
||||
candidates = grouped_candidates.get(api_format, [])
|
||||
|
||||
# 统计
|
||||
success_count = sum(1 for c in candidates if c.status == "success")
|
||||
failed_count = sum(1 for c in candidates if c.status == "failed")
|
||||
skipped_count = sum(1 for c in candidates if c.status == "skipped")
|
||||
total_attempts = len(candidates)
|
||||
|
||||
# 计算成功率 = success / (success + failed)
|
||||
actual_completed = success_count + failed_count
|
||||
success_rate = success_count / actual_completed if actual_completed > 0 else 1.0
|
||||
|
||||
# 转换为公开版事件列表(不含敏感信息如 provider_id, key_id)
|
||||
events: List[PublicHealthEvent] = []
|
||||
for c in candidates:
|
||||
event_time = c.finished_at or c.started_at or c.created_at
|
||||
events.append(
|
||||
PublicHealthEvent(
|
||||
timestamp=event_time,
|
||||
status=c.status,
|
||||
status_code=c.status_code,
|
||||
latency_ms=c.latency_ms,
|
||||
error_type=c.error_type,
|
||||
)
|
||||
)
|
||||
|
||||
# 最后事件时间
|
||||
last_event_at = None
|
||||
if candidates:
|
||||
last_event_at = (
|
||||
candidates[0].finished_at
|
||||
or candidates[0].started_at
|
||||
or candidates[0].created_at
|
||||
)
|
||||
|
||||
timeline_data = EndpointHealthService._generate_timeline_from_usage(
|
||||
db=db,
|
||||
endpoint_ids=endpoint_map.get(api_format, []),
|
||||
now=now,
|
||||
lookback_hours=self.lookback_hours,
|
||||
)
|
||||
|
||||
# 获取本站入口路径
|
||||
from src.core.api_format_metadata import get_local_path
|
||||
from src.core.enums import APIFormat
|
||||
|
||||
try:
|
||||
api_format_enum = APIFormat(api_format)
|
||||
local_path = get_local_path(api_format_enum)
|
||||
except ValueError:
|
||||
local_path = "/"
|
||||
|
||||
monitors.append(
|
||||
PublicApiFormatHealthMonitor(
|
||||
api_format=api_format,
|
||||
api_path=local_path,
|
||||
total_attempts=total_attempts,
|
||||
success_count=success_count,
|
||||
failed_count=failed_count,
|
||||
skipped_count=skipped_count,
|
||||
success_rate=success_rate,
|
||||
last_event_at=last_event_at,
|
||||
events=events,
|
||||
timeline=timeline_data.get("timeline", []),
|
||||
time_range_start=timeline_data.get("time_range_start"),
|
||||
time_range_end=timeline_data.get("time_range_end"),
|
||||
)
|
||||
)
|
||||
|
||||
response = PublicApiFormatHealthMonitorResponse(
|
||||
generated_at=now,
|
||||
formats=monitors,
|
||||
)
|
||||
|
||||
logger.debug(f"公开健康监控: 返回 {len(monitors)} 个 API 格式的健康数据")
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicGlobalModelsAdapter(PublicApiAdapter):
|
||||
"""公开的 GlobalModel 列表适配器"""
|
||||
|
||||
skip: int
|
||||
limit: int
|
||||
is_active: Optional[bool]
|
||||
search: Optional[str]
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求 GlobalModel 列表")
|
||||
|
||||
query = db.query(GlobalModel)
|
||||
|
||||
# 默认只返回活跃的模型
|
||||
if self.is_active is not None:
|
||||
query = query.filter(GlobalModel.is_active == self.is_active)
|
||||
else:
|
||||
query = query.filter(GlobalModel.is_active.is_(True))
|
||||
|
||||
# 搜索过滤
|
||||
if self.search:
|
||||
search_term = f"%{self.search}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
GlobalModel.name.ilike(search_term),
|
||||
GlobalModel.display_name.ilike(search_term),
|
||||
GlobalModel.description.ilike(search_term),
|
||||
)
|
||||
)
|
||||
|
||||
# 统计总数
|
||||
total = query.count()
|
||||
|
||||
# 分页
|
||||
models = query.order_by(GlobalModel.name).offset(self.skip).limit(self.limit).all()
|
||||
|
||||
# 转换为响应格式
|
||||
model_responses = []
|
||||
for gm in models:
|
||||
model_responses.append(
|
||||
PublicGlobalModelResponse(
|
||||
id=gm.id,
|
||||
name=gm.name,
|
||||
display_name=gm.display_name,
|
||||
description=gm.description,
|
||||
icon_url=gm.icon_url,
|
||||
is_active=gm.is_active,
|
||||
default_price_per_request=gm.default_price_per_request,
|
||||
default_tiered_pricing=gm.default_tiered_pricing,
|
||||
default_supports_vision=gm.default_supports_vision or False,
|
||||
default_supports_function_calling=gm.default_supports_function_calling or False,
|
||||
default_supports_streaming=(
|
||||
gm.default_supports_streaming
|
||||
if gm.default_supports_streaming is not None
|
||||
else True
|
||||
),
|
||||
default_supports_extended_thinking=gm.default_supports_extended_thinking
|
||||
or False,
|
||||
supported_capabilities=gm.supported_capabilities,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"返回 {len(model_responses)} 个 GlobalModel")
|
||||
return PublicGlobalModelListResponse(models=model_responses, total=total)
|
||||
52
src/api/public/claude.py
Normal file
52
src/api/public/claude.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Claude API 端点
|
||||
|
||||
- /v1/messages - Claude Messages API
|
||||
- /v1/messages/count_tokens - Token Count API
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.api.handlers.claude import (
|
||||
ClaudeTokenCountAdapter,
|
||||
build_claude_adapter,
|
||||
)
|
||||
from src.core.api_format_metadata import get_api_format_definition
|
||||
from src.core.enums import APIFormat
|
||||
from src.database import get_db
|
||||
|
||||
_claude_def = get_api_format_definition(APIFormat.CLAUDE)
|
||||
router = APIRouter(tags=["Claude API"], prefix=_claude_def.path_prefix)
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.post("/v1/messages")
|
||||
async def create_message(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""统一入口:根据 x-app 自动在标准/Claude Code 之间切换。"""
|
||||
adapter = build_claude_adapter(http_request.headers.get("x-app", ""))
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
api_format_hint=adapter.allowed_api_formats[0],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/messages/count_tokens")
|
||||
async def count_tokens(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = ClaudeTokenCountAdapter()
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
)
|
||||
130
src/api/public/gemini.py
Normal file
130
src/api/public/gemini.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Gemini API 专属端点
|
||||
|
||||
托管 Gemini API 相关路由:
|
||||
- /v1beta/models/{model}:generateContent
|
||||
- /v1beta/models/{model}:streamGenerateContent
|
||||
|
||||
注意: Gemini API 的 model 在 URL 路径中,而不是请求体中
|
||||
|
||||
路径配置来源: src.core.api_format_metadata.APIFormat.GEMINI
|
||||
- path_prefix: 本站路径前缀(如 /gemini),通过 router prefix 配置
|
||||
- default_path: 标准 API 路径模板
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.api.handlers.gemini import build_gemini_adapter
|
||||
from src.api.handlers.gemini_cli import build_gemini_cli_adapter
|
||||
from src.core.api_format_metadata import get_api_format_definition
|
||||
from src.core.enums import APIFormat
|
||||
from src.database import get_db
|
||||
|
||||
# 从配置获取路径前缀
|
||||
_gemini_def = get_api_format_definition(APIFormat.GEMINI)
|
||||
|
||||
router = APIRouter(tags=["Gemini API"], prefix=_gemini_def.path_prefix)
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
def _is_cli_request(request: Request) -> bool:
|
||||
"""
|
||||
判断是否为 CLI 请求
|
||||
|
||||
检查顺序:
|
||||
1. x-app header 包含 "cli"
|
||||
2. user-agent 包含 "GeminiCLI" 或 "gemini-cli"
|
||||
"""
|
||||
# 检查 x-app header
|
||||
x_app = request.headers.get("x-app", "")
|
||||
if "cli" in x_app.lower():
|
||||
return True
|
||||
|
||||
# 检查 user-agent
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
user_agent_lower = user_agent.lower()
|
||||
if "geminicli" in user_agent_lower or "gemini-cli" in user_agent_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/v1beta/models/{model}:generateContent")
|
||||
async def generate_content(
|
||||
model: str,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Gemini generateContent 端点
|
||||
|
||||
非流式生成内容请求
|
||||
"""
|
||||
# 根据 user-agent 或 x-app header 选择适配器
|
||||
if _is_cli_request(http_request):
|
||||
adapter = build_gemini_cli_adapter()
|
||||
else:
|
||||
adapter = build_gemini_adapter()
|
||||
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
api_format_hint=adapter.allowed_api_formats[0],
|
||||
# 将 model 注入到请求体中,stream 用于内部判断流式模式
|
||||
path_params={"model": model, "stream": False},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1beta/models/{model}:streamGenerateContent")
|
||||
async def stream_generate_content(
|
||||
model: str,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Gemini streamGenerateContent 端点
|
||||
|
||||
流式生成内容请求
|
||||
|
||||
注意: Gemini API 通过 URL 端点区分流式/非流式,不需要在请求体中添加 stream 字段
|
||||
"""
|
||||
# 根据 user-agent 或 x-app header 选择适配器
|
||||
if _is_cli_request(http_request):
|
||||
adapter = build_gemini_cli_adapter()
|
||||
else:
|
||||
adapter = build_gemini_adapter()
|
||||
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
api_format_hint=adapter.allowed_api_formats[0],
|
||||
# model 注入到请求体,stream 用于内部判断流式模式(不发送到 API)
|
||||
path_params={"model": model, "stream": True},
|
||||
)
|
||||
|
||||
|
||||
# 兼容 v1 路径(部分 SDK 可能使用)
|
||||
@router.post("/v1/models/{model}:generateContent")
|
||||
async def generate_content_v1(
|
||||
model: str,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""v1 兼容端点"""
|
||||
return await generate_content(model, http_request, db)
|
||||
|
||||
|
||||
@router.post("/v1/models/{model}:streamGenerateContent")
|
||||
async def stream_generate_content_v1(
|
||||
model: str,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""v1 兼容端点"""
|
||||
return await stream_generate_content(model, http_request, db)
|
||||
50
src/api/public/openai.py
Normal file
50
src/api/public/openai.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
OpenAI API 端点
|
||||
|
||||
- /v1/chat/completions - OpenAI Chat API
|
||||
- /v1/responses - OpenAI Responses API (CLI)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.api.handlers.openai import OpenAIChatAdapter
|
||||
from src.api.handlers.openai_cli import OpenAICliAdapter
|
||||
from src.core.api_format_metadata import get_api_format_definition
|
||||
from src.core.enums import APIFormat
|
||||
from src.database import get_db
|
||||
|
||||
_openai_def = get_api_format_definition(APIFormat.OPENAI)
|
||||
router = APIRouter(tags=["OpenAI API"], prefix=_openai_def.path_prefix)
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def create_chat_completion(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = OpenAIChatAdapter()
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
api_format_hint=adapter.allowed_api_formats[0],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/responses")
|
||||
async def create_responses(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = OpenAICliAdapter()
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
http_request=http_request,
|
||||
db=db,
|
||||
mode=adapter.mode,
|
||||
api_format_hint=adapter.allowed_api_formats[0],
|
||||
)
|
||||
306
src/api/public/system_catalog.py
Normal file
306
src/api/public/system_catalog.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
System Catalog / 健康检查相关端点
|
||||
|
||||
这些是系统工具端点,不需要复杂的 Adapter 抽象。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
||||
from src.clients.redis_client import get_redis_client, get_redis_client_sync
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.database.database import get_pool_status
|
||||
from src.models.database import Model, Provider
|
||||
from src.services.orchestration.fallback_orchestrator import FallbackOrchestrator
|
||||
from src.services.provider.transport import build_provider_url
|
||||
|
||||
router = APIRouter(tags=["System Catalog"])
|
||||
|
||||
|
||||
# ============== 辅助函数 ==============
|
||||
|
||||
|
||||
def _as_bool(value: Optional[str], default: bool) -> bool:
|
||||
"""将字符串转换为布尔值"""
|
||||
if value is None:
|
||||
return default
|
||||
return value.lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _serialize_provider(
|
||||
provider: Provider,
|
||||
include_models: bool,
|
||||
include_endpoints: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""序列化 Provider 对象"""
|
||||
provider_data: Dict[str, Any] = {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"is_active": provider.is_active,
|
||||
"provider_priority": provider.provider_priority,
|
||||
}
|
||||
|
||||
if include_endpoints:
|
||||
provider_data["endpoints"] = [
|
||||
{
|
||||
"id": endpoint.id,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
"is_active": endpoint.is_active,
|
||||
}
|
||||
for endpoint in provider.endpoints or []
|
||||
]
|
||||
|
||||
if include_models:
|
||||
provider_data["models"] = [
|
||||
{
|
||||
"id": model.id,
|
||||
"name": (
|
||||
model.global_model.name if model.global_model else model.provider_model_name
|
||||
),
|
||||
"display_name": (
|
||||
model.global_model.display_name
|
||||
if model.global_model
|
||||
else model.provider_model_name
|
||||
),
|
||||
"is_active": model.is_active,
|
||||
"supports_streaming": model.supports_streaming,
|
||||
}
|
||||
for model in provider.models or []
|
||||
if model.is_active
|
||||
]
|
||||
|
||||
return provider_data
|
||||
|
||||
|
||||
def _select_provider(db: Session, provider_name: Optional[str]) -> Optional[Provider]:
|
||||
"""选择 Provider(按 provider_priority 优先级选择)"""
|
||||
query = db.query(Provider).filter(Provider.is_active == True)
|
||||
if provider_name:
|
||||
provider = query.filter(Provider.name == provider_name).first()
|
||||
if provider:
|
||||
return provider
|
||||
|
||||
# 按优先级选择(provider_priority 最小的优先)
|
||||
return query.order_by(Provider.provider_priority.asc()).first()
|
||||
|
||||
|
||||
# ============== 端点 ==============
|
||||
|
||||
|
||||
@router.get("/v1/health")
|
||||
async def service_health(db: Session = Depends(get_db)):
|
||||
"""返回服务健康状态与依赖信息"""
|
||||
active_providers = (
|
||||
db.query(func.count(Provider.id)).filter(Provider.is_active == True).scalar() or 0
|
||||
)
|
||||
active_models = db.query(func.count(Model.id)).filter(Model.is_active == True).scalar() or 0
|
||||
|
||||
redis_info: Dict[str, Any] = {"status": "unknown"}
|
||||
try:
|
||||
redis = await get_redis_client()
|
||||
if redis:
|
||||
await redis.ping()
|
||||
redis_info = {"status": "ok"}
|
||||
else:
|
||||
redis_info = {"status": "degraded", "message": "Redis client not initialized"}
|
||||
except Exception as exc:
|
||||
redis_info = {"status": "error", "message": str(exc)}
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"stats": {
|
||||
"active_providers": active_providers,
|
||||
"active_models": active_models,
|
||||
},
|
||||
"dependencies": {
|
||||
"database": {"status": "ok"},
|
||||
"redis": redis_info,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""简单健康检查端点(无需认证)"""
|
||||
try:
|
||||
pool_status = get_pool_status()
|
||||
pool_health = {
|
||||
"checked_out": pool_status["checked_out"],
|
||||
"pool_size": pool_status["pool_size"],
|
||||
"overflow": pool_status["overflow"],
|
||||
"max_capacity": pool_status["max_capacity"],
|
||||
"usage_rate": (
|
||||
f"{(pool_status['checked_out'] / pool_status['max_capacity'] * 100):.1f}%"
|
||||
if pool_status["max_capacity"] > 0
|
||||
else "0.0%"
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
pool_health = {"error": str(e)}
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"database_pool": pool_health,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root(db: Session = Depends(get_db)):
|
||||
"""Root endpoint - 服务信息概览"""
|
||||
# 按优先级选择最高优先级的提供商
|
||||
top_provider = (
|
||||
db.query(Provider)
|
||||
.filter(Provider.is_active == True)
|
||||
.order_by(Provider.provider_priority.asc())
|
||||
.first()
|
||||
)
|
||||
active_providers = db.query(Provider).filter(Provider.is_active == True).count()
|
||||
|
||||
return {
|
||||
"message": "AI Proxy with Modular Architecture v4.0.0",
|
||||
"status": "running",
|
||||
"current_provider": top_provider.name if top_provider else "None",
|
||||
"available_providers": active_providers,
|
||||
"config": {},
|
||||
"endpoints": {
|
||||
"messages": "/v1/messages",
|
||||
"count_tokens": "/v1/messages/count_tokens",
|
||||
"health": "/v1/health",
|
||||
"providers": "/v1/providers",
|
||||
"test_connection": "/v1/test-connection",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/providers")
|
||||
async def list_providers(
|
||||
db: Session = Depends(get_db),
|
||||
include_models: bool = Query(False),
|
||||
include_endpoints: bool = Query(False),
|
||||
active_only: bool = Query(True),
|
||||
):
|
||||
"""列出所有 Provider"""
|
||||
load_options = []
|
||||
if include_models:
|
||||
load_options.append(selectinload(Provider.models).selectinload(Model.global_model))
|
||||
if include_endpoints:
|
||||
load_options.append(selectinload(Provider.endpoints))
|
||||
|
||||
base_query = db.query(Provider)
|
||||
if load_options:
|
||||
base_query = base_query.options(*load_options)
|
||||
if active_only:
|
||||
base_query = base_query.filter(Provider.is_active == True)
|
||||
base_query = base_query.order_by(Provider.provider_priority.asc(), Provider.name.asc())
|
||||
|
||||
providers = base_query.all()
|
||||
return {
|
||||
"providers": [
|
||||
_serialize_provider(provider, include_models, include_endpoints)
|
||||
for provider in providers
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/v1/providers/{provider_identifier}")
|
||||
async def provider_detail(
|
||||
provider_identifier: str,
|
||||
db: Session = Depends(get_db),
|
||||
include_models: bool = Query(False),
|
||||
include_endpoints: bool = Query(False),
|
||||
):
|
||||
"""获取单个 Provider 详情"""
|
||||
load_options = []
|
||||
if include_models:
|
||||
load_options.append(selectinload(Provider.models).selectinload(Model.global_model))
|
||||
if include_endpoints:
|
||||
load_options.append(selectinload(Provider.endpoints))
|
||||
|
||||
base_query = db.query(Provider)
|
||||
if load_options:
|
||||
base_query = base_query.options(*load_options)
|
||||
|
||||
provider = base_query.filter(
|
||||
(Provider.id == provider_identifier) | (Provider.name == provider_identifier)
|
||||
).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
return _serialize_provider(provider, include_models, include_endpoints)
|
||||
|
||||
|
||||
@router.get("/v1/test-connection")
|
||||
@router.get("/test-connection")
|
||||
async def test_connection(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
provider: Optional[str] = Query(None),
|
||||
model: str = Query("claude-3-haiku-20240307"),
|
||||
api_format: Optional[str] = Query(None),
|
||||
):
|
||||
"""测试 Provider 连接"""
|
||||
selected_provider = _select_provider(db, provider)
|
||||
if not selected_provider:
|
||||
raise HTTPException(status_code=503, detail="No active provider available")
|
||||
|
||||
# 构建测试请求体
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Health check"}],
|
||||
"max_tokens": 5,
|
||||
}
|
||||
|
||||
# 确定 API 格式
|
||||
format_value = api_format or "CLAUDE"
|
||||
|
||||
# 创建 FallbackOrchestrator
|
||||
redis_client = get_redis_client_sync()
|
||||
orchestrator = FallbackOrchestrator(db, redis_client)
|
||||
|
||||
# 定义请求函数
|
||||
async def test_request_func(_prov, endpoint, key):
|
||||
request_builder = PassthroughRequestBuilder()
|
||||
provider_payload, provider_headers = request_builder.build(
|
||||
payload, {}, endpoint, key, is_stream=False
|
||||
)
|
||||
|
||||
url = build_provider_url(
|
||||
endpoint,
|
||||
query_params=dict(request.query_params),
|
||||
path_params={"model": model},
|
||||
is_stream=False,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(url, json=provider_payload, headers=provider_headers)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
try:
|
||||
response, actual_provider, *_ = await orchestrator.execute_with_fallback(
|
||||
api_format=format_value,
|
||||
model_name=model,
|
||||
user_api_key=None,
|
||||
request_func=test_request_func,
|
||||
request_id=None,
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"provider": actual_provider,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"response_id": response.get("id", "unknown"),
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error(f"API connectivity test failed: {exc}")
|
||||
raise HTTPException(status_code=503, detail=str(exc))
|
||||
Reference in New Issue
Block a user