Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
"""
请求处理服务模块
包含候选选择、执行等功能。
注意:
- RequestBuilder 已移至 src.api.handlers.base.request_builder请直接从该模块导入
- record_failed_request 已移至 src.services.usage.recorder请直接从该模块导入
"""
from src.services.request.candidate import RequestCandidateService
from src.services.request.executor import RequestExecutor
__all__ = [
"RequestCandidateService",
"RequestExecutor",
]

View File

@@ -0,0 +1,291 @@
"""
请求候选记录服务 - 管理候选队列
"""
import uuid
from datetime import datetime, timezone
from typing import List, Optional
from sqlalchemy.orm import Session
from src.core.batch_committer import get_batch_committer
from src.models.database import RequestCandidate
class RequestCandidateService:
"""请求候选记录服务"""
@staticmethod
def create_candidate(
db: Session,
request_id: str,
candidate_index: int,
retry_index: int = 0, # 新增:重试序号
user_id: Optional[str] = None,
api_key_id: Optional[str] = None,
provider_id: Optional[str] = None,
endpoint_id: Optional[str] = None,
key_id: Optional[str] = None,
status: str = "available",
skip_reason: Optional[str] = None,
is_cached: bool = False,
extra_data: Optional[dict] = None,
required_capabilities: Optional[dict] = None,
) -> RequestCandidate:
"""
创建候选记录
Args:
db: 数据库会话
request_id: 请求ID
candidate_index: 候选序号
retry_index: 重试序号从0开始
user_id: 用户ID
api_key_id: API Key ID
provider_id: Provider ID
endpoint_id: Endpoint ID
key_id: API Key ID
status: 候选状态 ('available', 'used', 'skipped', 'success', 'failed')
skip_reason: 跳过原因
is_cached: 是否为缓存亲和性候选
extra_data: 额外数据
required_capabilities: 请求需要的能力标签
"""
candidate = RequestCandidate(
id=str(uuid.uuid4()),
request_id=request_id,
candidate_index=candidate_index,
retry_index=retry_index, # 新增
user_id=user_id,
api_key_id=api_key_id,
provider_id=provider_id,
endpoint_id=endpoint_id,
key_id=key_id,
status=status,
skip_reason=skip_reason,
is_cached=is_cached,
extra_data=extra_data or {},
required_capabilities=required_capabilities,
created_at=datetime.now(timezone.utc),
)
db.add(candidate)
db.flush() # 只flush不立即 commit
# 标记为批量提交(非关键数据,可延迟)
get_batch_committer().mark_dirty(db)
return candidate
@staticmethod
def mark_candidate_started(db: Session, candidate_id: str) -> None:
"""
标记候选开始执行
Args:
db: 数据库会话
candidate_id: 候选ID
"""
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
if candidate:
candidate.status = "pending"
candidate.started_at = datetime.now(timezone.utc)
# 关键状态更新:立即提交,不使用批量提交
# 原因:前端需要实时看到请求开始执行
db.commit()
@staticmethod
def update_candidate_status(db: Session, candidate_id: str, status: str) -> None:
"""
更新候选状态(通用方法)
Args:
db: 数据库会话
candidate_id: 候选ID
status: 新状态pending, available, success, failed, skipped
"""
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
if candidate:
candidate.status = status
# 如果状态变更为 pending记录开始时间
if status == "pending" and not candidate.started_at:
candidate.started_at = datetime.now(timezone.utc)
# 立即提交,确保前端能实时看到状态变化
db.commit()
@staticmethod
def mark_candidate_streaming(
db: Session,
candidate_id: str,
status_code: int = 200,
concurrent_requests: Optional[int] = None,
) -> None:
"""
标记候选为流式传输中
用于流式请求:连接建立成功后,流开始传输时调用。
此时请求尚未完成,需要等流传输完毕后再调用 mark_candidate_success。
Args:
db: 数据库会话
candidate_id: 候选ID
status_code: HTTP 状态码(通常是 200
concurrent_requests: 并发请求数
"""
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
if candidate:
candidate.status = "streaming"
candidate.status_code = status_code
candidate.concurrent_requests = concurrent_requests
# streaming 状态不设置 finished_at因为请求还在进行中
db.commit()
@staticmethod
def mark_candidate_success(
db: Session,
candidate_id: str,
status_code: int,
latency_ms: int,
concurrent_requests: Optional[int] = None,
extra_data: Optional[dict] = None,
) -> None:
"""
标记候选执行成功
Args:
db: 数据库会话
candidate_id: 候选ID
status_code: HTTP 状态码
latency_ms: 延迟(毫秒)
concurrent_requests: 并发请求数
extra_data: 额外数据
"""
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
if candidate:
candidate.status = "success"
candidate.status_code = status_code
candidate.latency_ms = latency_ms
candidate.concurrent_requests = concurrent_requests
candidate.finished_at = datetime.now(timezone.utc)
if extra_data:
candidate.extra_data = {**(candidate.extra_data or {}), **extra_data}
# 关键状态更新:立即提交,不使用批量提交
# 原因:前端需要实时看到请求成功/失败状态
db.commit()
@staticmethod
def mark_candidate_failed(
db: Session,
candidate_id: str,
error_type: str,
error_message: str,
status_code: Optional[int] = None,
latency_ms: Optional[int] = None,
concurrent_requests: Optional[int] = None,
extra_data: Optional[dict] = None,
) -> None:
"""
标记候选执行失败
Args:
db: 数据库会话
candidate_id: 候选ID
error_type: 错误类型
error_message: 错误消息
status_code: HTTP 状态码(如果有)
latency_ms: 延迟(毫秒)
concurrent_requests: 并发请求数
extra_data: 额外数据
"""
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
if candidate:
candidate.status = "failed"
candidate.error_type = error_type
candidate.error_message = error_message
candidate.status_code = status_code
candidate.latency_ms = latency_ms
candidate.concurrent_requests = concurrent_requests
candidate.finished_at = datetime.now(timezone.utc)
if extra_data:
candidate.extra_data = {**(candidate.extra_data or {}), **extra_data}
# 关键状态更新:立即提交,不使用批量提交
# 原因:前端需要实时看到请求成功/失败状态
db.commit()
@staticmethod
def mark_candidate_skipped(
db: Session, candidate_id: str, skip_reason: Optional[str] = None
) -> None:
"""
标记候选为已跳过
Args:
db: 数据库会话
candidate_id: 候选ID
skip_reason: 跳过原因
"""
candidate = db.query(RequestCandidate).filter(RequestCandidate.id == candidate_id).first()
if candidate:
candidate.status = "skipped"
candidate.skip_reason = skip_reason
candidate.finished_at = datetime.now(timezone.utc)
db.flush() # 只 flush不立即 commit
get_batch_committer().mark_dirty(db)
@staticmethod
def get_candidates_by_request_id(db: Session, request_id: str) -> List[RequestCandidate]:
"""
获取请求的所有候选记录
Args:
db: 数据库会话
request_id: 请求ID
Returns:
候选记录列表,按 candidate_index 排序
"""
return (
db.query(RequestCandidate)
.filter(RequestCandidate.request_id == request_id)
.order_by(RequestCandidate.candidate_index)
.all()
)
@staticmethod
def get_candidate_stats_by_provider(db: Session, provider_id: str, limit: int = 100) -> dict:
"""
获取 Provider 的候选统计
Args:
db: 数据库会话
provider_id: Provider ID
limit: 最近记录数量限制
Returns:
统计信息字典
"""
candidates = (
db.query(RequestCandidate)
.filter(RequestCandidate.provider_id == provider_id)
.order_by(RequestCandidate.created_at.desc())
.limit(limit)
.all()
)
total_candidates = len(candidates)
success_count = sum(1 for c in candidates if c.status == "success")
failed_count = sum(1 for c in candidates if c.status == "failed")
skipped_count = sum(1 for c in candidates if c.status == "skipped")
pending_count = sum(1 for c in candidates if c.status == "pending")
available_count = sum(1 for c in candidates if c.status == "available")
# 计算失败率(只统计已完成的候选,即成功或失败的)
completed_count = success_count + failed_count
failure_rate = (failed_count / completed_count * 100) if completed_count > 0 else 0
return {
"total_attempts": total_candidates, # 前端使用 total_attempts 字段
"success_count": success_count,
"failed_count": failed_count,
"skipped_count": skipped_count,
"pending_count": pending_count,
"available_count": available_count, # 新增:尚未被调度的候选数
"failure_rate": round(failure_rate, 2),
}

