Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
"""Public-facing API routers."""
from fastapi import APIRouter
from .capabilities import router as capabilities_router
from .catalog import router as catalog_router
from .claude import router as claude_router
from .gemini import router as gemini_router
from .openai import router as openai_router
from .system_catalog import router as system_catalog_router
router = APIRouter()
router.include_router(claude_router, tags=["Claude API"])
router.include_router(openai_router)
router.include_router(gemini_router, tags=["Gemini API"])
router.include_router(system_catalog_router, tags=["System Catalog"])
router.include_router(catalog_router)
router.include_router(capabilities_router)
__all__ = ["router"]

View File

@@ -0,0 +1,104 @@
"""
能力配置公共 API
提供系统支持的能力列表,供前端展示和配置使用。
"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from src.core.key_capabilities import (
get_all_capabilities,
get_user_configurable_capabilities,
)
from src.database import get_db
router = APIRouter(prefix="/api/capabilities", tags=["Capabilities"])
@router.get("")
async def list_capabilities():
"""获取所有能力定义"""
return {
"capabilities": [
{
"name": cap.name,
"display_name": cap.display_name,
"short_name": cap.short_name,
"description": cap.description,
"match_mode": cap.match_mode.value,
"config_mode": cap.config_mode.value,
}
for cap in get_all_capabilities()
]
}
@router.get("/user-configurable")
async def list_user_configurable_capabilities():
"""获取用户可配置的能力列表(用于前端展示配置选项)"""
return {
"capabilities": [
{
"name": cap.name,
"display_name": cap.display_name,
"short_name": cap.short_name,
"description": cap.description,
"match_mode": cap.match_mode.value,
"config_mode": cap.config_mode.value,
}
for cap in get_user_configurable_capabilities()
]
}
@router.get("/model/{model_name}")
async def get_model_supported_capabilities(
model_name: str,
db: Session = Depends(get_db),
):
"""
获取指定模型支持的能力列表
Args:
model_name: 模型名称(如 claude-sonnet-4-20250514
Returns:
模型支持的能力列表,以及每个能力的详细定义
"""
from src.services.model.mapping_resolver import get_model_mapping_resolver
mapping_resolver = get_model_mapping_resolver()
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
if not global_model:
return {
"model": model_name,
"supported_capabilities": [],
"capability_details": [],
"error": "模型不存在",
}
supported_caps = global_model.supported_capabilities or []
# 获取支持的能力详情
all_caps = {cap.name: cap for cap in get_all_capabilities()}
capability_details = []
for cap_name in supported_caps:
if cap_name in all_caps:
cap = all_caps[cap_name]
capability_details.append({
"name": cap.name,
"display_name": cap.display_name,
"description": cap.description,
"match_mode": cap.match_mode.value,
"config_mode": cap.config_mode.value,
})
return {
"model": model_name,
"global_model_id": str(global_model.id),
"global_model_name": global_model.name,
"supported_capabilities": supported_caps,
"capability_details": capability_details,
}

643
src/api/public/catalog.py Normal file
View File

@@ -0,0 +1,643 @@
"""
公开API端点 - 用户可查看的提供商和模型信息
不包含敏感信息,普通用户可访问
"""
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import and_, func, or_
from sqlalchemy.orm import Session, joinedload
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.pipeline import ApiRequestPipeline
from src.core.logger import logger
from src.database import get_db
from src.models.api import (
ProviderStatsResponse,
PublicGlobalModelListResponse,
PublicGlobalModelResponse,
PublicModelMappingResponse,
PublicModelResponse,
PublicProviderResponse,
)
from src.models.database import (
GlobalModel,
Model,
ModelMapping,
Provider,
ProviderEndpoint,
RequestCandidate,
)
from src.models.endpoint_models import (
PublicApiFormatHealthMonitor,
PublicApiFormatHealthMonitorResponse,
PublicHealthEvent,
)
from src.services.health.endpoint import EndpointHealthService
router = APIRouter(prefix="/api/public", tags=["Public Catalog"])
pipeline = ApiRequestPipeline()
@router.get("/providers", response_model=List[PublicProviderResponse])
async def get_public_providers(
request: Request,
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
skip: int = Query(0, description="跳过记录数"),
limit: int = Query(100, description="返回记录数限制"),
db: Session = Depends(get_db),
):
"""获取提供商列表(用户视图)。"""
adapter = PublicProvidersAdapter(is_active=is_active, skip=skip, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/models", response_model=List[PublicModelResponse])
async def get_public_models(
request: Request,
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
skip: int = Query(0, description="跳过记录数"),
limit: int = Query(100, description="返回记录数限制"),
db: Session = Depends(get_db),
):
adapter = PublicModelsAdapter(
provider_id=provider_id, is_active=is_active, skip=skip, limit=limit
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/model-mappings", response_model=List[PublicModelMappingResponse])
async def get_public_model_mappings(
request: Request,
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
alias: Optional[str] = Query(None, description="别名过滤原source_model"),
skip: int = Query(0, description="跳过记录数"),
limit: int = Query(100, description="返回记录数限制"),
db: Session = Depends(get_db),
):
adapter = PublicModelMappingsAdapter(
provider_id=provider_id,
alias=alias,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/stats", response_model=ProviderStatsResponse)
async def get_public_stats(request: Request, db: Session = Depends(get_db)):
adapter = PublicStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/search/models")
async def search_models(
request: Request,
q: str = Query(..., description="搜索关键词"),
provider_id: Optional[int] = Query(None, description="提供商ID过滤"),
limit: int = Query(20, description="返回记录数限制"),
db: Session = Depends(get_db),
):
adapter = PublicSearchModelsAdapter(query=q, provider_id=provider_id, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/health/api-formats", response_model=PublicApiFormatHealthMonitorResponse)
async def get_public_api_format_health(
request: Request,
lookback_hours: int = Query(6, ge=1, le=168, description="回溯小时数"),
per_format_limit: int = Query(100, ge=10, le=500, description="每个格式的事件数限制"),
db: Session = Depends(get_db),
):
"""获取各 API 格式的健康监控数据(公开版,不含敏感信息)"""
adapter = PublicApiFormatHealthMonitorAdapter(
lookback_hours=lookback_hours,
per_format_limit=per_format_limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/global-models", response_model=PublicGlobalModelListResponse)
async def get_public_global_models(
request: Request,
skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回记录数限制"),
is_active: Optional[bool] = Query(None, description="过滤活跃状态"),
search: Optional[str] = Query(None, description="搜索关键词"),
db: Session = Depends(get_db),
):
"""获取 GlobalModel 列表(用户视图,只读)"""
adapter = PublicGlobalModelsAdapter(
skip=skip,
limit=limit,
is_active=is_active,
search=search,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
# -------- 公共适配器 --------
class PublicApiAdapter(ApiAdapter):
mode = ApiMode.PUBLIC
def authorize(self, context): # type: ignore[override]
return None
@dataclass
class PublicProvidersAdapter(PublicApiAdapter):
is_active: Optional[bool]
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求提供商列表")
query = db.query(Provider)
if self.is_active is not None:
query = query.filter(Provider.is_active == self.is_active)
else:
query = query.filter(Provider.is_active.is_(True))
providers = query.offset(self.skip).limit(self.limit).all()
result = []
for provider in providers:
models_count = db.query(Model).filter(Model.provider_id == provider.id).count()
active_models_count = (
db.query(Model)
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True)))
.count()
)
mappings_count = (
db.query(ModelMapping)
.filter(
and_(ModelMapping.provider_id == provider.id, ModelMapping.is_active.is_(True))
)
.count()
)
endpoints_count = len(provider.endpoints) if provider.endpoints else 0
active_endpoints_count = (
sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0
)
provider_data = PublicProviderResponse(
id=provider.id,
name=provider.name,
display_name=provider.display_name,
description=provider.description,
is_active=provider.is_active,
provider_priority=provider.provider_priority,
models_count=models_count,
active_models_count=active_models_count,
mappings_count=mappings_count,
endpoints_count=endpoints_count,
active_endpoints_count=active_endpoints_count,
)
result.append(provider_data.model_dump())
logger.debug(f"返回 {len(result)} 个提供商信息")
return result
@dataclass
class PublicModelsAdapter(PublicApiAdapter):
provider_id: Optional[str]
is_active: Optional[bool]
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求模型列表")
query = (
db.query(Model, Provider)
.options(joinedload(Model.global_model))
.join(Provider)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
)
if self.provider_id is not None:
query = query.filter(Model.provider_id == self.provider_id)
results = query.offset(self.skip).limit(self.limit).all()
response = []
for model, provider in results:
global_model = model.global_model
display_name = global_model.display_name if global_model else model.provider_model_name
unified_name = global_model.name if global_model else model.provider_model_name
model_data = PublicModelResponse(
id=model.id,
provider_id=model.provider_id,
provider_name=provider.name,
provider_display_name=provider.display_name,
name=unified_name,
display_name=display_name,
description=global_model.description if global_model else None,
tags=None,
icon_url=global_model.icon_url if global_model else None,
input_price_per_1m=model.get_effective_input_price(),
output_price_per_1m=model.get_effective_output_price(),
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
cache_read_price_per_1m=model.get_effective_cache_read_price(),
supports_vision=model.get_effective_supports_vision(),
supports_function_calling=model.get_effective_supports_function_calling(),
supports_streaming=model.get_effective_supports_streaming(),
is_active=model.is_active,
)
response.append(model_data.model_dump())
logger.debug(f"返回 {len(response)} 个模型信息")
return response
@dataclass
class PublicModelMappingsAdapter(PublicApiAdapter):
provider_id: Optional[str]
alias: Optional[str] # 原 source_model改为 alias
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求模型映射列表")
query = (
db.query(ModelMapping, GlobalModel, Provider)
.join(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
.outerjoin(Provider, ModelMapping.provider_id == Provider.id)
.filter(
and_(
ModelMapping.is_active.is_(True),
GlobalModel.is_active.is_(True),
)
)
)
if self.provider_id is not None:
provider_global_model_ids = (
db.query(Model.global_model_id)
.join(Provider, Model.provider_id == Provider.id)
.filter(
Provider.id == self.provider_id,
Model.is_active.is_(True),
Provider.is_active.is_(True),
Model.global_model_id.isnot(None),
)
.distinct()
)
query = query.filter(
or_(
ModelMapping.provider_id == self.provider_id,
and_(
ModelMapping.provider_id.is_(None),
ModelMapping.target_global_model_id.in_(provider_global_model_ids),
),
)
)
else:
query = query.filter(ModelMapping.provider_id.is_(None))
if self.alias is not None:
query = query.filter(ModelMapping.source_model.ilike(f"%{self.alias}%"))
results = query.offset(self.skip).limit(self.limit).all()
response = []
for mapping, global_model, provider in results:
scope = "provider" if mapping.provider_id else "global"
mapping_data = PublicModelMappingResponse(
id=mapping.id,
source_model=mapping.source_model,
target_global_model_id=mapping.target_global_model_id,
target_global_model_name=global_model.name if global_model else None,
target_global_model_display_name=(
global_model.display_name if global_model else None
),
provider_id=mapping.provider_id,
scope=scope,
is_active=mapping.is_active,
)
response.append(mapping_data.model_dump())
logger.debug(f"返回 {len(response)} 个模型映射")
return response
class PublicStatsAdapter(PublicApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求系统统计信息")
active_providers = db.query(Provider).filter(Provider.is_active.is_(True)).count()
active_models = (
db.query(Model)
.join(Provider)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
.count()
)
from ...models.database import ModelMapping
active_mappings = db.query(ModelMapping).filter(ModelMapping.is_active.is_(True)).count()
formats = (
db.query(Provider.api_format).filter(Provider.is_active.is_(True)).distinct().all()
)
supported_formats = [f.api_format for f in formats if f.api_format]
stats = ProviderStatsResponse(
total_providers=active_providers,
active_providers=active_providers,
total_models=active_models,
active_models=active_models,
total_mappings=active_mappings,
supported_formats=supported_formats,
)
logger.debug("返回系统统计信息")
return stats.model_dump()
@dataclass
class PublicSearchModelsAdapter(PublicApiAdapter):
query: str
provider_id: Optional[int]
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug(f"公共API搜索模型: {self.query}")
query_stmt = (
db.query(Model, Provider)
.options(joinedload(Model.global_model))
.join(Provider)
.outerjoin(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
)
search_filter = (
Model.provider_model_name.ilike(f"%{self.query}%")
| GlobalModel.name.ilike(f"%{self.query}%")
| GlobalModel.display_name.ilike(f"%{self.query}%")
| GlobalModel.description.ilike(f"%{self.query}%")
)
query_stmt = query_stmt.filter(search_filter)
if self.provider_id is not None:
query_stmt = query_stmt.filter(Model.provider_id == self.provider_id)
results = query_stmt.limit(self.limit).all()
response = []
for model, provider in results:
global_model = model.global_model
display_name = global_model.display_name if global_model else model.provider_model_name
unified_name = global_model.name if global_model else model.provider_model_name
model_data = PublicModelResponse(
id=model.id,
provider_id=model.provider_id,
provider_name=provider.name,
provider_display_name=provider.display_name,
name=unified_name,
display_name=display_name,
description=global_model.description if global_model else None,
tags=None,
icon_url=global_model.icon_url if global_model else None,
input_price_per_1m=model.get_effective_input_price(),
output_price_per_1m=model.get_effective_output_price(),
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
cache_read_price_per_1m=model.get_effective_cache_read_price(),
supports_vision=model.get_effective_supports_vision(),
supports_function_calling=model.get_effective_supports_function_calling(),
supports_streaming=model.get_effective_supports_streaming(),
is_active=model.is_active,
)
response.append(model_data.model_dump())
logger.debug(f"搜索 '{self.query}' 返回 {len(response)} 个结果")
return response
@dataclass
class PublicApiFormatHealthMonitorAdapter(PublicApiAdapter):
"""公开版 API 格式健康监控适配器(返回 events 数组,前端复用 EndpointHealthTimeline 组件)"""
lookback_hours: int
per_format_limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
now = datetime.now(timezone.utc)
since = now - timedelta(hours=self.lookback_hours)
# 1. 获取所有活跃的 API 格式
active_formats = (
db.query(ProviderEndpoint.api_format)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.distinct()
.all()
)
all_formats: List[str] = []
for (api_format_enum,) in active_formats:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
all_formats.append(api_format)
# API 格式 -> Endpoint ID 映射(用于 Usage 时间线)
endpoint_rows = (
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.all()
)
endpoint_map: Dict[str, List[str]] = defaultdict(list)
for api_format_enum, endpoint_id in endpoint_rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
endpoint_map[api_format].append(endpoint_id)
# 2. 获取最近一段时间的 RequestCandidate限制数量
# 只查询最终状态的记录success, failed, skipped
final_statuses = ["success", "failed", "skipped"]
limit_rows = max(500, self.per_format_limit * 10)
rows = (
db.query(
RequestCandidate,
ProviderEndpoint.api_format,
)
.join(ProviderEndpoint, RequestCandidate.endpoint_id == ProviderEndpoint.id)
.filter(
RequestCandidate.created_at >= since,
RequestCandidate.status.in_(final_statuses),
)
.order_by(RequestCandidate.created_at.desc())
.limit(limit_rows)
.all()
)
grouped_candidates: Dict[str, List[RequestCandidate]] = {}
for candidate, api_format_enum in rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
if api_format not in grouped_candidates:
grouped_candidates[api_format] = []
if len(grouped_candidates[api_format]) < self.per_format_limit:
grouped_candidates[api_format].append(candidate)
# 3. 为所有活跃格式生成监控数据
monitors: List[PublicApiFormatHealthMonitor] = []
for api_format in all_formats:
candidates = grouped_candidates.get(api_format, [])
# 统计
success_count = sum(1 for c in candidates if c.status == "success")
failed_count = sum(1 for c in candidates if c.status == "failed")
skipped_count = sum(1 for c in candidates if c.status == "skipped")
total_attempts = len(candidates)
# 计算成功率 = success / (success + failed)
actual_completed = success_count + failed_count
success_rate = success_count / actual_completed if actual_completed > 0 else 1.0
# 转换为公开版事件列表(不含敏感信息如 provider_id, key_id
events: List[PublicHealthEvent] = []
for c in candidates:
event_time = c.finished_at or c.started_at or c.created_at
events.append(
PublicHealthEvent(
timestamp=event_time,
status=c.status,
status_code=c.status_code,
latency_ms=c.latency_ms,
error_type=c.error_type,
)
)
# 最后事件时间
last_event_at = None
if candidates:
last_event_at = (
candidates[0].finished_at
or candidates[0].started_at
or candidates[0].created_at
)
timeline_data = EndpointHealthService._generate_timeline_from_usage(
db=db,
endpoint_ids=endpoint_map.get(api_format, []),
now=now,
lookback_hours=self.lookback_hours,
)
# 获取本站入口路径
from src.core.api_format_metadata import get_local_path
from src.core.enums import APIFormat
try:
api_format_enum = APIFormat(api_format)
local_path = get_local_path(api_format_enum)
except ValueError:
local_path = "/"
monitors.append(
PublicApiFormatHealthMonitor(
api_format=api_format,
api_path=local_path,
total_attempts=total_attempts,
success_count=success_count,
failed_count=failed_count,
skipped_count=skipped_count,
success_rate=success_rate,
last_event_at=last_event_at,
events=events,
timeline=timeline_data.get("timeline", []),
time_range_start=timeline_data.get("time_range_start"),
time_range_end=timeline_data.get("time_range_end"),
)
)
response = PublicApiFormatHealthMonitorResponse(
generated_at=now,
formats=monitors,
)
logger.debug(f"公开健康监控: 返回 {len(monitors)} 个 API 格式的健康数据")
return response
@dataclass
class PublicGlobalModelsAdapter(PublicApiAdapter):
"""公开的 GlobalModel 列表适配器"""
skip: int
limit: int
is_active: Optional[bool]
search: Optional[str]
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求 GlobalModel 列表")
query = db.query(GlobalModel)
# 默认只返回活跃的模型
if self.is_active is not None:
query = query.filter(GlobalModel.is_active == self.is_active)
else:
query = query.filter(GlobalModel.is_active.is_(True))
# 搜索过滤
if self.search:
search_term = f"%{self.search}%"
query = query.filter(
or_(
GlobalModel.name.ilike(search_term),
GlobalModel.display_name.ilike(search_term),
GlobalModel.description.ilike(search_term),
)
)
# 统计总数
total = query.count()
# 分页
models = query.order_by(GlobalModel.name).offset(self.skip).limit(self.limit).all()
# 转换为响应格式
model_responses = []
for gm in models:
model_responses.append(
PublicGlobalModelResponse(
id=gm.id,
name=gm.name,
display_name=gm.display_name,
description=gm.description,
icon_url=gm.icon_url,
is_active=gm.is_active,
default_price_per_request=gm.default_price_per_request,
default_tiered_pricing=gm.default_tiered_pricing,
default_supports_vision=gm.default_supports_vision or False,
default_supports_function_calling=gm.default_supports_function_calling or False,
default_supports_streaming=(
gm.default_supports_streaming
if gm.default_supports_streaming is not None
else True
),
default_supports_extended_thinking=gm.default_supports_extended_thinking
or False,
supported_capabilities=gm.supported_capabilities,
)
)
logger.debug(f"返回 {len(model_responses)} 个 GlobalModel")
return PublicGlobalModelListResponse(models=model_responses, total=total)

52
src/api/public/claude.py Normal file
View File

@@ -0,0 +1,52 @@
"""
Claude API 端点
- /v1/messages - Claude Messages API
- /v1/messages/count_tokens - Token Count API
"""
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from src.api.base.pipeline import ApiRequestPipeline
from src.api.handlers.claude import (
ClaudeTokenCountAdapter,
build_claude_adapter,
)
from src.core.api_format_metadata import get_api_format_definition
from src.core.enums import APIFormat
from src.database import get_db
_claude_def = get_api_format_definition(APIFormat.CLAUDE)
router = APIRouter(tags=["Claude API"], prefix=_claude_def.path_prefix)
pipeline = ApiRequestPipeline()
@router.post("/v1/messages")
async def create_message(
http_request: Request,
db: Session = Depends(get_db),
):
"""统一入口:根据 x-app 自动在标准/Claude Code 之间切换。"""
adapter = build_claude_adapter(http_request.headers.get("x-app", ""))
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
api_format_hint=adapter.allowed_api_formats[0],
)
@router.post("/v1/messages/count_tokens")
async def count_tokens(
http_request: Request,
db: Session = Depends(get_db),
):
adapter = ClaudeTokenCountAdapter()
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
)

130
src/api/public/gemini.py Normal file
View File

@@ -0,0 +1,130 @@
"""
Gemini API 专属端点
托管 Gemini API 相关路由:
- /v1beta/models/{model}:generateContent
- /v1beta/models/{model}:streamGenerateContent
注意: Gemini API 的 model 在 URL 路径中,而不是请求体中
路径配置来源: src.core.api_format_metadata.APIFormat.GEMINI
- path_prefix: 本站路径前缀(如 /gemini通过 router prefix 配置
- default_path: 标准 API 路径模板
"""
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from src.api.base.pipeline import ApiRequestPipeline
from src.api.handlers.gemini import build_gemini_adapter
from src.api.handlers.gemini_cli import build_gemini_cli_adapter
from src.core.api_format_metadata import get_api_format_definition
from src.core.enums import APIFormat
from src.database import get_db
# 从配置获取路径前缀
_gemini_def = get_api_format_definition(APIFormat.GEMINI)
router = APIRouter(tags=["Gemini API"], prefix=_gemini_def.path_prefix)
pipeline = ApiRequestPipeline()
def _is_cli_request(request: Request) -> bool:
"""
判断是否为 CLI 请求
检查顺序:
1. x-app header 包含 "cli"
2. user-agent 包含 "GeminiCLI""gemini-cli"
"""
# 检查 x-app header
x_app = request.headers.get("x-app", "")
if "cli" in x_app.lower():
return True
# 检查 user-agent
user_agent = request.headers.get("user-agent", "")
user_agent_lower = user_agent.lower()
if "geminicli" in user_agent_lower or "gemini-cli" in user_agent_lower:
return True
return False
@router.post("/v1beta/models/{model}:generateContent")
async def generate_content(
model: str,
http_request: Request,
db: Session = Depends(get_db),
):
"""
Gemini generateContent 端点
非流式生成内容请求
"""
# 根据 user-agent 或 x-app header 选择适配器
if _is_cli_request(http_request):
adapter = build_gemini_cli_adapter()
else:
adapter = build_gemini_adapter()
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
api_format_hint=adapter.allowed_api_formats[0],
# 将 model 注入到请求体中stream 用于内部判断流式模式
path_params={"model": model, "stream": False},
)
@router.post("/v1beta/models/{model}:streamGenerateContent")
async def stream_generate_content(
model: str,
http_request: Request,
db: Session = Depends(get_db),
):
"""
Gemini streamGenerateContent 端点
流式生成内容请求
注意: Gemini API 通过 URL 端点区分流式/非流式,不需要在请求体中添加 stream 字段
"""
# 根据 user-agent 或 x-app header 选择适配器
if _is_cli_request(http_request):
adapter = build_gemini_cli_adapter()
else:
adapter = build_gemini_adapter()
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
api_format_hint=adapter.allowed_api_formats[0],
# model 注入到请求体stream 用于内部判断流式模式(不发送到 API
path_params={"model": model, "stream": True},
)
# 兼容 v1 路径(部分 SDK 可能使用)
@router.post("/v1/models/{model}:generateContent")
async def generate_content_v1(
model: str,
http_request: Request,
db: Session = Depends(get_db),
):
"""v1 兼容端点"""
return await generate_content(model, http_request, db)
@router.post("/v1/models/{model}:streamGenerateContent")
async def stream_generate_content_v1(
model: str,
http_request: Request,
db: Session = Depends(get_db),
):
"""v1 兼容端点"""
return await stream_generate_content(model, http_request, db)

50
src/api/public/openai.py Normal file
View File

@@ -0,0 +1,50 @@
"""
OpenAI API 端点
- /v1/chat/completions - OpenAI Chat API
- /v1/responses - OpenAI Responses API (CLI)
"""
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from src.api.base.pipeline import ApiRequestPipeline
from src.api.handlers.openai import OpenAIChatAdapter
from src.api.handlers.openai_cli import OpenAICliAdapter
from src.core.api_format_metadata import get_api_format_definition
from src.core.enums import APIFormat
from src.database import get_db
_openai_def = get_api_format_definition(APIFormat.OPENAI)
router = APIRouter(tags=["OpenAI API"], prefix=_openai_def.path_prefix)
pipeline = ApiRequestPipeline()
@router.post("/v1/chat/completions")
async def create_chat_completion(
http_request: Request,
db: Session = Depends(get_db),
):
adapter = OpenAIChatAdapter()
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
api_format_hint=adapter.allowed_api_formats[0],
)
@router.post("/v1/responses")
async def create_responses(
http_request: Request,
db: Session = Depends(get_db),
):
adapter = OpenAICliAdapter()
return await pipeline.run(
adapter=adapter,
http_request=http_request,
db=db,
mode=adapter.mode,
api_format_hint=adapter.allowed_api_formats[0],
)

View File

@@ -0,0 +1,306 @@
"""
System Catalog / 健康检查相关端点
这些是系统工具端点,不需要复杂的 Adapter 抽象。
"""
from datetime import datetime, timezone
from typing import Any, Dict, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy import func
from sqlalchemy.orm import Session, selectinload
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.clients.redis_client import get_redis_client, get_redis_client_sync
from src.core.logger import logger
from src.database import get_db
from src.database.database import get_pool_status
from src.models.database import Model, Provider
from src.services.orchestration.fallback_orchestrator import FallbackOrchestrator
from src.services.provider.transport import build_provider_url
router = APIRouter(tags=["System Catalog"])
# ============== 辅助函数 ==============
def _as_bool(value: Optional[str], default: bool) -> bool:
"""将字符串转换为布尔值"""
if value is None:
return default
return value.lower() in {"1", "true", "yes", "on"}
def _serialize_provider(
provider: Provider,
include_models: bool,
include_endpoints: bool,
) -> Dict[str, Any]:
"""序列化 Provider 对象"""
provider_data: Dict[str, Any] = {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"is_active": provider.is_active,
"provider_priority": provider.provider_priority,
}
if include_endpoints:
provider_data["endpoints"] = [
{
"id": endpoint.id,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
"is_active": endpoint.is_active,
}
for endpoint in provider.endpoints or []
]
if include_models:
provider_data["models"] = [
{
"id": model.id,
"name": (
model.global_model.name if model.global_model else model.provider_model_name
),
"display_name": (
model.global_model.display_name
if model.global_model
else model.provider_model_name
),
"is_active": model.is_active,
"supports_streaming": model.supports_streaming,
}
for model in provider.models or []
if model.is_active
]
return provider_data
def _select_provider(db: Session, provider_name: Optional[str]) -> Optional[Provider]:
"""选择 Provider按 provider_priority 优先级选择)"""
query = db.query(Provider).filter(Provider.is_active == True)
if provider_name:
provider = query.filter(Provider.name == provider_name).first()
if provider:
return provider
# 按优先级选择provider_priority 最小的优先)
return query.order_by(Provider.provider_priority.asc()).first()
# ============== 端点 ==============
@router.get("/v1/health")
async def service_health(db: Session = Depends(get_db)):
"""返回服务健康状态与依赖信息"""
active_providers = (
db.query(func.count(Provider.id)).filter(Provider.is_active == True).scalar() or 0
)
active_models = db.query(func.count(Model.id)).filter(Model.is_active == True).scalar() or 0
redis_info: Dict[str, Any] = {"status": "unknown"}
try:
redis = await get_redis_client()
if redis:
await redis.ping()
redis_info = {"status": "ok"}
else:
redis_info = {"status": "degraded", "message": "Redis client not initialized"}
except Exception as exc:
redis_info = {"status": "error", "message": str(exc)}
return {
"status": "ok",
"timestamp": datetime.now(timezone.utc).isoformat(),
"stats": {
"active_providers": active_providers,
"active_models": active_models,
},
"dependencies": {
"database": {"status": "ok"},
"redis": redis_info,
},
}
@router.get("/health")
async def health_check():
"""简单健康检查端点(无需认证)"""
try:
pool_status = get_pool_status()
pool_health = {
"checked_out": pool_status["checked_out"],
"pool_size": pool_status["pool_size"],
"overflow": pool_status["overflow"],
"max_capacity": pool_status["max_capacity"],
"usage_rate": (
f"{(pool_status['checked_out'] / pool_status['max_capacity'] * 100):.1f}%"
if pool_status["max_capacity"] > 0
else "0.0%"
),
}
except Exception as e:
pool_health = {"error": str(e)}
return {
"status": "healthy",
"timestamp": datetime.now(timezone.utc).isoformat(),
"database_pool": pool_health,
}
@router.get("/")
async def root(db: Session = Depends(get_db)):
"""Root endpoint - 服务信息概览"""
# 按优先级选择最高优先级的提供商
top_provider = (
db.query(Provider)
.filter(Provider.is_active == True)
.order_by(Provider.provider_priority.asc())
.first()
)
active_providers = db.query(Provider).filter(Provider.is_active == True).count()
return {
"message": "AI Proxy with Modular Architecture v4.0.0",
"status": "running",
"current_provider": top_provider.name if top_provider else "None",
"available_providers": active_providers,
"config": {},
"endpoints": {
"messages": "/v1/messages",
"count_tokens": "/v1/messages/count_tokens",
"health": "/v1/health",
"providers": "/v1/providers",
"test_connection": "/v1/test-connection",
},
}
@router.get("/v1/providers")
async def list_providers(
db: Session = Depends(get_db),
include_models: bool = Query(False),
include_endpoints: bool = Query(False),
active_only: bool = Query(True),
):
"""列出所有 Provider"""
load_options = []
if include_models:
load_options.append(selectinload(Provider.models).selectinload(Model.global_model))
if include_endpoints:
load_options.append(selectinload(Provider.endpoints))
base_query = db.query(Provider)
if load_options:
base_query = base_query.options(*load_options)
if active_only:
base_query = base_query.filter(Provider.is_active == True)
base_query = base_query.order_by(Provider.provider_priority.asc(), Provider.name.asc())
providers = base_query.all()
return {
"providers": [
_serialize_provider(provider, include_models, include_endpoints)
for provider in providers
]
}
@router.get("/v1/providers/{provider_identifier}")
async def provider_detail(
provider_identifier: str,
db: Session = Depends(get_db),
include_models: bool = Query(False),
include_endpoints: bool = Query(False),
):
"""获取单个 Provider 详情"""
load_options = []
if include_models:
load_options.append(selectinload(Provider.models).selectinload(Model.global_model))
if include_endpoints:
load_options.append(selectinload(Provider.endpoints))
base_query = db.query(Provider)
if load_options:
base_query = base_query.options(*load_options)
provider = base_query.filter(
(Provider.id == provider_identifier) | (Provider.name == provider_identifier)
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
return _serialize_provider(provider, include_models, include_endpoints)
@router.get("/v1/test-connection")
@router.get("/test-connection")
async def test_connection(
request: Request,
db: Session = Depends(get_db),
provider: Optional[str] = Query(None),
model: str = Query("claude-3-haiku-20240307"),
api_format: Optional[str] = Query(None),
):
"""测试 Provider 连接"""
selected_provider = _select_provider(db, provider)
if not selected_provider:
raise HTTPException(status_code=503, detail="No active provider available")
# 构建测试请求体
payload = {
"model": model,
"messages": [{"role": "user", "content": "Health check"}],
"max_tokens": 5,
}
# 确定 API 格式
format_value = api_format or "CLAUDE"
# 创建 FallbackOrchestrator
redis_client = get_redis_client_sync()
orchestrator = FallbackOrchestrator(db, redis_client)
# 定义请求函数
async def test_request_func(_prov, endpoint, key):
request_builder = PassthroughRequestBuilder()
provider_payload, provider_headers = request_builder.build(
payload, {}, endpoint, key, is_stream=False
)
url = build_provider_url(
endpoint,
query_params=dict(request.query_params),
path_params={"model": model},
is_stream=False,
)
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(url, json=provider_payload, headers=provider_headers)
resp.raise_for_status()
return resp.json()
try:
response, actual_provider, *_ = await orchestrator.execute_with_fallback(
api_format=format_value,
model_name=model,
user_api_key=None,
request_func=test_request_func,
request_id=None,
)
return {
"status": "success",
"provider": actual_provider,
"timestamp": datetime.now(timezone.utc).isoformat(),
"response_id": response.get("id", "unknown"),
}
except Exception as exc:
logger.error(f"API connectivity test failed: {exc}")
raise HTTPException(status_code=503, detail=str(exc))