mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 08:12:26 +08:00
411 lines
12 KiB
Python
411 lines
12 KiB
Python
|
|
"""
|
|||
|
|
Provider Query API 端点
|
|||
|
|
用于查询提供商的余额、使用记录等信息
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
|
|||
|
|
from src.core.logger import logger
|
|||
|
|
from src.database.database import get_db
|
|||
|
|
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
|
|||
|
|
|
|||
|
|
# 初始化适配器注册
|
|||
|
|
from src.plugins.provider_query import init # noqa
|
|||
|
|
from src.plugins.provider_query import get_query_registry
|
|||
|
|
from src.plugins.provider_query.base import QueryCapability
|
|||
|
|
from src.utils.auth_utils import get_current_user
|
|||
|
|
|
|||
|
|
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============ Request/Response Models ============
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BalanceQueryRequest(BaseModel):
|
|||
|
|
"""余额查询请求"""
|
|||
|
|
|
|||
|
|
provider_id: str
|
|||
|
|
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
|
|||
|
|
|
|||
|
|
|
|||
|
|
class UsageSummaryQueryRequest(BaseModel):
|
|||
|
|
"""使用汇总查询请求"""
|
|||
|
|
|
|||
|
|
provider_id: str
|
|||
|
|
api_key_id: Optional[str] = None
|
|||
|
|
period: str = "month" # day, week, month, year
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ModelsQueryRequest(BaseModel):
|
|||
|
|
"""模型列表查询请求"""
|
|||
|
|
|
|||
|
|
provider_id: str
|
|||
|
|
api_key_id: Optional[str] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ============ API Endpoints ============
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/adapters")
|
|||
|
|
async def list_adapters(
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
获取所有可用的查询适配器
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
适配器列表
|
|||
|
|
"""
|
|||
|
|
registry = get_query_registry()
|
|||
|
|
adapters = registry.list_adapters()
|
|||
|
|
|
|||
|
|
return {"success": True, "data": adapters}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/capabilities/{provider_id}")
|
|||
|
|
async def get_provider_capabilities(
|
|||
|
|
provider_id: str,
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
获取提供商支持的查询能力
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
provider_id: 提供商 ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
支持的查询能力列表
|
|||
|
|
"""
|
|||
|
|
# 获取提供商
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
|
|||
|
|
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
|||
|
|
provider = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if not provider:
|
|||
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|||
|
|
|
|||
|
|
registry = get_query_registry()
|
|||
|
|
capabilities = registry.get_capabilities_for_provider(provider.name)
|
|||
|
|
|
|||
|
|
if capabilities is None:
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"data": {
|
|||
|
|
"provider_id": provider_id,
|
|||
|
|
"provider_name": provider.name,
|
|||
|
|
"capabilities": [],
|
|||
|
|
"has_adapter": False,
|
|||
|
|
"message": "No query adapter available for this provider",
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": True,
|
|||
|
|
"data": {
|
|||
|
|
"provider_id": provider_id,
|
|||
|
|
"provider_name": provider.name,
|
|||
|
|
"capabilities": [c.name for c in capabilities],
|
|||
|
|
"has_adapter": True,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/balance")
|
|||
|
|
async def query_balance(
|
|||
|
|
request: BalanceQueryRequest,
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
查询提供商余额
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
request: 查询请求
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
余额信息
|
|||
|
|
"""
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
from sqlalchemy.orm import selectinload
|
|||
|
|
|
|||
|
|
# 获取提供商及其端点
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(Provider)
|
|||
|
|
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
|||
|
|
.where(Provider.id == request.provider_id)
|
|||
|
|
)
|
|||
|
|
provider = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if not provider:
|
|||
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|||
|
|
|
|||
|
|
# 获取 API Key
|
|||
|
|
api_key_value = None
|
|||
|
|
endpoint_config = None
|
|||
|
|
|
|||
|
|
if request.api_key_id:
|
|||
|
|
# 查找指定的 API Key
|
|||
|
|
for endpoint in provider.endpoints:
|
|||
|
|
for api_key in endpoint.api_keys:
|
|||
|
|
if api_key.id == request.api_key_id:
|
|||
|
|
api_key_value = api_key.api_key
|
|||
|
|
endpoint_config = {
|
|||
|
|
"base_url": endpoint.base_url,
|
|||
|
|
"api_format": endpoint.api_format if endpoint.api_format else None,
|
|||
|
|
}
|
|||
|
|
break
|
|||
|
|
if api_key_value:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not api_key_value:
|
|||
|
|
raise HTTPException(status_code=404, detail="API Key not found")
|
|||
|
|
else:
|
|||
|
|
# 使用第一个可用的 API Key
|
|||
|
|
for endpoint in provider.endpoints:
|
|||
|
|
if endpoint.is_active and endpoint.api_keys:
|
|||
|
|
for api_key in endpoint.api_keys:
|
|||
|
|
if api_key.is_active:
|
|||
|
|
api_key_value = api_key.api_key
|
|||
|
|
endpoint_config = {
|
|||
|
|
"base_url": endpoint.base_url,
|
|||
|
|
"api_format": endpoint.api_format if endpoint.api_format else None,
|
|||
|
|
}
|
|||
|
|
break
|
|||
|
|
if api_key_value:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not api_key_value:
|
|||
|
|
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
|||
|
|
|
|||
|
|
# 查询余额
|
|||
|
|
registry = get_query_registry()
|
|||
|
|
query_result = await registry.query_provider_balance(
|
|||
|
|
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not query_result.success:
|
|||
|
|
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": query_result.success,
|
|||
|
|
"data": query_result.to_dict(),
|
|||
|
|
"provider": {
|
|||
|
|
"id": provider.id,
|
|||
|
|
"name": provider.name,
|
|||
|
|
"display_name": provider.display_name,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/usage-summary")
|
|||
|
|
async def query_usage_summary(
|
|||
|
|
request: UsageSummaryQueryRequest,
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
查询提供商使用汇总
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
request: 查询请求
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
使用汇总信息
|
|||
|
|
"""
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
from sqlalchemy.orm import selectinload
|
|||
|
|
|
|||
|
|
# 获取提供商及其端点
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(Provider)
|
|||
|
|
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
|||
|
|
.where(Provider.id == request.provider_id)
|
|||
|
|
)
|
|||
|
|
provider = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if not provider:
|
|||
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|||
|
|
|
|||
|
|
# 获取 API Key(逻辑同上)
|
|||
|
|
api_key_value = None
|
|||
|
|
endpoint_config = None
|
|||
|
|
|
|||
|
|
if request.api_key_id:
|
|||
|
|
for endpoint in provider.endpoints:
|
|||
|
|
for api_key in endpoint.api_keys:
|
|||
|
|
if api_key.id == request.api_key_id:
|
|||
|
|
api_key_value = api_key.api_key
|
|||
|
|
endpoint_config = {"base_url": endpoint.base_url}
|
|||
|
|
break
|
|||
|
|
if api_key_value:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not api_key_value:
|
|||
|
|
raise HTTPException(status_code=404, detail="API Key not found")
|
|||
|
|
else:
|
|||
|
|
for endpoint in provider.endpoints:
|
|||
|
|
if endpoint.is_active and endpoint.api_keys:
|
|||
|
|
for api_key in endpoint.api_keys:
|
|||
|
|
if api_key.is_active:
|
|||
|
|
api_key_value = api_key.api_key
|
|||
|
|
endpoint_config = {"base_url": endpoint.base_url}
|
|||
|
|
break
|
|||
|
|
if api_key_value:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not api_key_value:
|
|||
|
|
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
|||
|
|
|
|||
|
|
# 查询使用汇总
|
|||
|
|
registry = get_query_registry()
|
|||
|
|
query_result = await registry.query_provider_usage(
|
|||
|
|
provider_type=provider.name,
|
|||
|
|
api_key=api_key_value,
|
|||
|
|
period=request.period,
|
|||
|
|
endpoint_config=endpoint_config,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": query_result.success,
|
|||
|
|
"data": query_result.to_dict(),
|
|||
|
|
"provider": {
|
|||
|
|
"id": provider.id,
|
|||
|
|
"name": provider.name,
|
|||
|
|
"display_name": provider.display_name,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/models")
|
|||
|
|
async def query_available_models(
|
|||
|
|
request: ModelsQueryRequest,
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
查询提供商可用模型
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
request: 查询请求
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
模型列表
|
|||
|
|
"""
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
from sqlalchemy.orm import selectinload
|
|||
|
|
|
|||
|
|
# 获取提供商及其端点
|
|||
|
|
result = await db.execute(
|
|||
|
|
select(Provider)
|
|||
|
|
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
|||
|
|
.where(Provider.id == request.provider_id)
|
|||
|
|
)
|
|||
|
|
provider = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if not provider:
|
|||
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|||
|
|
|
|||
|
|
# 获取 API Key
|
|||
|
|
api_key_value = None
|
|||
|
|
endpoint_config = None
|
|||
|
|
|
|||
|
|
if request.api_key_id:
|
|||
|
|
for endpoint in provider.endpoints:
|
|||
|
|
for api_key in endpoint.api_keys:
|
|||
|
|
if api_key.id == request.api_key_id:
|
|||
|
|
api_key_value = api_key.api_key
|
|||
|
|
endpoint_config = {"base_url": endpoint.base_url}
|
|||
|
|
break
|
|||
|
|
if api_key_value:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not api_key_value:
|
|||
|
|
raise HTTPException(status_code=404, detail="API Key not found")
|
|||
|
|
else:
|
|||
|
|
for endpoint in provider.endpoints:
|
|||
|
|
if endpoint.is_active and endpoint.api_keys:
|
|||
|
|
for api_key in endpoint.api_keys:
|
|||
|
|
if api_key.is_active:
|
|||
|
|
api_key_value = api_key.api_key
|
|||
|
|
endpoint_config = {"base_url": endpoint.base_url}
|
|||
|
|
break
|
|||
|
|
if api_key_value:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not api_key_value:
|
|||
|
|
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
|||
|
|
|
|||
|
|
# 查询模型
|
|||
|
|
registry = get_query_registry()
|
|||
|
|
adapter = registry.get_adapter_for_provider(provider.name)
|
|||
|
|
|
|||
|
|
if not adapter:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
query_result = await adapter.query_available_models(
|
|||
|
|
api_key=api_key_value, endpoint_config=endpoint_config
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"success": query_result.success,
|
|||
|
|
"data": query_result.to_dict(),
|
|||
|
|
"provider": {
|
|||
|
|
"id": provider.id,
|
|||
|
|
"name": provider.name,
|
|||
|
|
"display_name": provider.display_name,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/cache/{provider_id}")
|
|||
|
|
async def clear_query_cache(
|
|||
|
|
provider_id: str,
|
|||
|
|
api_key_id: Optional[str] = None,
|
|||
|
|
db: AsyncSession = Depends(get_db),
|
|||
|
|
current_user: User = Depends(get_current_user),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
清除查询缓存
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
provider_id: 提供商 ID
|
|||
|
|
api_key_id: 可选,指定清除某个 API Key 的缓存
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
清除结果
|
|||
|
|
"""
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
|
|||
|
|
# 获取提供商
|
|||
|
|
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
|||
|
|
provider = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if not provider:
|
|||
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|||
|
|
|
|||
|
|
registry = get_query_registry()
|
|||
|
|
adapter = registry.get_adapter_for_provider(provider.name)
|
|||
|
|
|
|||
|
|
if adapter:
|
|||
|
|
if api_key_id:
|
|||
|
|
# 获取 API Key 值来清除缓存
|
|||
|
|
from sqlalchemy.orm import selectinload
|
|||
|
|
|
|||
|
|
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
|
|||
|
|
api_key = result.scalar_one_or_none()
|
|||
|
|
if api_key:
|
|||
|
|
adapter.clear_cache(api_key.api_key)
|
|||
|
|
else:
|
|||
|
|
adapter.clear_cache()
|
|||
|
|
|
|||
|
|
return {"success": True, "message": "Cache cleared successfully"}
|