View File

@@ -0,0 +1,193 @@
"""
封装请求执行逻辑,包含并发控制与链路追踪。
"""
import time
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
from sqlalchemy.orm import Session
from src.core.enums import APIFormat
from src.core.exceptions import ConcurrencyLimitError
from src.core.logger import logger
from src.services.health.monitor import health_monitor
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
from src.services.request.candidate import RequestCandidateService
@dataclass
class ExecutionContext:
candidate_id: str
candidate_index: int
provider_id: str
endpoint_id: str
key_id: str
user_id: Optional[str]
api_key_id: Optional[str]
is_cached_user: bool
start_time: Optional[float] = None
elapsed_ms: Optional[int] = None
concurrent_requests: Optional[int] = None
@dataclass
class ExecutionResult:
response: Any
context: ExecutionContext
class ExecutionError(Exception):
def __init__(self, cause: Exception, context: ExecutionContext):
super().__init__(str(cause))
self.cause = cause
self.context = context
class RequestExecutor:
def __init__(self, db: Session, concurrency_manager, adaptive_manager):
self.db = db
self.concurrency_manager = concurrency_manager
self.adaptive_manager = adaptive_manager
async def execute(
self,
*,
candidate,
candidate_id: str,
candidate_index: int,
user_api_key,
request_func: Callable,
request_id: Optional[str],
api_format: Union[str, APIFormat],
model_name: str,
is_stream: bool = False,
) -> ExecutionResult:
provider = candidate.provider
endpoint = candidate.endpoint
key = candidate.key
is_cached_user = bool(candidate.is_cached)
# 标记候选开始执行
RequestCandidateService.mark_candidate_started(
db=self.db,
candidate_id=candidate_id,
)
context = ExecutionContext(
candidate_id=candidate_id,
candidate_index=candidate_index,
provider_id=provider.id,
endpoint_id=endpoint.id,
key_id=key.id,
user_id=user_api_key.user_id,
api_key_id=user_api_key.id,
is_cached_user=is_cached_user,
)
try:
# 计算动态预留比例
reservation_manager = get_adaptive_reservation_manager()
# 获取当前并发数用于计算负载
try:
_, current_key_concurrent = await self.concurrency_manager.get_current_concurrency(
endpoint_id=endpoint.id,
key_id=key.id,
)
except Exception as e:
logger.debug(f"获取并发数失败(用于预留计算): {e}")
current_key_concurrent = 0
# 获取有效的并发限制(自适应或固定)
effective_key_limit = (
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
)
reservation_result = reservation_manager.calculate_reservation(
key=key,
current_concurrent=current_key_concurrent,
effective_limit=effective_key_limit,
)
dynamic_reservation_ratio = reservation_result.ratio
logger.debug(f"[Executor] 动态预留: key={key.id[:8]}..., "
f"ratio={dynamic_reservation_ratio:.0%}, phase={reservation_result.phase}, "
f"confidence={reservation_result.confidence:.0%}")
async with self.concurrency_manager.concurrency_guard(
endpoint_id=endpoint.id,
endpoint_max_concurrent=endpoint.max_concurrent,
key_id=key.id,
key_max_concurrent=effective_key_limit,
is_cached_user=is_cached_user,
cache_reservation_ratio=dynamic_reservation_ratio,
):
try:
_, key_concurrent = await self.concurrency_manager.get_current_concurrency(
endpoint_id=endpoint.id,
key_id=key.id,
)
except Exception as e:
logger.debug(f"获取并发数失败guard 内): {e}")
key_concurrent = None
context.concurrent_requests = key_concurrent
context.start_time = time.time()
response = await request_func(provider, endpoint, key)
context.elapsed_ms = int((time.time() - context.start_time) * 1000)
health_monitor.record_success(
db=self.db,
key_id=key.id,
response_time_ms=context.elapsed_ms,
)
# 自适应模式max_concurrent = NULL
if key.max_concurrent is None and key_concurrent is not None:
self.adaptive_manager.handle_success(
db=self.db,
key=key,
current_concurrent=key_concurrent,
)
# 根据是否为流式请求,标记不同状态
if is_stream:
# 流式请求:标记为 streaming 状态
# 此时连接已建立但流传输尚未完成
# success 状态会在流完成后由 _record_stream_stats 方法标记
RequestCandidateService.mark_candidate_streaming(
db=self.db,
candidate_id=candidate_id,
status_code=200,
concurrent_requests=key_concurrent,
)
else:
# 非流式请求:标记为 success 状态
RequestCandidateService.mark_candidate_success(
db=self.db,
candidate_id=candidate_id,
status_code=200,
latency_ms=context.elapsed_ms,
concurrent_requests=key_concurrent,
extra_data={
"is_cached_user": is_cached_user,
"model_name": model_name,
"api_format": (
api_format.value if isinstance(api_format, APIFormat) else api_format
),
},
)
return ExecutionResult(response=response, context=context)
except ConcurrencyLimitError as exc:
raise ExecutionError(exc, context) from exc
except Exception as exc:
context.elapsed_ms = (
int((time.time() - context.start_time) * 1000)
if context.start_time is not None
else None
)
raise ExecutionError(exc, context) from exc

