Files
Aether/src/api/handlers/base/base_handler.py

364 lines
12 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
基础消息处理器封装通用的编排转换遥测逻辑
接口约定
- process_stream: 处理流式请求返回 StreamingResponse
- process_sync: 处理非流式请求返回 JSONResponse
签名规范推荐
async def process_stream(
self,
request: Any, # 解析后的请求模型
http_request: Request, # FastAPI Request 对象
original_headers: Dict[str, str], # 原始请求头
original_request_body: Dict[str, Any], # 原始请求体
query_params: Optional[Dict[str, str]] = None, # 查询参数
) -> StreamingResponse: ...
async def process_sync(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> JSONResponse: ...
"""
from __future__ import annotations
import time
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from sqlalchemy.orm import Session
from src.clients.redis_client import get_redis_client_sync
from src.core.api_format_metadata import resolve_api_format
from src.core.enums import APIFormat
from src.core.logger import logger
from src.services.orchestration.fallback_orchestrator import FallbackOrchestrator
from src.services.provider.format import normalize_api_format
from src.services.system.audit import audit_service
from src.services.usage.service import UsageService
class MessageTelemetry:
"""
负责记录 Usage/Audit避免处理器里重复代码
"""
def __init__(self, db: Session, user, api_key, request_id: str, client_ip: str):
self.db = db
self.user = user
self.api_key = api_key
self.request_id = request_id
self.client_ip = client_ip
async def calculate_cost(
self,
provider: str,
model: str,
*,
input_tokens: int,
output_tokens: int,
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
) -> float:
input_price, output_price = await UsageService.get_model_price_async(
self.db, provider, model
)
_, _, _, _, _, _, total_cost = UsageService.calculate_cost(
input_tokens,
output_tokens,
input_price,
output_price,
cache_creation_tokens,
cache_read_tokens,
*await UsageService.get_cache_prices_async(self.db, provider, model, input_price),
)
return total_cost
async def record_success(
self,
*,
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
response_time_ms: int,
status_code: int,
request_body: Dict[str, Any],
request_headers: Dict[str, Any],
response_body: Any,
response_headers: Dict[str, Any],
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
is_stream: bool = False,
provider_request_headers: Optional[Dict[str, Any]] = None,
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
api_format: Optional[str] = None,
# 模型映射信息
target_model: Optional[str] = None,
# Provider 响应元数据(如 Gemini 的 modelVersion
response_metadata: Optional[Dict[str, Any]] = None,
) -> float:
total_cost = await self.calculate_cost(
provider,
model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_tokens=cache_creation_tokens,
cache_read_tokens=cache_read_tokens,
)
await UsageService.record_usage(
db=self.db,
user=self.user,
api_key=self.api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_tokens,
cache_read_input_tokens=cache_read_tokens,
request_type="chat",
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
status_code=status_code,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers or {},
response_headers=response_headers,
response_body=response_body,
request_id=self.request_id,
# Provider 侧追踪信息(用于记录真实成本)
provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id,
# 模型映射信息
target_model=target_model,
# Provider 响应元数据
metadata=response_metadata,
)
if self.user and self.api_key:
audit_service.log_api_request(
db=self.db,
user_id=self.user.id,
api_key_id=self.api_key.id,
request_id=self.request_id,
model=model,
provider=provider,
success=True,
ip_address=self.client_ip,
status_code=status_code,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost_usd=total_cost,
)
return total_cost
async def record_failure(
self,
*,
provider: str,
model: str,
response_time_ms: int,
status_code: int,
error_message: str,
request_body: Dict[str, Any],
request_headers: Dict[str, Any],
is_stream: bool,
api_format: Optional[str] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
# 预估 token 信息(来自 message_start 事件,用于中断请求的成本估算)
input_tokens: int = 0,
output_tokens: int = 0,
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
response_body: Optional[Dict[str, Any]] = None,
# 模型映射信息
target_model: Optional[str] = None,
):
"""
记录失败请求
注意Provider 链路信息provider_id, endpoint_id, key_id不在此处记录
因为 RequestCandidate 表已经记录了完整的请求链路追踪信息
Args:
input_tokens: 预估输入 tokens来自 message_start用于中断请求的成本估算
output_tokens: 预估输出 tokens来自已收到的内容
cache_creation_tokens: 缓存创建 tokens
cache_read_tokens: 缓存读取 tokens
response_body: 响应体如果有部分响应
target_model: 映射后的目标模型名如果发生了映射
"""
provider_name = provider or "unknown"
if provider_name == "unknown":
logger.warning(f"[Telemetry] Recording failure with unknown provider (request_id={self.request_id})")
await UsageService.record_usage(
db=self.db,
user=self.user,
api_key=self.api_key,
provider=provider_name,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_tokens,
cache_read_input_tokens=cache_read_tokens,
request_type="chat",
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=error_message,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers or {},
response_headers={},
response_body=response_body or {"error": error_message},
request_id=self.request_id,
# 模型映射信息
target_model=target_model,
)
@runtime_checkable
class MessageHandlerProtocol(Protocol):
"""
消息处理器协议 - 定义标准接口
ChatHandlerBase 使用完整签名 request, http_request
CliMessageHandlerBase 使用简化签名 original_request_body, original_headers
"""
async def process_stream(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> StreamingResponse:
"""处理流式请求"""
...
async def process_sync(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> JSONResponse:
"""处理非流式请求"""
...
class BaseMessageHandler:
"""
消息处理器基类所有具体格式的 handler 可以继承它
子类需要实现
- process_stream: 处理流式请求
- process_sync: 处理非流式请求
推荐使用 MessageHandlerProtocol 中定义的签名
"""
# Adapter 检测器类型
AdapterDetectorType = Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]
def __init__(
self,
*,
db: Session,
user,
api_key,
request_id: str,
client_ip: str,
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list[str]] = None,
adapter_detector: Optional[AdapterDetectorType] = None,
):
self.db = db
self.user = user
self.api_key = api_key
self.request_id = request_id
self.client_ip = client_ip
self.user_agent = user_agent
self.start_time = start_time
self.allowed_api_formats = allowed_api_formats or [APIFormat.CLAUDE.value]
self.primary_api_format = normalize_api_format(self.allowed_api_formats[0])
self.adapter_detector = adapter_detector
redis_client = get_redis_client_sync()
self.orchestrator = FallbackOrchestrator(db, redis_client)
self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip)
def elapsed_ms(self) -> int:
return int((time.time() - self.start_time) * 1000)
def _resolve_capability_requirements(
self,
model_name: str,
request_headers: Optional[Dict[str, str]] = None,
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""
解析请求的能力需求
来源:
1. 用户模型级配置 (User.model_capability_settings)
2. 用户 API Key 强制配置 (ApiKey.force_capabilities)
3. 请求头 X-Require-Capability
4. Adapter detect_capability_requirements Claude anthropic-beta
Args:
model_name: 模型名称
request_headers: 请求头
request_body: 请求体可选
Returns:
能力需求字典
"""
from src.services.capability.resolver import CapabilityResolver
return CapabilityResolver.resolve_requirements(
user=self.user,
user_api_key=self.api_key,
model_name=model_name,
request_headers=request_headers,
request_body=request_body,
adapter_detector=self.adapter_detector,
)
def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat:
"""根据 provider_type 解析 API 格式,未知类型默认 OPENAI"""
if provider_type:
return resolve_api_format(provider_type, default=APIFormat.OPENAI)
return self.primary_api_format
def build_provider_payload(
self,
original_body: Dict[str, Any],
*,
mapped_model: Optional[str] = None,
) -> Dict[str, Any]:
"""构建发送给 Provider 的请求体,替换 model 名称"""
payload = dict(original_body)
if mapped_model:
payload["model"] = mapped_model
return payload