mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-11 03:58:28 +08:00
Initial commit
This commit is contained in:
17
src/services/request/__init__.py
Normal file
17
src/services/request/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
请求处理服务模块
|
||||
|
||||
包含候选选择、执行等功能。
|
||||
|
||||
注意:
|
||||
- RequestBuilder 已移至 src.api.handlers.base.request_builder,请直接从该模块导入
|
||||
- record_failed_request 已移至 src.services.usage.recorder,请直接从该模块导入
|
||||
"""
|
||||
|
||||
from src.services.request.candidate import RequestCandidateService
|
||||
from src.services.request.executor import RequestExecutor
|
||||
|
||||
__all__ = [
|
||||
"RequestCandidateService",
|
||||
"RequestExecutor",
|
||||
]
|
||||
291
src/services/request/candidate.py
Normal file
291
src/services/request/candidate.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
请求候选记录服务 - 管理候选队列
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.batch_committer import get_batch_committer
|
||||
from src.models.database import RequestCandidate
|
||||
|
||||
|
||||
class RequestCandidateService:
|
||||
"""请求候选记录服务"""
|
||||
|
||||
@staticmethod
|
||||
def create_candidate(
|
||||
db: Session,
|
||||
request_id: str,
|
||||
candidate_index: int,
|
||||
retry_index: int = 0, # 新增:重试序号
|
||||
user_id: Optional[str] = None,
|
||||
api_key_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
key_id: Optional[str] = None,
|
||||
status: str = "available",
|
||||
skip_reason: Optional[str] = None,
|
||||
is_cached: bool = False,
|
||||
extra_data: Optional[dict] = None,
|
||||
required_capabilities: Optional[dict] = None,
|
||||
) -> RequestCandidate:
|
||||
"""
|
||||
创建候选记录
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
request_id: 请求ID
|
||||
candidate_index: 候选序号
|
||||
retry_index: 重试序号(从0开始)
|
||||
user_id: 用户ID
|
||||
api_key_id: API Key ID
|
||||
provider_id: Provider ID
|
||||
endpoint_id: Endpoint ID
|
||||
key_id: API Key ID
|
||||
status: 候选状态 ('available', 'used', 'skipped', 'success', 'failed')
|
||||
skip_reason: 跳过原因
|
||||
is_cached: 是否为缓存亲和性候选
|
||||
extra_data: 额外数据
|
||||
required_capabilities: 请求需要的能力标签
|
||||
"""
|
||||
candidate = RequestCandidate(
|
||||
id=str(uuid.uuid4()),
|
||||
request_id=request_id,
|
||||
candidate_index=candidate_index,
|
||||
retry_index=retry_index, # 新增
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=endpoint_id,
|
||||
key_id=key_id,
|
||||
status=status,
|
||||
skip_reason=skip_reason,
|
||||
is_cached=is_cached,
|
||||
extra_data=extra_data or {},
|
||||
required_capabilities=required_capabilities,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(candidate)
|
||||
db.flush() # 只flush,不立即 commit
|
||||
# 标记为批量提交(非关键数据,可延迟)
|
||||
get_batch_committer().mark_dirty(db)
|
||||
return candidate
|
||||
|
||||
@staticmethod
|
||||
def mark_candidate_started(db: Session, candidate_id: str) -> None:
|
||||
"""
|
||||
标记候选开始执行
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
candidate_id: 候选ID
|
||||
"""
|
||||
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
|
||||
if candidate:
|
||||
candidate.status = "pending"
|
||||
candidate.started_at = datetime.now(timezone.utc)
|
||||
# 关键状态更新:立即提交,不使用批量提交
|
||||
# 原因:前端需要实时看到请求开始执行
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_candidate_status(db: Session, candidate_id: str, status: str) -> None:
|
||||
"""
|
||||
更新候选状态(通用方法)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
candidate_id: 候选ID
|
||||
status: 新状态(pending, available, success, failed, skipped)
|
||||
"""
|
||||
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
|
||||
if candidate:
|
||||
candidate.status = status
|
||||
# 如果状态变更为 pending,记录开始时间
|
||||
if status == "pending" and not candidate.started_at:
|
||||
candidate.started_at = datetime.now(timezone.utc)
|
||||
# 立即提交,确保前端能实时看到状态变化
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def mark_candidate_streaming(
|
||||
db: Session,
|
||||
candidate_id: str,
|
||||
status_code: int = 200,
|
||||
concurrent_requests: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
标记候选为流式传输中
|
||||
|
||||
用于流式请求:连接建立成功后,流开始传输时调用。
|
||||
此时请求尚未完成,需要等流传输完毕后再调用 mark_candidate_success。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
candidate_id: 候选ID
|
||||
status_code: HTTP 状态码(通常是 200)
|
||||
concurrent_requests: 并发请求数
|
||||
"""
|
||||
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
|
||||
if candidate:
|
||||
candidate.status = "streaming"
|
||||
candidate.status_code = status_code
|
||||
candidate.concurrent_requests = concurrent_requests
|
||||
# streaming 状态不设置 finished_at,因为请求还在进行中
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def mark_candidate_success(
|
||||
db: Session,
|
||||
candidate_id: str,
|
||||
status_code: int,
|
||||
latency_ms: int,
|
||||
concurrent_requests: Optional[int] = None,
|
||||
extra_data: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
标记候选执行成功
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
candidate_id: 候选ID
|
||||
status_code: HTTP 状态码
|
||||
latency_ms: 延迟(毫秒)
|
||||
concurrent_requests: 并发请求数
|
||||
extra_data: 额外数据
|
||||
"""
|
||||
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
|
||||
if candidate:
|
||||
candidate.status = "success"
|
||||
candidate.status_code = status_code
|
||||
candidate.latency_ms = latency_ms
|
||||
candidate.concurrent_requests = concurrent_requests
|
||||
candidate.finished_at = datetime.now(timezone.utc)
|
||||
if extra_data:
|
||||
candidate.extra_data = {**(candidate.extra_data or {}), **extra_data}
|
||||
# 关键状态更新:立即提交,不使用批量提交
|
||||
# 原因:前端需要实时看到请求成功/失败状态
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def mark_candidate_failed(
|
||||
db: Session,
|
||||
candidate_id: str,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
status_code: Optional[int] = None,
|
||||
latency_ms: Optional[int] = None,
|
||||
concurrent_requests: Optional[int] = None,
|
||||
extra_data: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
标记候选执行失败
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
candidate_id: 候选ID
|
||||
error_type: 错误类型
|
||||
error_message: 错误消息
|
||||
status_code: HTTP 状态码(如果有)
|
||||
latency_ms: 延迟(毫秒)
|
||||
concurrent_requests: 并发请求数
|
||||
extra_data: 额外数据
|
||||
"""
|
||||
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
|
||||
if candidate:
|
||||
candidate.status = "failed"
|
||||
candidate.error_type = error_type
|
||||
candidate.error_message = error_message
|
||||
candidate.status_code = status_code
|
||||
candidate.latency_ms = latency_ms
|
||||
candidate.concurrent_requests = concurrent_requests
|
||||
candidate.finished_at = datetime.now(timezone.utc)
|
||||
if extra_data:
|
||||
candidate.extra_data = {**(candidate.extra_data or {}), **extra_data}
|
||||
# 关键状态更新:立即提交,不使用批量提交
|
||||
# 原因:前端需要实时看到请求成功/失败状态
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def mark_candidate_skipped(
|
||||
db: Session, candidate_id: str, skip_reason: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
标记候选为已跳过
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
candidate_id: 候选ID
|
||||
skip_reason: 跳过原因
|
||||
"""
|
||||
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
|
||||
if candidate:
|
||||
candidate.status = "skipped"
|
||||
candidate.skip_reason = skip_reason
|
||||
candidate.finished_at = datetime.now(timezone.utc)
|
||||
db.flush() # 只 flush,不立即 commit
|
||||
get_batch_committer().mark_dirty(db)
|
||||
|
||||
@staticmethod
|
||||
def get_candidates_by_request_id(db: Session, request_id: str) -> List[RequestCandidate]:
|
||||
"""
|
||||
获取请求的所有候选记录
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
request_id: 请求ID
|
||||
|
||||
Returns:
|
||||
候选记录列表,按 candidate_index 排序
|
||||
"""
|
||||
return (
|
||||
db.query(RequestCandidate)
|
||||
.filter(RequestCandidate.request_id == request_id)
|
||||
.order_by(RequestCandidate.candidate_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_candidate_stats_by_provider(db: Session, provider_id: str, limit: int = 100) -> dict:
|
||||
"""
|
||||
获取 Provider 的候选统计
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider_id: Provider ID
|
||||
limit: 最近记录数量限制
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
candidates = (
|
||||
db.query(RequestCandidate)
|
||||
.filter(RequestCandidate.provider_id == provider_id)
|
||||
.order_by(RequestCandidate.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
total_candidates = len(candidates)
|
||||
success_count = sum(1 for c in candidates if c.status == "success")
|
||||
failed_count = sum(1 for c in candidates if c.status == "failed")
|
||||
skipped_count = sum(1 for c in candidates if c.status == "skipped")
|
||||
pending_count = sum(1 for c in candidates if c.status == "pending")
|
||||
available_count = sum(1 for c in candidates if c.status == "available")
|
||||
|
||||
# 计算失败率(只统计已完成的候选,即成功或失败的)
|
||||
completed_count = success_count + failed_count
|
||||
failure_rate = (failed_count / completed_count * 100) if completed_count > 0 else 0
|
||||
|
||||
return {
|
||||
"total_attempts": total_candidates, # 前端使用 total_attempts 字段
|
||||
"success_count": success_count,
|
||||
"failed_count": failed_count,
|
||||
"skipped_count": skipped_count,
|
||||
"pending_count": pending_count,
|
||||
"available_count": available_count, # 新增:尚未被调度的候选数
|
||||
"failure_rate": round(failure_rate, 2),
|
||||
}
|
||||
193
src/services/request/executor.py
Normal file
193
src/services/request/executor.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
封装请求执行逻辑,包含并发控制与链路追踪。
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.enums import APIFormat
|
||||
from src.core.exceptions import ConcurrencyLimitError
|
||||
from src.core.logger import logger
|
||||
from src.services.health.monitor import health_monitor
|
||||
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
|
||||
from src.services.request.candidate import RequestCandidateService
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
candidate_id: str
|
||||
candidate_index: int
|
||||
provider_id: str
|
||||
endpoint_id: str
|
||||
key_id: str
|
||||
user_id: Optional[str]
|
||||
api_key_id: Optional[str]
|
||||
is_cached_user: bool
|
||||
start_time: Optional[float] = None
|
||||
elapsed_ms: Optional[int] = None
|
||||
concurrent_requests: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
response: Any
|
||||
context: ExecutionContext
|
||||
|
||||
|
||||
class ExecutionError(Exception):
|
||||
def __init__(self, cause: Exception, context: ExecutionContext):
|
||||
super().__init__(str(cause))
|
||||
self.cause = cause
|
||||
self.context = context
|
||||
|
||||
|
||||
class RequestExecutor:
|
||||
def __init__(self, db: Session, concurrency_manager, adaptive_manager):
|
||||
self.db = db
|
||||
self.concurrency_manager = concurrency_manager
|
||||
self.adaptive_manager = adaptive_manager
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
candidate,
|
||||
candidate_id: str,
|
||||
candidate_index: int,
|
||||
user_api_key,
|
||||
request_func: Callable,
|
||||
request_id: Optional[str],
|
||||
api_format: Union[str, APIFormat],
|
||||
model_name: str,
|
||||
is_stream: bool = False,
|
||||
) -> ExecutionResult:
|
||||
provider = candidate.provider
|
||||
endpoint = candidate.endpoint
|
||||
key = candidate.key
|
||||
is_cached_user = bool(candidate.is_cached)
|
||||
|
||||
# 标记候选开始执行
|
||||
RequestCandidateService.mark_candidate_started(
|
||||
db=self.db,
|
||||
candidate_id=candidate_id,
|
||||
)
|
||||
|
||||
context = ExecutionContext(
|
||||
candidate_id=candidate_id,
|
||||
candidate_index=candidate_index,
|
||||
provider_id=provider.id,
|
||||
endpoint_id=endpoint.id,
|
||||
key_id=key.id,
|
||||
user_id=user_api_key.user_id,
|
||||
api_key_id=user_api_key.id,
|
||||
is_cached_user=is_cached_user,
|
||||
)
|
||||
|
||||
try:
|
||||
# 计算动态预留比例
|
||||
reservation_manager = get_adaptive_reservation_manager()
|
||||
# 获取当前并发数用于计算负载
|
||||
try:
|
||||
_, current_key_concurrent = await self.concurrency_manager.get_current_concurrency(
|
||||
endpoint_id=endpoint.id,
|
||||
key_id=key.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"获取并发数失败(用于预留计算): {e}")
|
||||
current_key_concurrent = 0
|
||||
|
||||
# 获取有效的并发限制(自适应或固定)
|
||||
effective_key_limit = (
|
||||
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
|
||||
)
|
||||
|
||||
reservation_result = reservation_manager.calculate_reservation(
|
||||
key=key,
|
||||
current_concurrent=current_key_concurrent,
|
||||
effective_limit=effective_key_limit,
|
||||
)
|
||||
dynamic_reservation_ratio = reservation_result.ratio
|
||||
|
||||
logger.debug(f"[Executor] 动态预留: key={key.id[:8]}..., "
|
||||
f"ratio={dynamic_reservation_ratio:.0%}, phase={reservation_result.phase}, "
|
||||
f"confidence={reservation_result.confidence:.0%}")
|
||||
|
||||
async with self.concurrency_manager.concurrency_guard(
|
||||
endpoint_id=endpoint.id,
|
||||
endpoint_max_concurrent=endpoint.max_concurrent,
|
||||
key_id=key.id,
|
||||
key_max_concurrent=effective_key_limit,
|
||||
is_cached_user=is_cached_user,
|
||||
cache_reservation_ratio=dynamic_reservation_ratio,
|
||||
):
|
||||
try:
|
||||
_, key_concurrent = await self.concurrency_manager.get_current_concurrency(
|
||||
endpoint_id=endpoint.id,
|
||||
key_id=key.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"获取并发数失败(guard 内): {e}")
|
||||
key_concurrent = None
|
||||
|
||||
context.concurrent_requests = key_concurrent
|
||||
context.start_time = time.time()
|
||||
|
||||
response = await request_func(provider, endpoint, key)
|
||||
|
||||
context.elapsed_ms = int((time.time() - context.start_time) * 1000)
|
||||
|
||||
health_monitor.record_success(
|
||||
db=self.db,
|
||||
key_id=key.id,
|
||||
response_time_ms=context.elapsed_ms,
|
||||
)
|
||||
|
||||
# 自适应模式:max_concurrent = NULL
|
||||
if key.max_concurrent is None and key_concurrent is not None:
|
||||
self.adaptive_manager.handle_success(
|
||||
db=self.db,
|
||||
key=key,
|
||||
current_concurrent=key_concurrent,
|
||||
)
|
||||
|
||||
# 根据是否为流式请求,标记不同状态
|
||||
if is_stream:
|
||||
# 流式请求:标记为 streaming 状态
|
||||
# 此时连接已建立但流传输尚未完成
|
||||
# success 状态会在流完成后由 _record_stream_stats 方法标记
|
||||
RequestCandidateService.mark_candidate_streaming(
|
||||
db=self.db,
|
||||
candidate_id=candidate_id,
|
||||
status_code=200,
|
||||
concurrent_requests=key_concurrent,
|
||||
)
|
||||
else:
|
||||
# 非流式请求:标记为 success 状态
|
||||
RequestCandidateService.mark_candidate_success(
|
||||
db=self.db,
|
||||
candidate_id=candidate_id,
|
||||
status_code=200,
|
||||
latency_ms=context.elapsed_ms,
|
||||
concurrent_requests=key_concurrent,
|
||||
extra_data={
|
||||
"is_cached_user": is_cached_user,
|
||||
"model_name": model_name,
|
||||
"api_format": (
|
||||
api_format.value if isinstance(api_format, APIFormat) else api_format
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return ExecutionResult(response=response, context=context)
|
||||
except ConcurrencyLimitError as exc:
|
||||
raise ExecutionError(exc, context) from exc
|
||||
except Exception as exc:
|
||||
context.elapsed_ms = (
|
||||
int((time.time() - context.start_time) * 1000)
|
||||
if context.start_time is not None
|
||||
else None
|
||||
)
|
||||
raise ExecutionError(exc, context) from exc
|
||||
330
src/services/request/result.py
Normal file
330
src/services/request/result.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
统一的请求结果和元数据结构
|
||||
|
||||
设计原则:
|
||||
1. RequestMetadata: 描述请求执行的上下文(Provider、Endpoint、Key、API格式等)
|
||||
2. RequestResult: 封装请求的完整结果(成功/失败、响应、元数据、费用等)
|
||||
3. 确保 api_format 在整个链路中始终可用
|
||||
|
||||
使用场景:
|
||||
- ProviderService 创建 RequestMetadata
|
||||
- FallbackOrchestrator 在异常时补充 RequestMetadata
|
||||
- ChatHandlerBase 使用 RequestResult 记录 Usage
|
||||
- ChatAdapterBase 使用 RequestResult 处理异常响应
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
|
||||
class RequestStatus(Enum):
|
||||
"""请求状态"""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PARTIAL = "partial" # 流式请求部分成功
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestMetadata:
|
||||
"""
|
||||
请求元数据 - 描述请求执行的上下文
|
||||
|
||||
必填字段:
|
||||
- api_format: API 格式,必须在请求开始时就确定
|
||||
- provider: Provider 名称
|
||||
- model: 模型名称
|
||||
|
||||
可选字段:
|
||||
- provider_id, provider_endpoint_id, provider_api_key_id: Provider 追踪信息
|
||||
- provider_request_headers, provider_response_headers: 请求/响应头
|
||||
- attempt_id: 请求尝试 ID
|
||||
- original_model: 用户请求的原始模型名(映射前)
|
||||
"""
|
||||
|
||||
# 必填字段 - 在请求开始时就应该确定
|
||||
api_format: str
|
||||
provider: str = "unknown"
|
||||
model: str = "unknown"
|
||||
|
||||
# Provider 追踪信息
|
||||
provider_id: Optional[str] = None
|
||||
provider_endpoint_id: Optional[str] = None
|
||||
provider_api_key_id: Optional[str] = None
|
||||
|
||||
# 请求/响应头
|
||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||||
provider_response_headers: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# 其他元数据
|
||||
attempt_id: Optional[str] = None
|
||||
original_model: Optional[str] = None # 用户请求的原始模型名(用于价格计算)
|
||||
|
||||
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion)
|
||||
response_metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def with_provider_info(
|
||||
self,
|
||||
provider: str,
|
||||
provider_id: str,
|
||||
provider_endpoint_id: str,
|
||||
provider_api_key_id: str,
|
||||
) -> "RequestMetadata":
|
||||
"""返回包含 Provider 信息的新 RequestMetadata"""
|
||||
return RequestMetadata(
|
||||
api_format=self.api_format,
|
||||
provider=provider,
|
||||
model=self.model,
|
||||
provider_id=provider_id,
|
||||
provider_endpoint_id=provider_endpoint_id,
|
||||
provider_api_key_id=provider_api_key_id,
|
||||
provider_request_headers=self.provider_request_headers,
|
||||
provider_response_headers=self.provider_response_headers,
|
||||
attempt_id=self.attempt_id,
|
||||
original_model=self.original_model,
|
||||
response_metadata=self.response_metadata,
|
||||
)
|
||||
|
||||
def with_response_headers(self, headers: Dict[str, str]) -> "RequestMetadata":
|
||||
"""返回包含响应头的新 RequestMetadata"""
|
||||
return RequestMetadata(
|
||||
api_format=self.api_format,
|
||||
provider=self.provider,
|
||||
model=self.model,
|
||||
provider_id=self.provider_id,
|
||||
provider_endpoint_id=self.provider_endpoint_id,
|
||||
provider_api_key_id=self.provider_api_key_id,
|
||||
provider_request_headers=self.provider_request_headers,
|
||||
provider_response_headers=headers,
|
||||
attempt_id=self.attempt_id,
|
||||
original_model=self.original_model,
|
||||
response_metadata=self.response_metadata,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageInfo:
|
||||
"""Token 使用量信息"""
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CostInfo:
|
||||
"""费用信息"""
|
||||
|
||||
input_cost_usd: float = 0.0
|
||||
output_cost_usd: float = 0.0
|
||||
cache_creation_cost_usd: float = 0.0
|
||||
cache_read_cost_usd: float = 0.0
|
||||
cache_cost_usd: float = 0.0
|
||||
total_cost_usd: float = 0.0
|
||||
|
||||
# 实际费用(乘以 rate_multiplier 后)
|
||||
actual_input_cost_usd: float = 0.0
|
||||
actual_output_cost_usd: float = 0.0
|
||||
actual_cache_creation_cost_usd: float = 0.0
|
||||
actual_cache_read_cost_usd: float = 0.0
|
||||
actual_total_cost_usd: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestResult:
|
||||
"""
|
||||
请求结果 - 封装请求的完整结果
|
||||
|
||||
用于:
|
||||
- 成功请求:包含响应数据、使用量、费用
|
||||
- 失败请求:包含错误信息、状态码
|
||||
- 流式请求:包含流生成器和元数据
|
||||
"""
|
||||
|
||||
# 状态
|
||||
status: RequestStatus
|
||||
|
||||
# 元数据(必须存在)
|
||||
metadata: RequestMetadata
|
||||
|
||||
# 响应相关
|
||||
response_data: Optional[Any] = None # 成功时的响应数据
|
||||
stream: Optional[AsyncIterator[str]] = None # 流式响应
|
||||
|
||||
# 使用量和费用
|
||||
usage: UsageInfo = field(default_factory=UsageInfo)
|
||||
cost: CostInfo = field(default_factory=CostInfo)
|
||||
|
||||
# 错误信息
|
||||
status_code: int = 200
|
||||
error_message: Optional[str] = None
|
||||
error_type: Optional[str] = None
|
||||
|
||||
# 计时
|
||||
response_time_ms: int = 0
|
||||
|
||||
# 请求信息(用于记录)
|
||||
is_stream: bool = False
|
||||
request_headers: Dict[str, str] = field(default_factory=dict)
|
||||
request_body: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
return self.status == RequestStatus.SUCCESS
|
||||
|
||||
@property
|
||||
def is_failed(self) -> bool:
|
||||
return self.status == RequestStatus.FAILED
|
||||
|
||||
@classmethod
|
||||
def success(
|
||||
cls,
|
||||
metadata: RequestMetadata,
|
||||
response_data: Any,
|
||||
usage: UsageInfo,
|
||||
response_time_ms: int,
|
||||
is_stream: bool = False,
|
||||
) -> "RequestResult":
|
||||
"""创建成功的请求结果"""
|
||||
return cls(
|
||||
status=RequestStatus.SUCCESS,
|
||||
metadata=metadata,
|
||||
response_data=response_data,
|
||||
usage=usage,
|
||||
status_code=200,
|
||||
response_time_ms=response_time_ms,
|
||||
is_stream=is_stream,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def failed(
|
||||
cls,
|
||||
metadata: RequestMetadata,
|
||||
status_code: int,
|
||||
error_message: str,
|
||||
error_type: str,
|
||||
response_time_ms: int,
|
||||
is_stream: bool = False,
|
||||
) -> "RequestResult":
|
||||
"""创建失败的请求结果"""
|
||||
return cls(
|
||||
status=RequestStatus.FAILED,
|
||||
metadata=metadata,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
error_type=error_type,
|
||||
response_time_ms=response_time_ms,
|
||||
is_stream=is_stream,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_exception(
|
||||
cls,
|
||||
exception: Exception,
|
||||
api_format: str,
|
||||
model: str,
|
||||
response_time_ms: int,
|
||||
is_stream: bool = False,
|
||||
) -> "RequestResult":
|
||||
"""从异常创建失败的请求结果"""
|
||||
# 尝试从异常中提取 metadata
|
||||
existing_metadata = getattr(exception, "request_metadata", None)
|
||||
|
||||
def get_meta_value(meta, key, default=None):
|
||||
"""从 metadata 中提取值,支持字典和对象两种形式"""
|
||||
if meta is None:
|
||||
return default
|
||||
if isinstance(meta, dict):
|
||||
return meta.get(key, default)
|
||||
return getattr(meta, key, default)
|
||||
|
||||
if existing_metadata:
|
||||
# 如果异常已有 metadata,使用它但确保 api_format 存在
|
||||
metadata = RequestMetadata(
|
||||
api_format=get_meta_value(existing_metadata, "api_format") or api_format,
|
||||
provider=get_meta_value(existing_metadata, "provider", "unknown") or "unknown",
|
||||
model=get_meta_value(existing_metadata, "model", model) or model,
|
||||
provider_id=get_meta_value(existing_metadata, "provider_id"),
|
||||
provider_endpoint_id=get_meta_value(existing_metadata, "provider_endpoint_id"),
|
||||
provider_api_key_id=get_meta_value(existing_metadata, "provider_api_key_id"),
|
||||
provider_request_headers=get_meta_value(existing_metadata, "provider_request_headers", {}),
|
||||
provider_response_headers=get_meta_value(
|
||||
existing_metadata, "provider_response_headers", {}
|
||||
),
|
||||
attempt_id=get_meta_value(existing_metadata, "attempt_id"),
|
||||
original_model=get_meta_value(existing_metadata, "original_model"),
|
||||
response_metadata=get_meta_value(existing_metadata, "response_metadata", {}),
|
||||
)
|
||||
else:
|
||||
# 创建最小的 metadata
|
||||
metadata = RequestMetadata(
|
||||
api_format=api_format,
|
||||
provider="unknown",
|
||||
model=model,
|
||||
)
|
||||
|
||||
# 确定状态码和错误类型
|
||||
from src.core.exceptions import (
|
||||
ProviderAuthException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderRateLimitException,
|
||||
ProviderTimeoutException,
|
||||
)
|
||||
|
||||
if isinstance(exception, ProviderAuthException):
|
||||
status_code = 503
|
||||
error_type = "provider_auth_error"
|
||||
elif isinstance(exception, ProviderRateLimitException):
|
||||
status_code = 429
|
||||
error_type = "rate_limit_exceeded"
|
||||
elif isinstance(exception, ProviderTimeoutException):
|
||||
status_code = 504
|
||||
error_type = "timeout_error"
|
||||
elif isinstance(exception, ProviderNotAvailableException):
|
||||
status_code = 503
|
||||
error_type = "provider_unavailable"
|
||||
else:
|
||||
status_code = 500
|
||||
error_type = "internal_error"
|
||||
|
||||
return cls(
|
||||
status=RequestStatus.FAILED,
|
||||
metadata=metadata,
|
||||
status_code=status_code,
|
||||
error_message=str(exception),
|
||||
error_type=error_type,
|
||||
response_time_ms=response_time_ms,
|
||||
is_stream=is_stream,
|
||||
)
|
||||
|
||||
|
||||
class StreamWithMetadata:
|
||||
"""带元数据的流式响应包装器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: AsyncIterator[str],
|
||||
metadata: RequestMetadata,
|
||||
response_headers_container: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.stream = stream
|
||||
self.metadata = metadata
|
||||
self.response_headers_container = response_headers_container
|
||||
self._metadata_updated = False
|
||||
|
||||
def update_metadata_with_response_headers(self):
|
||||
"""使用实际的响应头更新元数据"""
|
||||
if self.response_headers_container and "headers" in self.response_headers_container:
|
||||
if not self._metadata_updated:
|
||||
self.metadata = self.metadata.with_response_headers(
|
||||
self.response_headers_container["headers"]
|
||||
)
|
||||
self._metadata_updated = True
|
||||
|
||||
def __aiter__(self):
|
||||
return self.stream
|
||||
|
||||
async def __anext__(self):
|
||||
return await self.stream.__anext__()
|
||||
Reference in New Issue
Block a user