View File

@@ -0,0 +1,330 @@
"""
统一的请求结果和元数据结构
设计原则:
1. RequestMetadata: 描述请求执行的上下文Provider、Endpoint、Key、API格式等
2. RequestResult: 封装请求的完整结果(成功/失败、响应、元数据、费用等)
3. 确保 api_format 在整个链路中始终可用
使用场景:
- ProviderService 创建 RequestMetadata
- FallbackOrchestrator 在异常时补充 RequestMetadata
- ChatHandlerBase 使用 RequestResult 记录 Usage
- ChatAdapterBase 使用 RequestResult 处理异常响应
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, AsyncIterator, Dict, Optional
class RequestStatus(Enum):
"""请求状态"""
SUCCESS = "success"
FAILED = "failed"
PARTIAL = "partial" # 流式请求部分成功
@dataclass
class RequestMetadata:
"""
请求元数据 - 描述请求执行的上下文
必填字段:
- api_format: API 格式,必须在请求开始时就确定
- provider: Provider 名称
- model: 模型名称
可选字段:
- provider_id, provider_endpoint_id, provider_api_key_id: Provider 追踪信息
- provider_request_headers, provider_response_headers: 请求/响应头
- attempt_id: 请求尝试 ID
- original_model: 用户请求的原始模型名(映射前)
"""
# 必填字段 - 在请求开始时就应该确定
api_format: str
provider: str = "unknown"
model: str = "unknown"
# Provider 追踪信息
provider_id: Optional[str] = None
provider_endpoint_id: Optional[str] = None
provider_api_key_id: Optional[str] = None
# 请求/响应头
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_response_headers: Dict[str, str] = field(default_factory=dict)
# 其他元数据
attempt_id: Optional[str] = None
original_model: Optional[str] = None # 用户请求的原始模型名(用于价格计算)
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion
response_metadata: Dict[str, Any] = field(default_factory=dict)
def with_provider_info(
self,
provider: str,
provider_id: str,
provider_endpoint_id: str,
provider_api_key_id: str,
) -> "RequestMetadata":
"""返回包含 Provider 信息的新 RequestMetadata"""
return RequestMetadata(
api_format=self.api_format,
provider=provider,
model=self.model,
provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id,
provider_request_headers=self.provider_request_headers,
provider_response_headers=self.provider_response_headers,
attempt_id=self.attempt_id,
original_model=self.original_model,
response_metadata=self.response_metadata,
)
def with_response_headers(self, headers: Dict[str, str]) -> "RequestMetadata":
"""返回包含响应头的新 RequestMetadata"""
return RequestMetadata(
api_format=self.api_format,
provider=self.provider,
model=self.model,
provider_id=self.provider_id,
provider_endpoint_id=self.provider_endpoint_id,
provider_api_key_id=self.provider_api_key_id,
provider_request_headers=self.provider_request_headers,
provider_response_headers=headers,
attempt_id=self.attempt_id,
original_model=self.original_model,
response_metadata=self.response_metadata,
)
@dataclass
class UsageInfo:
"""Token 使用量信息"""
input_tokens: int = 0
output_tokens: int = 0
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
@dataclass
class CostInfo:
"""费用信息"""
input_cost_usd: float = 0.0
output_cost_usd: float = 0.0
cache_creation_cost_usd: float = 0.0
cache_read_cost_usd: float = 0.0
cache_cost_usd: float = 0.0
total_cost_usd: float = 0.0
# 实际费用(乘以 rate_multiplier 后)
actual_input_cost_usd: float = 0.0
actual_output_cost_usd: float = 0.0
actual_cache_creation_cost_usd: float = 0.0
actual_cache_read_cost_usd: float = 0.0
actual_total_cost_usd: float = 0.0
@dataclass
class RequestResult:
"""
请求结果 - 封装请求的完整结果
用于:
- 成功请求:包含响应数据、使用量、费用
- 失败请求:包含错误信息、状态码
- 流式请求:包含流生成器和元数据
"""
# 状态
status: RequestStatus
# 元数据(必须存在)
metadata: RequestMetadata
# 响应相关
response_data: Optional[Any] = None # 成功时的响应数据
stream: Optional[AsyncIterator[str]] = None # 流式响应
# 使用量和费用
usage: UsageInfo = field(default_factory=UsageInfo)
cost: CostInfo = field(default_factory=CostInfo)
# 错误信息
status_code: int = 200
error_message: Optional[str] = None
error_type: Optional[str] = None
# 计时
response_time_ms: int = 0
# 请求信息(用于记录)
is_stream: bool = False
request_headers: Dict[str, str] = field(default_factory=dict)
request_body: Dict[str, Any] = field(default_factory=dict)
@property
def is_success(self) -> bool:
return self.status == RequestStatus.SUCCESS
@property
def is_failed(self) -> bool:
return self.status == RequestStatus.FAILED
@classmethod
def success(
cls,
metadata: RequestMetadata,
response_data: Any,
usage: UsageInfo,
response_time_ms: int,
is_stream: bool = False,
) -> "RequestResult":
"""创建成功的请求结果"""
return cls(
status=RequestStatus.SUCCESS,
metadata=metadata,
response_data=response_data,
usage=usage,
status_code=200,
response_time_ms=response_time_ms,
is_stream=is_stream,
)
@classmethod
def failed(
cls,
metadata: RequestMetadata,
status_code: int,
error_message: str,
error_type: str,
response_time_ms: int,
is_stream: bool = False,
) -> "RequestResult":
"""创建失败的请求结果"""
return cls(
status=RequestStatus.FAILED,
metadata=metadata,
status_code=status_code,
error_message=error_message,
error_type=error_type,
response_time_ms=response_time_ms,
is_stream=is_stream,
)
@classmethod
def from_exception(
cls,
exception: Exception,
api_format: str,
model: str,
response_time_ms: int,
is_stream: bool = False,
) -> "RequestResult":
"""从异常创建失败的请求结果"""
# 尝试从异常中提取 metadata
existing_metadata = getattr(exception, "request_metadata", None)
def get_meta_value(meta, key, default=None):
"""从 metadata 中提取值,支持字典和对象两种形式"""
if meta is None:
return default
if isinstance(meta, dict):
return meta.get(key, default)
return getattr(meta, key, default)
if existing_metadata:
# 如果异常已有 metadata使用它但确保 api_format 存在
metadata = RequestMetadata(
api_format=get_meta_value(existing_metadata, "api_format") or api_format,
provider=get_meta_value(existing_metadata, "provider", "unknown") or "unknown",
model=get_meta_value(existing_metadata, "model", model) or model,
provider_id=get_meta_value(existing_metadata, "provider_id"),
provider_endpoint_id=get_meta_value(existing_metadata, "provider_endpoint_id"),
provider_api_key_id=get_meta_value(existing_metadata, "provider_api_key_id"),
provider_request_headers=get_meta_value(existing_metadata, "provider_request_headers", {}),
provider_response_headers=get_meta_value(
existing_metadata, "provider_response_headers", {}
),
attempt_id=get_meta_value(existing_metadata, "attempt_id"),
original_model=get_meta_value(existing_metadata, "original_model"),
response_metadata=get_meta_value(existing_metadata, "response_metadata", {}),
)
else:
# 创建最小的 metadata
metadata = RequestMetadata(
api_format=api_format,
provider="unknown",
model=model,
)
# 确定状态码和错误类型
from src.core.exceptions import (
ProviderAuthException,
ProviderNotAvailableException,
ProviderRateLimitException,
ProviderTimeoutException,
)
if isinstance(exception, ProviderAuthException):
status_code = 503
error_type = "provider_auth_error"
elif isinstance(exception, ProviderRateLimitException):
status_code = 429
error_type = "rate_limit_exceeded"
elif isinstance(exception, ProviderTimeoutException):
status_code = 504
error_type = "timeout_error"
elif isinstance(exception, ProviderNotAvailableException):
status_code = 503
error_type = "provider_unavailable"
else:
status_code = 500
error_type = "internal_error"
return cls(
status=RequestStatus.FAILED,
metadata=metadata,
status_code=status_code,
error_message=str(exception),
error_type=error_type,
response_time_ms=response_time_ms,
is_stream=is_stream,
)
class StreamWithMetadata:
"""带元数据的流式响应包装器"""
def __init__(
self,
stream: AsyncIterator[str],
metadata: RequestMetadata,
response_headers_container: Optional[Dict[str, Any]] = None,
):
self.stream = stream
self.metadata = metadata
self.response_headers_container = response_headers_container
self._metadata_updated = False
def update_metadata_with_response_headers(self):
"""使用实际的响应头更新元数据"""
if self.response_headers_container and "headers" in self.response_headers_container:
if not self._metadata_updated:
self.metadata = self.metadata.with_response_headers(
self.response_headers_container["headers"]
)
self._metadata_updated = True
def __aiter__(self):
return self.stream
async def __anext__(self):
return await self.stream.__anext__()