2025-12-10 20:52:44 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Chat Adapter 通用基类
|
|
|
|
|
|
|
|
|
|
|
|
提供 Chat 格式(进行请求验证和标准化)的通用适配器逻辑:
|
|
|
|
|
|
- 请求解析和验证
|
|
|
|
|
|
- 审计日志记录
|
|
|
|
|
|
- 错误处理和响应格式化
|
|
|
|
|
|
- Handler 创建和调用
|
|
|
|
|
|
- 计费策略(支持不同 API 格式的差异化计费)
|
|
|
|
|
|
|
|
|
|
|
|
子类只需提供:
|
|
|
|
|
|
- FORMAT_ID: API 格式标识
|
|
|
|
|
|
- HANDLER_CLASS: 对应的 ChatHandlerBase 子类
|
|
|
|
|
|
- _validate_request_body(): 可选覆盖请求验证逻辑
|
|
|
|
|
|
- _build_audit_metadata(): 可选覆盖审计元数据构建
|
|
|
|
|
|
- compute_total_input_context(): 可选覆盖总输入上下文计算(用于阶梯计费判定)
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import time
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
from abc import abstractmethod
|
2025-12-19 23:40:42 +08:00
|
|
|
|
from typing import Any, Dict, Optional, Tuple, Type
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
2025-12-19 23:40:42 +08:00
|
|
|
|
import httpx
|
2025-12-10 20:52:44 +08:00
|
|
|
|
from fastapi import HTTPException
|
|
|
|
|
|
from fastapi.responses import JSONResponse
|
|
|
|
|
|
|
|
|
|
|
|
from src.api.base.adapter import ApiAdapter, ApiMode
|
|
|
|
|
|
from src.api.base.context import ApiRequestContext
|
|
|
|
|
|
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
|
|
|
|
|
from src.core.exceptions import (
|
|
|
|
|
|
InvalidRequestException,
|
|
|
|
|
|
ModelNotSupportedException,
|
|
|
|
|
|
ProviderAuthException,
|
|
|
|
|
|
ProviderNotAvailableException,
|
|
|
|
|
|
ProviderRateLimitException,
|
|
|
|
|
|
ProviderTimeoutException,
|
|
|
|
|
|
ProxyException,
|
|
|
|
|
|
QuotaExceededException,
|
|
|
|
|
|
UpstreamClientException,
|
|
|
|
|
|
)
|
|
|
|
|
|
from src.core.logger import logger
|
|
|
|
|
|
from src.services.request.result import RequestResult
|
|
|
|
|
|
from src.services.usage.recorder import UsageRecorder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatAdapterBase(ApiAdapter):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Chat Adapter 通用基类
|
|
|
|
|
|
|
|
|
|
|
|
提供 Chat 格式的通用适配器逻辑,子类只需配置:
|
|
|
|
|
|
- FORMAT_ID: API 格式标识
|
|
|
|
|
|
- HANDLER_CLASS: ChatHandlerBase 子类
|
|
|
|
|
|
- name: 适配器名称
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# 子类必须覆盖
|
|
|
|
|
|
FORMAT_ID: str = "UNKNOWN"
|
|
|
|
|
|
HANDLER_CLASS: Type[ChatHandlerBase]
|
|
|
|
|
|
|
|
|
|
|
|
# 适配器配置
|
|
|
|
|
|
name: str = "chat.base"
|
|
|
|
|
|
mode = ApiMode.STANDARD
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
|
|
|
|
|
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
|
|
|
|
|
|
|
|
|
|
|
async def handle(self, context: ApiRequestContext):
|
|
|
|
|
|
"""处理 Chat API 请求"""
|
|
|
|
|
|
http_request = context.request
|
|
|
|
|
|
user = context.user
|
|
|
|
|
|
api_key = context.api_key
|
|
|
|
|
|
db = context.db
|
|
|
|
|
|
request_id = context.request_id
|
|
|
|
|
|
quota_remaining_value = context.quota_remaining
|
|
|
|
|
|
start_time = context.start_time
|
|
|
|
|
|
client_ip = context.client_ip
|
|
|
|
|
|
user_agent = context.user_agent
|
|
|
|
|
|
original_headers = context.original_headers
|
|
|
|
|
|
query_params = context.query_params
|
|
|
|
|
|
|
|
|
|
|
|
original_request_body = context.ensure_json_body()
|
|
|
|
|
|
|
|
|
|
|
|
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
|
|
|
|
|
|
if context.path_params:
|
|
|
|
|
|
original_request_body = self._merge_path_params(
|
|
|
|
|
|
original_request_body, context.path_params
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 验证和解析请求
|
|
|
|
|
|
request_obj = self._validate_request_body(original_request_body, context.path_params)
|
|
|
|
|
|
if isinstance(request_obj, JSONResponse):
|
|
|
|
|
|
return request_obj
|
|
|
|
|
|
|
|
|
|
|
|
stream = getattr(request_obj, "stream", False)
|
|
|
|
|
|
model = getattr(request_obj, "model", "unknown")
|
|
|
|
|
|
|
|
|
|
|
|
# 添加审计元数据
|
|
|
|
|
|
audit_metadata = self._build_audit_metadata(original_request_body, request_obj)
|
|
|
|
|
|
context.add_audit_metadata(**audit_metadata)
|
|
|
|
|
|
|
|
|
|
|
|
# 格式化额度显示
|
|
|
|
|
|
quota_display = (
|
|
|
|
|
|
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 请求开始日志
|
|
|
|
|
|
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
|
|
|
|
|
|
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 检查客户端连接
|
|
|
|
|
|
if await http_request.is_disconnected():
|
|
|
|
|
|
logger.warning("客户端连接断开")
|
|
|
|
|
|
raise HTTPException(status_code=499, detail="Client disconnected")
|
|
|
|
|
|
|
|
|
|
|
|
# 创建 Handler
|
|
|
|
|
|
handler = self._create_handler(
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
user=user,
|
|
|
|
|
|
api_key=api_key,
|
|
|
|
|
|
request_id=request_id,
|
|
|
|
|
|
client_ip=client_ip,
|
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
|
start_time=start_time,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 处理请求
|
|
|
|
|
|
if stream:
|
|
|
|
|
|
return await handler.process_stream(
|
|
|
|
|
|
request=request_obj,
|
|
|
|
|
|
http_request=http_request,
|
|
|
|
|
|
original_headers=original_headers,
|
|
|
|
|
|
original_request_body=original_request_body,
|
|
|
|
|
|
query_params=query_params,
|
|
|
|
|
|
)
|
|
|
|
|
|
return await handler.process_sync(
|
|
|
|
|
|
request=request_obj,
|
|
|
|
|
|
http_request=http_request,
|
|
|
|
|
|
original_headers=original_headers,
|
|
|
|
|
|
original_request_body=original_request_body,
|
|
|
|
|
|
query_params=query_params,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except HTTPException:
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
except (
|
|
|
|
|
|
ModelNotSupportedException,
|
|
|
|
|
|
QuotaExceededException,
|
|
|
|
|
|
InvalidRequestException,
|
|
|
|
|
|
) as e:
|
|
|
|
|
|
logger.info(f"客户端请求错误: {e.error_type}")
|
|
|
|
|
|
return self._error_response(
|
|
|
|
|
|
status_code=e.status_code,
|
|
|
|
|
|
error_type=(
|
|
|
|
|
|
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
|
|
|
|
|
|
),
|
|
|
|
|
|
message=e.message,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except (
|
|
|
|
|
|
ProviderAuthException,
|
|
|
|
|
|
ProviderRateLimitException,
|
|
|
|
|
|
ProviderNotAvailableException,
|
|
|
|
|
|
ProviderTimeoutException,
|
|
|
|
|
|
UpstreamClientException,
|
|
|
|
|
|
) as e:
|
|
|
|
|
|
return await self._handle_provider_exception(
|
|
|
|
|
|
e,
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
user=user,
|
|
|
|
|
|
api_key=api_key,
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
stream=stream,
|
|
|
|
|
|
start_time=start_time,
|
|
|
|
|
|
original_headers=original_headers,
|
|
|
|
|
|
original_request_body=original_request_body,
|
|
|
|
|
|
client_ip=client_ip,
|
|
|
|
|
|
request_id=request_id,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return await self._handle_unexpected_exception(
|
|
|
|
|
|
e,
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
user=user,
|
|
|
|
|
|
api_key=api_key,
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
stream=stream,
|
|
|
|
|
|
start_time=start_time,
|
|
|
|
|
|
original_headers=original_headers,
|
|
|
|
|
|
original_request_body=original_request_body,
|
|
|
|
|
|
client_ip=client_ip,
|
|
|
|
|
|
request_id=request_id,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _create_handler(
|
|
|
|
|
|
self,
|
|
|
|
|
|
*,
|
|
|
|
|
|
db,
|
|
|
|
|
|
user,
|
|
|
|
|
|
api_key,
|
|
|
|
|
|
request_id: str,
|
|
|
|
|
|
client_ip: str,
|
|
|
|
|
|
user_agent: str,
|
|
|
|
|
|
start_time: float,
|
|
|
|
|
|
):
|
|
|
|
|
|
"""创建 Handler 实例 - 子类可覆盖"""
|
|
|
|
|
|
return self.HANDLER_CLASS(
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
user=user,
|
|
|
|
|
|
api_key=api_key,
|
|
|
|
|
|
request_id=request_id,
|
|
|
|
|
|
client_ip=client_ip,
|
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
|
start_time=start_time,
|
|
|
|
|
|
allowed_api_formats=self.allowed_api_formats,
|
|
|
|
|
|
adapter_detector=self.detect_capability_requirements,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _merge_path_params(
|
|
|
|
|
|
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
|
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
合并 URL 路径参数到请求体 - 子类可覆盖
|
|
|
|
|
|
|
|
|
|
|
|
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
original_request_body: 原始请求体字典
|
|
|
|
|
|
path_params: URL 路径参数字典
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
合并后的请求体字典
|
|
|
|
|
|
"""
|
|
|
|
|
|
merged = original_request_body.copy()
|
|
|
|
|
|
for key, value in path_params.items():
|
|
|
|
|
|
if key not in merged:
|
|
|
|
|
|
merged[key] = value
|
|
|
|
|
|
return merged
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
|
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
验证请求体 - 子类必须实现
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
original_request_body: 原始请求体字典
|
|
|
|
|
|
path_params: URL 路径参数(如 Gemini 的 stream 通过 URL 端点传入)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
验证后的请求对象,或 JSONResponse 错误响应
|
|
|
|
|
|
"""
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
|
|
|
|
|
|
"""
|
|
|
|
|
|
提取消息数量 - 子类可覆盖
|
|
|
|
|
|
|
|
|
|
|
|
默认实现:从 messages 字段提取
|
|
|
|
|
|
"""
|
|
|
|
|
|
messages = payload.get("messages", [])
|
|
|
|
|
|
if hasattr(request_obj, "messages"):
|
|
|
|
|
|
messages = request_obj.messages
|
|
|
|
|
|
return len(messages) if isinstance(messages, list) else 0
|
|
|
|
|
|
|
|
|
|
|
|
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
构建审计日志元数据 - 子类可覆盖
|
|
|
|
|
|
"""
|
|
|
|
|
|
model = getattr(request_obj, "model", payload.get("model", "unknown"))
|
|
|
|
|
|
stream = getattr(request_obj, "stream", payload.get("stream", False))
|
|
|
|
|
|
messages_count = self._extract_message_count(payload, request_obj)
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"action": f"{self.FORMAT_ID.lower()}_request",
|
|
|
|
|
|
"model": model,
|
|
|
|
|
|
"stream": bool(stream),
|
|
|
|
|
|
"max_tokens": getattr(request_obj, "max_tokens", payload.get("max_tokens")),
|
|
|
|
|
|
"messages_count": messages_count,
|
|
|
|
|
|
"temperature": getattr(request_obj, "temperature", payload.get("temperature")),
|
|
|
|
|
|
"top_p": getattr(request_obj, "top_p", payload.get("top_p")),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async def _handle_provider_exception(
|
|
|
|
|
|
self,
|
|
|
|
|
|
e: Exception,
|
|
|
|
|
|
*,
|
|
|
|
|
|
db,
|
|
|
|
|
|
user,
|
|
|
|
|
|
api_key,
|
|
|
|
|
|
model: str,
|
|
|
|
|
|
stream: bool,
|
|
|
|
|
|
start_time: float,
|
|
|
|
|
|
original_headers: Dict[str, str],
|
|
|
|
|
|
original_request_body: Dict[str, Any],
|
|
|
|
|
|
client_ip: str,
|
|
|
|
|
|
request_id: str,
|
|
|
|
|
|
) -> JSONResponse:
|
|
|
|
|
|
"""处理 Provider 相关异常"""
|
|
|
|
|
|
logger.debug(f"Caught provider exception: {type(e).__name__}")
|
|
|
|
|
|
|
|
|
|
|
|
response_time = int((time.time() - start_time) * 1000)
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 RequestResult.from_exception 创建统一的失败结果
|
|
|
|
|
|
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
|
|
|
|
|
result = RequestResult.from_exception(
|
|
|
|
|
|
exception=e,
|
|
|
|
|
|
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
response_time_ms=response_time,
|
|
|
|
|
|
is_stream=stream,
|
|
|
|
|
|
)
|
|
|
|
|
|
result.request_headers = original_headers
|
|
|
|
|
|
result.request_body = original_request_body
|
|
|
|
|
|
|
|
|
|
|
|
# 确定错误消息
|
|
|
|
|
|
if isinstance(e, ProviderAuthException):
|
|
|
|
|
|
error_message = (
|
|
|
|
|
|
f"提供商认证失败: {str(e)}"
|
|
|
|
|
|
if result.metadata.provider != "unknown"
|
|
|
|
|
|
else "服务端错误: 无可用提供商"
|
|
|
|
|
|
)
|
|
|
|
|
|
result.error_message = error_message
|
|
|
|
|
|
|
|
|
|
|
|
# 处理上游客户端错误(如图片处理失败)
|
|
|
|
|
|
if isinstance(e, UpstreamClientException):
|
|
|
|
|
|
# 返回 400 状态码和清晰的错误消息
|
|
|
|
|
|
result.status_code = e.status_code
|
|
|
|
|
|
result.error_message = e.message
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 UsageRecorder 记录失败
|
|
|
|
|
|
recorder = UsageRecorder(
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
user=user,
|
|
|
|
|
|
api_key=api_key,
|
|
|
|
|
|
client_ip=client_ip,
|
|
|
|
|
|
request_id=request_id,
|
|
|
|
|
|
)
|
|
|
|
|
|
await recorder.record_failure(result, original_headers, original_request_body)
|
|
|
|
|
|
|
|
|
|
|
|
# 根据异常类型确定错误类型
|
|
|
|
|
|
if isinstance(e, UpstreamClientException):
|
|
|
|
|
|
error_type = "invalid_request_error"
|
|
|
|
|
|
elif result.status_code == 503:
|
|
|
|
|
|
error_type = "internal_server_error"
|
|
|
|
|
|
else:
|
|
|
|
|
|
error_type = "rate_limit_exceeded"
|
|
|
|
|
|
|
|
|
|
|
|
return self._error_response(
|
|
|
|
|
|
status_code=result.status_code,
|
|
|
|
|
|
error_type=error_type,
|
|
|
|
|
|
message=result.error_message or str(e),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def _handle_unexpected_exception(
|
|
|
|
|
|
self,
|
|
|
|
|
|
e: Exception,
|
|
|
|
|
|
*,
|
|
|
|
|
|
db,
|
|
|
|
|
|
user,
|
|
|
|
|
|
api_key,
|
|
|
|
|
|
model: str,
|
|
|
|
|
|
stream: bool,
|
|
|
|
|
|
start_time: float,
|
|
|
|
|
|
original_headers: Dict[str, str],
|
|
|
|
|
|
original_request_body: Dict[str, Any],
|
|
|
|
|
|
client_ip: str,
|
|
|
|
|
|
request_id: str,
|
|
|
|
|
|
) -> JSONResponse:
|
|
|
|
|
|
"""处理未预期的异常"""
|
|
|
|
|
|
if isinstance(e, ProxyException):
|
|
|
|
|
|
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
|
|
|
|
|
|
exception=e,
|
|
|
|
|
|
extra_data={
|
|
|
|
|
|
"exception_class": e.__class__.__name__,
|
|
|
|
|
|
"processing_stage": "request_processing",
|
|
|
|
|
|
"model": model,
|
|
|
|
|
|
"stream": stream,
|
|
|
|
|
|
"traceback_preview": str(traceback.format_exc())[:500],
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
response_time = int((time.time() - start_time) * 1000)
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 RequestResult.from_exception 创建统一的失败结果
|
|
|
|
|
|
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
|
|
|
|
|
result = RequestResult.from_exception(
|
|
|
|
|
|
exception=e,
|
|
|
|
|
|
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
response_time_ms=response_time,
|
|
|
|
|
|
is_stream=stream,
|
|
|
|
|
|
)
|
|
|
|
|
|
# 对于未预期的异常,强制设置状态码为 500
|
|
|
|
|
|
result.status_code = 500
|
|
|
|
|
|
result.error_type = "internal_error"
|
|
|
|
|
|
result.request_headers = original_headers
|
|
|
|
|
|
result.request_body = original_request_body
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 使用 UsageRecorder 记录失败
|
|
|
|
|
|
recorder = UsageRecorder(
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
user=user,
|
|
|
|
|
|
api_key=api_key,
|
|
|
|
|
|
client_ip=client_ip,
|
|
|
|
|
|
request_id=request_id,
|
|
|
|
|
|
)
|
|
|
|
|
|
await recorder.record_failure(result, original_headers, original_request_body)
|
|
|
|
|
|
except Exception as record_error:
|
|
|
|
|
|
logger.error(f"记录失败请求时出错: {record_error}")
|
|
|
|
|
|
|
|
|
|
|
|
return self._error_response(
|
|
|
|
|
|
status_code=500,
|
|
|
|
|
|
error_type="internal_server_error",
|
|
|
|
|
|
message="处理请求时发生内部错误")
|
|
|
|
|
|
|
|
|
|
|
|
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
|
|
|
|
|
|
"""生成错误响应 - 子类可覆盖以自定义格式"""
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
status_code=status_code,
|
|
|
|
|
|
content={
|
|
|
|
|
|
"error": {
|
|
|
|
|
|
"type": error_type,
|
|
|
|
|
|
"message": message,
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# =========================================================================
|
|
|
|
|
|
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
|
|
|
|
|
|
# =========================================================================
|
|
|
|
|
|
|
|
|
|
|
|
def compute_total_input_context(
|
|
|
|
|
|
self,
|
|
|
|
|
|
input_tokens: int,
|
|
|
|
|
|
cache_read_input_tokens: int,
|
|
|
|
|
|
cache_creation_input_tokens: int = 0,
|
|
|
|
|
|
) -> int:
|
|
|
|
|
|
"""
|
|
|
|
|
|
计算总输入上下文(用于阶梯计费判定)
|
|
|
|
|
|
|
|
|
|
|
|
默认实现:input_tokens + cache_read_input_tokens
|
|
|
|
|
|
子类可覆盖此方法实现不同的计算逻辑
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
input_tokens: 输入 token 数
|
|
|
|
|
|
cache_read_input_tokens: 缓存读取 token 数
|
|
|
|
|
|
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
总输入上下文 token 数
|
|
|
|
|
|
"""
|
|
|
|
|
|
return input_tokens + cache_read_input_tokens
|
|
|
|
|
|
|
|
|
|
|
|
def get_cache_read_price_for_ttl(
|
|
|
|
|
|
self,
|
|
|
|
|
|
tier: dict,
|
|
|
|
|
|
cache_ttl_minutes: Optional[int] = None,
|
|
|
|
|
|
) -> Optional[float]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据缓存 TTL 获取缓存读取价格
|
|
|
|
|
|
|
|
|
|
|
|
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
|
|
|
|
|
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
tier: 当前阶梯配置
|
|
|
|
|
|
cache_ttl_minutes: 缓存时长(分钟)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
缓存读取价格(每 1M tokens)
|
|
|
|
|
|
"""
|
|
|
|
|
|
ttl_pricing = tier.get("cache_ttl_pricing")
|
|
|
|
|
|
if ttl_pricing and cache_ttl_minutes is not None:
|
|
|
|
|
|
matched_price = None
|
|
|
|
|
|
for ttl_config in ttl_pricing:
|
|
|
|
|
|
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
|
|
|
|
|
if cache_ttl_minutes <= ttl_limit:
|
|
|
|
|
|
matched_price = ttl_config.get("cache_read_price_per_1m")
|
|
|
|
|
|
break
|
|
|
|
|
|
if matched_price is not None:
|
|
|
|
|
|
return matched_price
|
|
|
|
|
|
# 超过所有配置的 TTL,使用最后一个
|
|
|
|
|
|
if ttl_pricing:
|
|
|
|
|
|
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
|
|
|
|
|
|
|
|
|
|
|
return tier.get("cache_read_price_per_1m")
|
|
|
|
|
|
|
|
|
|
|
|
def compute_cost(
|
|
|
|
|
|
self,
|
|
|
|
|
|
input_tokens: int,
|
|
|
|
|
|
output_tokens: int,
|
|
|
|
|
|
cache_creation_input_tokens: int,
|
|
|
|
|
|
cache_read_input_tokens: int,
|
|
|
|
|
|
input_price_per_1m: float,
|
|
|
|
|
|
output_price_per_1m: float,
|
|
|
|
|
|
cache_creation_price_per_1m: Optional[float],
|
|
|
|
|
|
cache_read_price_per_1m: Optional[float],
|
|
|
|
|
|
price_per_request: Optional[float],
|
|
|
|
|
|
tiered_pricing: Optional[dict] = None,
|
|
|
|
|
|
cache_ttl_minutes: Optional[int] = None,
|
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
计算请求成本
|
|
|
|
|
|
|
|
|
|
|
|
默认实现:支持固定价格和阶梯计费
|
|
|
|
|
|
子类可覆盖此方法实现完全不同的计费逻辑
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
input_tokens: 输入 token 数
|
|
|
|
|
|
output_tokens: 输出 token 数
|
|
|
|
|
|
cache_creation_input_tokens: 缓存创建 token 数
|
|
|
|
|
|
cache_read_input_tokens: 缓存读取 token 数
|
|
|
|
|
|
input_price_per_1m: 输入价格(每 1M tokens)
|
|
|
|
|
|
output_price_per_1m: 输出价格(每 1M tokens)
|
|
|
|
|
|
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens)
|
|
|
|
|
|
cache_read_price_per_1m: 缓存读取价格(每 1M tokens)
|
|
|
|
|
|
price_per_request: 按次计费价格
|
|
|
|
|
|
tiered_pricing: 阶梯计费配置
|
|
|
|
|
|
cache_ttl_minutes: 缓存时长(分钟)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
包含各项成本的字典:
|
|
|
|
|
|
{
|
|
|
|
|
|
"input_cost": float,
|
|
|
|
|
|
"output_cost": float,
|
|
|
|
|
|
"cache_creation_cost": float,
|
|
|
|
|
|
"cache_read_cost": float,
|
|
|
|
|
|
"cache_cost": float,
|
|
|
|
|
|
"request_cost": float,
|
|
|
|
|
|
"total_cost": float,
|
|
|
|
|
|
"tier_index": Optional[int], # 命中的阶梯索引
|
|
|
|
|
|
}
|
|
|
|
|
|
"""
|
|
|
|
|
|
tier_index = None
|
|
|
|
|
|
effective_input_price = input_price_per_1m
|
|
|
|
|
|
effective_output_price = output_price_per_1m
|
|
|
|
|
|
effective_cache_creation_price = cache_creation_price_per_1m
|
|
|
|
|
|
effective_cache_read_price = cache_read_price_per_1m
|
|
|
|
|
|
|
|
|
|
|
|
# 检查阶梯计费
|
|
|
|
|
|
if tiered_pricing and tiered_pricing.get("tiers"):
|
|
|
|
|
|
total_input_context = self.compute_total_input_context(
|
|
|
|
|
|
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
|
|
|
|
|
)
|
|
|
|
|
|
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
|
|
|
|
|
|
|
|
|
|
|
if tier:
|
|
|
|
|
|
tier_index = tiered_pricing["tiers"].index(tier)
|
|
|
|
|
|
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
|
|
|
|
|
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
|
|
|
|
|
effective_cache_creation_price = tier.get(
|
|
|
|
|
|
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
|
|
|
|
|
)
|
|
|
|
|
|
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
|
|
|
|
|
tier, cache_ttl_minutes
|
|
|
|
|
|
)
|
|
|
|
|
|
if effective_cache_read_price is None:
|
|
|
|
|
|
effective_cache_read_price = cache_read_price_per_1m
|
|
|
|
|
|
|
|
|
|
|
|
# 计算各项成本
|
|
|
|
|
|
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
|
|
|
|
|
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
|
|
|
|
|
|
|
|
|
|
|
cache_creation_cost = 0.0
|
|
|
|
|
|
cache_read_cost = 0.0
|
|
|
|
|
|
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
|
|
|
|
|
cache_creation_cost = (
|
|
|
|
|
|
cache_creation_input_tokens / 1_000_000
|
|
|
|
|
|
) * effective_cache_creation_price
|
|
|
|
|
|
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
|
|
|
|
|
cache_read_cost = (
|
|
|
|
|
|
cache_read_input_tokens / 1_000_000
|
|
|
|
|
|
) * effective_cache_read_price
|
|
|
|
|
|
|
|
|
|
|
|
cache_cost = cache_creation_cost + cache_read_cost
|
|
|
|
|
|
request_cost = price_per_request if price_per_request else 0.0
|
|
|
|
|
|
total_cost = input_cost + output_cost + cache_cost + request_cost
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"input_cost": input_cost,
|
|
|
|
|
|
"output_cost": output_cost,
|
|
|
|
|
|
"cache_creation_cost": cache_creation_cost,
|
|
|
|
|
|
"cache_read_cost": cache_read_cost,
|
|
|
|
|
|
"cache_cost": cache_cost,
|
|
|
|
|
|
"request_cost": request_cost,
|
|
|
|
|
|
"total_cost": total_cost,
|
|
|
|
|
|
"tier_index": tier_index,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据总输入 token 数确定价格阶梯
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
|
|
|
|
|
|
total_input_tokens: 总输入 token 数
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
匹配的阶梯配置
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not tiered_pricing or "tiers" not in tiered_pricing:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
tiers = tiered_pricing.get("tiers", [])
|
|
|
|
|
|
if not tiers:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
for tier in tiers:
|
|
|
|
|
|
up_to = tier.get("up_to")
|
|
|
|
|
|
if up_to is None or total_input_tokens <= up_to:
|
|
|
|
|
|
return tier
|
|
|
|
|
|
|
|
|
|
|
|
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
|
|
|
|
|
return tiers[-1] if tiers else None
|
|
|
|
|
|
|
2025-12-19 23:40:42 +08:00
|
|
|
|
# =========================================================================
|
|
|
|
|
|
# 模型列表查询 - 子类应覆盖此方法
|
|
|
|
|
|
# =========================================================================
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
async def fetch_models(
|
|
|
|
|
|
cls,
|
|
|
|
|
|
client: httpx.AsyncClient,
|
|
|
|
|
|
base_url: str,
|
|
|
|
|
|
api_key: str,
|
|
|
|
|
|
extra_headers: Optional[Dict[str, str]] = None,
|
|
|
|
|
|
) -> Tuple[list, Optional[str]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
查询上游 API 支持的模型列表
|
|
|
|
|
|
|
|
|
|
|
|
这是 Aether 内部发起的请求(非用户透传),用于:
|
|
|
|
|
|
- 管理后台查询提供商支持的模型
|
|
|
|
|
|
- 自动发现可用模型
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
client: httpx 异步客户端
|
|
|
|
|
|
base_url: API 基础 URL
|
|
|
|
|
|
api_key: API 密钥(已解密)
|
|
|
|
|
|
extra_headers: 端点配置的额外请求头
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
(models, error): 模型列表和错误信息
|
|
|
|
|
|
- models: 模型信息列表,每个模型至少包含 id 字段
|
|
|
|
|
|
- error: 错误信息,成功时为 None
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 默认实现返回空列表,子类应覆盖
|
|
|
|
|
|
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
|
|
|
|
|
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
# =========================================================================
|
|
|
|
|
|
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
|
|
|
|
|
# =========================================================================
|
|
|
|
|
|
|
|
|
|
|
|
_ADAPTER_REGISTRY: Dict[str, Type["ChatAdapterBase"]] = {}
|
|
|
|
|
|
_ADAPTERS_LOADED = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_adapter(adapter_class: Type["ChatAdapterBase"]) -> Type["ChatAdapterBase"]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
注册 Adapter 类到注册表
|
|
|
|
|
|
|
|
|
|
|
|
用法:
|
|
|
|
|
|
@register_adapter
|
|
|
|
|
|
class ClaudeChatAdapter(ChatAdapterBase):
|
|
|
|
|
|
FORMAT_ID = "CLAUDE"
|
|
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
adapter_class: Adapter 类
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
注册的 Adapter 类(支持作为装饰器使用)
|
|
|
|
|
|
"""
|
|
|
|
|
|
format_id = adapter_class.FORMAT_ID
|
|
|
|
|
|
if format_id and format_id != "UNKNOWN":
|
|
|
|
|
|
_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
|
|
|
|
|
|
return adapter_class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ensure_adapters_loaded():
|
|
|
|
|
|
"""确保所有 Adapter 已被加载(触发注册)"""
|
|
|
|
|
|
global _ADAPTERS_LOADED
|
|
|
|
|
|
if _ADAPTERS_LOADED:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 导入各个 Adapter 模块以触发 @register_adapter 装饰器
|
|
|
|
|
|
try:
|
|
|
|
|
|
from src.api.handlers.claude import adapter as _ # noqa: F401
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
pass
|
|
|
|
|
|
try:
|
|
|
|
|
|
from src.api.handlers.openai import adapter as _ # noqa: F401
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
pass
|
|
|
|
|
|
try:
|
|
|
|
|
|
from src.api.handlers.gemini import adapter as _ # noqa: F401
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
_ADAPTERS_LOADED = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_adapter_class(api_format: str) -> Optional[Type["ChatAdapterBase"]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据 API format 获取 Adapter 类
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
api_format: API 格式标识(如 "CLAUDE", "OPENAI", "GEMINI")
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
对应的 Adapter 类,如果未找到返回 None
|
|
|
|
|
|
"""
|
|
|
|
|
|
_ensure_adapters_loaded()
|
|
|
|
|
|
return _ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_adapter_instance(api_format: str) -> Optional["ChatAdapterBase"]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据 API format 获取 Adapter 实例
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
api_format: API 格式标识
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Adapter 实例,如果未找到返回 None
|
|
|
|
|
|
"""
|
|
|
|
|
|
adapter_class = get_adapter_class(api_format)
|
|
|
|
|
|
if adapter_class:
|
|
|
|
|
|
return adapter_class()
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def list_registered_formats() -> list[str]:
|
|
|
|
|
|
"""返回所有已注册的 API 格式"""
|
|
|
|
|
|
_ensure_adapters_loaded()
|
|
|
|
|
|
return list(_ADAPTER_REGISTRY.keys())
|