mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 10:42:29 +08:00
194 lines
7.0 KiB
Python
194 lines
7.0 KiB
Python
|
|
"""
|
|||
|
|
封装请求执行逻辑,包含并发控制与链路追踪。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import time
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import Any, Callable, Optional, Union
|
|||
|
|
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from src.core.enums import APIFormat
|
|||
|
|
from src.core.exceptions import ConcurrencyLimitError
|
|||
|
|
from src.core.logger import logger
|
|||
|
|
from src.services.health.monitor import health_monitor
|
|||
|
|
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
|
|||
|
|
from src.services.request.candidate import RequestCandidateService
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ExecutionContext:
|
|||
|
|
candidate_id: str
|
|||
|
|
candidate_index: int
|
|||
|
|
provider_id: str
|
|||
|
|
endpoint_id: str
|
|||
|
|
key_id: str
|
|||
|
|
user_id: Optional[str]
|
|||
|
|
api_key_id: Optional[str]
|
|||
|
|
is_cached_user: bool
|
|||
|
|
start_time: Optional[float] = None
|
|||
|
|
elapsed_ms: Optional[int] = None
|
|||
|
|
concurrent_requests: Optional[int] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ExecutionResult:
|
|||
|
|
response: Any
|
|||
|
|
context: ExecutionContext
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ExecutionError(Exception):
|
|||
|
|
def __init__(self, cause: Exception, context: ExecutionContext):
|
|||
|
|
super().__init__(str(cause))
|
|||
|
|
self.cause = cause
|
|||
|
|
self.context = context
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RequestExecutor:
|
|||
|
|
def __init__(self, db: Session, concurrency_manager, adaptive_manager):
|
|||
|
|
self.db = db
|
|||
|
|
self.concurrency_manager = concurrency_manager
|
|||
|
|
self.adaptive_manager = adaptive_manager
|
|||
|
|
|
|||
|
|
async def execute(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
candidate,
|
|||
|
|
candidate_id: str,
|
|||
|
|
candidate_index: int,
|
|||
|
|
user_api_key,
|
|||
|
|
request_func: Callable,
|
|||
|
|
request_id: Optional[str],
|
|||
|
|
api_format: Union[str, APIFormat],
|
|||
|
|
model_name: str,
|
|||
|
|
is_stream: bool = False,
|
|||
|
|
) -> ExecutionResult:
|
|||
|
|
provider = candidate.provider
|
|||
|
|
endpoint = candidate.endpoint
|
|||
|
|
key = candidate.key
|
|||
|
|
is_cached_user = bool(candidate.is_cached)
|
|||
|
|
|
|||
|
|
# 标记候选开始执行
|
|||
|
|
RequestCandidateService.mark_candidate_started(
|
|||
|
|
db=self.db,
|
|||
|
|
candidate_id=candidate_id,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
context = ExecutionContext(
|
|||
|
|
candidate_id=candidate_id,
|
|||
|
|
candidate_index=candidate_index,
|
|||
|
|
provider_id=provider.id,
|
|||
|
|
endpoint_id=endpoint.id,
|
|||
|
|
key_id=key.id,
|
|||
|
|
user_id=user_api_key.user_id,
|
|||
|
|
api_key_id=user_api_key.id,
|
|||
|
|
is_cached_user=is_cached_user,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 计算动态预留比例
|
|||
|
|
reservation_manager = get_adaptive_reservation_manager()
|
|||
|
|
# 获取当前并发数用于计算负载
|
|||
|
|
try:
|
|||
|
|
_, current_key_concurrent = await self.concurrency_manager.get_current_concurrency(
|
|||
|
|
endpoint_id=endpoint.id,
|
|||
|
|
key_id=key.id,
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.debug(f"获取并发数失败(用于预留计算): {e}")
|
|||
|
|
current_key_concurrent = 0
|
|||
|
|
|
|||
|
|
# 获取有效的并发限制(自适应或固定)
|
|||
|
|
effective_key_limit = (
|
|||
|
|
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
reservation_result = reservation_manager.calculate_reservation(
|
|||
|
|
key=key,
|
|||
|
|
current_concurrent=current_key_concurrent,
|
|||
|
|
effective_limit=effective_key_limit,
|
|||
|
|
)
|
|||
|
|
dynamic_reservation_ratio = reservation_result.ratio
|
|||
|
|
|
|||
|
|
logger.debug(f"[Executor] 动态预留: key={key.id[:8]}..., "
|
|||
|
|
f"ratio={dynamic_reservation_ratio:.0%}, phase={reservation_result.phase}, "
|
|||
|
|
f"confidence={reservation_result.confidence:.0%}")
|
|||
|
|
|
|||
|
|
async with self.concurrency_manager.concurrency_guard(
|
|||
|
|
endpoint_id=endpoint.id,
|
|||
|
|
endpoint_max_concurrent=endpoint.max_concurrent,
|
|||
|
|
key_id=key.id,
|
|||
|
|
key_max_concurrent=effective_key_limit,
|
|||
|
|
is_cached_user=is_cached_user,
|
|||
|
|
cache_reservation_ratio=dynamic_reservation_ratio,
|
|||
|
|
):
|
|||
|
|
try:
|
|||
|
|
_, key_concurrent = await self.concurrency_manager.get_current_concurrency(
|
|||
|
|
endpoint_id=endpoint.id,
|
|||
|
|
key_id=key.id,
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.debug(f"获取并发数失败(guard 内): {e}")
|
|||
|
|
key_concurrent = None
|
|||
|
|
|
|||
|
|
context.concurrent_requests = key_concurrent
|
|||
|
|
context.start_time = time.time()
|
|||
|
|
|
|||
|
|
response = await request_func(provider, endpoint, key)
|
|||
|
|
|
|||
|
|
context.elapsed_ms = int((time.time() - context.start_time) * 1000)
|
|||
|
|
|
|||
|
|
health_monitor.record_success(
|
|||
|
|
db=self.db,
|
|||
|
|
key_id=key.id,
|
|||
|
|
response_time_ms=context.elapsed_ms,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 自适应模式:max_concurrent = NULL
|
|||
|
|
if key.max_concurrent is None and key_concurrent is not None:
|
|||
|
|
self.adaptive_manager.handle_success(
|
|||
|
|
db=self.db,
|
|||
|
|
key=key,
|
|||
|
|
current_concurrent=key_concurrent,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 根据是否为流式请求,标记不同状态
|
|||
|
|
if is_stream:
|
|||
|
|
# 流式请求:标记为 streaming 状态
|
|||
|
|
# 此时连接已建立但流传输尚未完成
|
|||
|
|
# success 状态会在流完成后由 _record_stream_stats 方法标记
|
|||
|
|
RequestCandidateService.mark_candidate_streaming(
|
|||
|
|
db=self.db,
|
|||
|
|
candidate_id=candidate_id,
|
|||
|
|
status_code=200,
|
|||
|
|
concurrent_requests=key_concurrent,
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
# 非流式请求:标记为 success 状态
|
|||
|
|
RequestCandidateService.mark_candidate_success(
|
|||
|
|
db=self.db,
|
|||
|
|
candidate_id=candidate_id,
|
|||
|
|
status_code=200,
|
|||
|
|
latency_ms=context.elapsed_ms,
|
|||
|
|
concurrent_requests=key_concurrent,
|
|||
|
|
extra_data={
|
|||
|
|
"is_cached_user": is_cached_user,
|
|||
|
|
"model_name": model_name,
|
|||
|
|
"api_format": (
|
|||
|
|
api_format.value if isinstance(api_format, APIFormat) else api_format
|
|||
|
|
),
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return ExecutionResult(response=response, context=context)
|
|||
|
|
except ConcurrencyLimitError as exc:
|
|||
|
|
raise ExecutionError(exc, context) from exc
|
|||
|
|
except Exception as exc:
|
|||
|
|
context.elapsed_ms = (
|
|||
|
|
int((time.time() - context.start_time) * 1000)
|
|||
|
|
if context.start_time is not None
|
|||
|
|
else None
|
|||
|
|
)
|
|||
|
|
raise ExecutionError(exc, context) from exc
|