Files
Aether/src/services/request/executor.py
2025-12-10 20:52:44 +08:00

194 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
封装请求执行逻辑,包含并发控制与链路追踪。
"""
import time
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
from sqlalchemy.orm import Session
from src.core.enums import APIFormat
from src.core.exceptions import ConcurrencyLimitError
from src.core.logger import logger
from src.services.health.monitor import health_monitor
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
from src.services.request.candidate import RequestCandidateService
@dataclass
class ExecutionContext:
candidate_id: str
candidate_index: int
provider_id: str
endpoint_id: str
key_id: str
user_id: Optional[str]
api_key_id: Optional[str]
is_cached_user: bool
start_time: Optional[float] = None
elapsed_ms: Optional[int] = None
concurrent_requests: Optional[int] = None
@dataclass
class ExecutionResult:
response: Any
context: ExecutionContext
class ExecutionError(Exception):
def __init__(self, cause: Exception, context: ExecutionContext):
super().__init__(str(cause))
self.cause = cause
self.context = context
class RequestExecutor:
def __init__(self, db: Session, concurrency_manager, adaptive_manager):
self.db = db
self.concurrency_manager = concurrency_manager
self.adaptive_manager = adaptive_manager
async def execute(
self,
*,
candidate,
candidate_id: str,
candidate_index: int,
user_api_key,
request_func: Callable,
request_id: Optional[str],
api_format: Union[str, APIFormat],
model_name: str,
is_stream: bool = False,
) -> ExecutionResult:
provider = candidate.provider
endpoint = candidate.endpoint
key = candidate.key
is_cached_user = bool(candidate.is_cached)
# 标记候选开始执行
RequestCandidateService.mark_candidate_started(
db=self.db,
candidate_id=candidate_id,
)
context = ExecutionContext(
candidate_id=candidate_id,
candidate_index=candidate_index,
provider_id=provider.id,
endpoint_id=endpoint.id,
key_id=key.id,
user_id=user_api_key.user_id,
api_key_id=user_api_key.id,
is_cached_user=is_cached_user,
)
try:
# 计算动态预留比例
reservation_manager = get_adaptive_reservation_manager()
# 获取当前并发数用于计算负载
try:
_, current_key_concurrent = await self.concurrency_manager.get_current_concurrency(
endpoint_id=endpoint.id,
key_id=key.id,
)
except Exception as e:
logger.debug(f"获取并发数失败(用于预留计算): {e}")
current_key_concurrent = 0
# 获取有效的并发限制(自适应或固定)
effective_key_limit = (
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
)
reservation_result = reservation_manager.calculate_reservation(
key=key,
current_concurrent=current_key_concurrent,
effective_limit=effective_key_limit,
)
dynamic_reservation_ratio = reservation_result.ratio
logger.debug(f"[Executor] 动态预留: key={key.id[:8]}..., "
f"ratio={dynamic_reservation_ratio:.0%}, phase={reservation_result.phase}, "
f"confidence={reservation_result.confidence:.0%}")
async with self.concurrency_manager.concurrency_guard(
endpoint_id=endpoint.id,
endpoint_max_concurrent=endpoint.max_concurrent,
key_id=key.id,
key_max_concurrent=effective_key_limit,
is_cached_user=is_cached_user,
cache_reservation_ratio=dynamic_reservation_ratio,
):
try:
_, key_concurrent = await self.concurrency_manager.get_current_concurrency(
endpoint_id=endpoint.id,
key_id=key.id,
)
except Exception as e:
logger.debug(f"获取并发数失败guard 内): {e}")
key_concurrent = None
context.concurrent_requests = key_concurrent
context.start_time = time.time()
response = await request_func(provider, endpoint, key)
context.elapsed_ms = int((time.time() - context.start_time) * 1000)
health_monitor.record_success(
db=self.db,
key_id=key.id,
response_time_ms=context.elapsed_ms,
)
# 自适应模式max_concurrent = NULL
if key.max_concurrent is None and key_concurrent is not None:
self.adaptive_manager.handle_success(
db=self.db,
key=key,
current_concurrent=key_concurrent,
)
# 根据是否为流式请求,标记不同状态
if is_stream:
# 流式请求:标记为 streaming 状态
# 此时连接已建立但流传输尚未完成
# success 状态会在流完成后由 _record_stream_stats 方法标记
RequestCandidateService.mark_candidate_streaming(
db=self.db,
candidate_id=candidate_id,
status_code=200,
concurrent_requests=key_concurrent,
)
else:
# 非流式请求:标记为 success 状态
RequestCandidateService.mark_candidate_success(
db=self.db,
candidate_id=candidate_id,
status_code=200,
latency_ms=context.elapsed_ms,
concurrent_requests=key_concurrent,
extra_data={
"is_cached_user": is_cached_user,
"model_name": model_name,
"api_format": (
api_format.value if isinstance(api_format, APIFormat) else api_format
),
},
)
return ExecutionResult(response=response, context=context)
except ConcurrencyLimitError as exc:
raise ExecutionError(exc, context) from exc
except Exception as exc:
context.elapsed_ms = (
int((time.time() - context.start_time) * 1000)
if context.start_time is not None
else None
)
raise ExecutionError(exc, context) from exc