mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
1253 lines
51 KiB
Python
1253 lines
51 KiB
Python
|
|
"""
|
|||
|
|
通用端点检查执行器(adapter check_endpoint 复用)
|
|||
|
|
|
|||
|
|
目标:
|
|||
|
|
- 统一日志输出格式
|
|||
|
|
- 统一错误处理逻辑
|
|||
|
|
- 将适配器差异收敛到:URL / headers / body 构建
|
|||
|
|
- 集成用量统计和费用计算
|
|||
|
|
|
|||
|
|
重构架构 - 分离关注点:
|
|||
|
|
- HttpRequestExecutor: 专门负责HTTP请求执行
|
|||
|
|
- UsageCalculator: 专门负责Token计数和费用计算
|
|||
|
|
- ErrorHandler: 统一错误处理
|
|||
|
|
- EndpointCheckOrchestrator: 协调整个流程
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import Any, AsyncIterator, Dict, Iterable, Optional, Union, List
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
import time
|
|||
|
|
import uuid
|
|||
|
|
import json
|
|||
|
|
from functools import lru_cache
|
|||
|
|
import asyncio
|
|||
|
|
from collections import defaultdict
|
|||
|
|
|
|||
|
|
import httpx
|
|||
|
|
|
|||
|
|
from src.core.logger import logger
|
|||
|
|
|
|||
|
|
|
|||
|
|
_SENSITIVE_HEADER_KEYS = {
|
|||
|
|
"authorization",
|
|||
|
|
"x-api-key",
|
|||
|
|
"x-goog-api-key",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _redact_headers(headers: Dict[str, str]) -> Dict[str, str]:
|
|||
|
|
redacted: Dict[str, str] = {}
|
|||
|
|
for key, value in headers.items():
|
|||
|
|
if key.lower() in _SENSITIVE_HEADER_KEYS:
|
|||
|
|
redacted[key] = "***"
|
|||
|
|
else:
|
|||
|
|
redacted[key] = value
|
|||
|
|
return redacted
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _truncate_repr(value: Any, limit: int = 1200) -> str:
|
|||
|
|
try:
|
|||
|
|
text = repr(value)
|
|||
|
|
except Exception:
|
|||
|
|
text = f"<unreprable {type(value)!r}>"
|
|||
|
|
if len(text) > limit:
|
|||
|
|
return text[:limit] + "...(truncated)"
|
|||
|
|
return text
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_safe_headers(
|
|||
|
|
base_headers: Dict[str, str],
|
|||
|
|
extra_headers: Optional[Dict[str, str]],
|
|||
|
|
protected_keys: Iterable[str],
|
|||
|
|
) -> Dict[str, str]:
|
|||
|
|
"""
|
|||
|
|
合并 extra_headers,但防止覆盖 protected_keys(大小写不敏感)。
|
|||
|
|
"""
|
|||
|
|
headers = dict(base_headers)
|
|||
|
|
if not extra_headers:
|
|||
|
|
return headers
|
|||
|
|
|
|||
|
|
protected = {k.lower() for k in protected_keys}
|
|||
|
|
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() not in protected}
|
|||
|
|
headers.update(safe_headers)
|
|||
|
|
return headers
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 保持向后兼容的run_endpoint_check函数(使用新架构)
|
|||
|
|
async def run_endpoint_check(
|
|||
|
|
*,
|
|||
|
|
client: httpx.AsyncClient, # 保持兼容性,但内部不使用
|
|||
|
|
url: str,
|
|||
|
|
headers: Dict[str, str],
|
|||
|
|
json_body: Dict[str, Any],
|
|||
|
|
api_format: str,
|
|||
|
|
provider_name: Optional[str] = None,
|
|||
|
|
model_name: Optional[str] = None,
|
|||
|
|
api_key_id: Optional[str] = None,
|
|||
|
|
provider_id: Optional[str] = None,
|
|||
|
|
db: Optional[Any] = None, # Session对象,需要时才导入
|
|||
|
|
user: Optional[Any] = None, # User对象
|
|||
|
|
) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
执行端点检查(重构版本,使用新的架构):
|
|||
|
|
- 使用新的架构类来分离关注点
|
|||
|
|
- 保持与现有代码的兼容性
|
|||
|
|
- 强制用量统计和费用计算(测试功能必需)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 创建端点检查请求对象
|
|||
|
|
request = EndpointCheckRequest(
|
|||
|
|
url=url,
|
|||
|
|
headers=headers,
|
|||
|
|
json_body=json_body,
|
|||
|
|
api_format=api_format,
|
|||
|
|
provider_name=provider_name,
|
|||
|
|
model_name=model_name,
|
|||
|
|
api_key_id=api_key_id,
|
|||
|
|
provider_id=provider_id,
|
|||
|
|
db=db,
|
|||
|
|
user=user,
|
|||
|
|
request_id=str(uuid.uuid4())[:8]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 使用协调器执行检查
|
|||
|
|
orchestrator = EndpointCheckOrchestrator()
|
|||
|
|
result = await orchestrator.execute_check(request)
|
|||
|
|
|
|||
|
|
# 转换为原有的响应格式以保持兼容性
|
|||
|
|
response_data = {
|
|||
|
|
"status_code": result.status_code,
|
|||
|
|
"headers": result.headers,
|
|||
|
|
"response_time_ms": result.response_time_ms,
|
|||
|
|
"request_id": result.request_id,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if result.response_data:
|
|||
|
|
response_data["response"] = result.response_data
|
|||
|
|
|
|||
|
|
if result.error_message:
|
|||
|
|
response_data["error"] = result.error_message
|
|||
|
|
|
|||
|
|
if result.usage_data:
|
|||
|
|
response_data["usage"] = result.usage_data
|
|||
|
|
|
|||
|
|
return response_data
|
|||
|
|
|
|||
|
|
async def _calculate_and_record_usage(
|
|||
|
|
*,
|
|||
|
|
db: Any,
|
|||
|
|
user: Any,
|
|||
|
|
provider_name: str,
|
|||
|
|
provider_id: str,
|
|||
|
|
api_key_id: str,
|
|||
|
|
model_name: str,
|
|||
|
|
request_data: Dict[str, Any],
|
|||
|
|
response_data: Optional[Dict[str, Any]],
|
|||
|
|
request_id: str,
|
|||
|
|
response_time_ms: int,
|
|||
|
|
request_headers: Dict[str, str],
|
|||
|
|
response_headers: Optional[Dict[str, str]] = None,
|
|||
|
|
status_code: int = 0,
|
|||
|
|
error_message: Optional[str] = None,
|
|||
|
|
# 新增:支持直接传递token数据
|
|||
|
|
input_tokens: Optional[int] = None,
|
|||
|
|
output_tokens: Optional[int] = None,
|
|||
|
|
cache_creation_input_tokens: Optional[int] = None,
|
|||
|
|
cache_read_input_tokens: Optional[int] = None,
|
|||
|
|
api_format: Optional[str] = None,
|
|||
|
|
) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
计算并记录用量数据(遗留函数)
|
|||
|
|
|
|||
|
|
注意:这是测试请求,使用的是Provider的API Key,但用量记录关联到执行测试的用户
|
|||
|
|
这是重构过程中的遗留函数,保持向后兼容性
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Dict包含用量统计信息
|
|||
|
|
"""
|
|||
|
|
from src.services.usage.service import UsageService
|
|||
|
|
from src.services.request.candidate import RequestCandidateService
|
|||
|
|
from src.models.database import ApiKey, ProviderAPIKey, ProviderEndpoint
|
|||
|
|
|
|||
|
|
# 获取Provider API Key对象(不是用户API Key)
|
|||
|
|
provider_api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
|
|||
|
|
if not provider_api_key:
|
|||
|
|
logger.warning(f"Provider API Key not found for usage calculation: {api_key_id}")
|
|||
|
|
return {"error": "Provider API Key not found"}
|
|||
|
|
|
|||
|
|
# 获取Provider Endpoint信息
|
|||
|
|
provider_endpoint = None
|
|||
|
|
if provider_api_key.endpoint_id:
|
|||
|
|
provider_endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == provider_api_key.endpoint_id).first()
|
|||
|
|
|
|||
|
|
# 获取用户的API Key(用于记录关联,即使实际使用的是Provider API Key)
|
|||
|
|
user_api_key = None
|
|||
|
|
if user:
|
|||
|
|
try:
|
|||
|
|
user_api_key = db.query(ApiKey).filter(ApiKey.user_id == user.id).first()
|
|||
|
|
logger.info(f"[endpoint_check] User API Key found: {user_api_key.id if user_api_key else None}")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"[endpoint_check] Failed to get user API Key: {e}")
|
|||
|
|
user_api_key = None
|
|||
|
|
|
|||
|
|
# 注意:测试请求使用Provider的API Key,但用量记录关联到执行测试的用户
|
|||
|
|
# 用量记录会关联到执行测试的用户,但实际的API调用使用Provider的配置
|
|||
|
|
|
|||
|
|
# Token计数 - 优先使用直接传递的数据,否则使用原有逻辑
|
|||
|
|
if input_tokens is None or output_tokens is None or cache_creation_input_tokens is None or cache_read_input_tokens is None:
|
|||
|
|
# 使用原有逻辑计算token
|
|||
|
|
logger.info(f"[endpoint_check] Calculating tokens from response data")
|
|||
|
|
|
|||
|
|
# 直接从响应中提取usage信息(优先级最高)
|
|||
|
|
usage_info = {}
|
|||
|
|
if response_data and isinstance(response_data, dict):
|
|||
|
|
usage_info = response_data.get("usage", {})
|
|||
|
|
|
|||
|
|
if not api_format:
|
|||
|
|
api_format = "OPENAI"
|
|||
|
|
|
|||
|
|
logger.info(f"[endpoint_check] Detected API format: {api_format}")
|
|||
|
|
|
|||
|
|
if usage_info:
|
|||
|
|
logger.info(f"[endpoint_check] Found usage field in response: {usage_info}")
|
|||
|
|
# 使用提取函数获取token数据
|
|||
|
|
api_identifier = provider_name # 在这个旧函数中,我们只能使用provider_name
|
|||
|
|
extracted_input, extracted_output, extracted_cache_creation, extracted_cache_read = \
|
|||
|
|
_extract_tokens_from_response(api_identifier, response_data)
|
|||
|
|
|
|||
|
|
input_tokens = input_tokens or extracted_input
|
|||
|
|
output_tokens = output_tokens or extracted_output
|
|||
|
|
cache_creation_input_tokens = cache_creation_input_tokens or extracted_cache_creation
|
|||
|
|
cache_read_input_tokens = cache_read_input_tokens or extracted_cache_read
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
# 如果没有usage字段,使用fallback
|
|||
|
|
logger.warning(f"[endpoint_check] No usage field found in response, using fallback counting")
|
|||
|
|
try:
|
|||
|
|
fallback_input, fallback_output, fallback_cache_creation, fallback_cache_read = \
|
|||
|
|
_fallback_token_counting(request_data, response_data)
|
|||
|
|
|
|||
|
|
input_tokens = input_tokens or fallback_input
|
|||
|
|
output_tokens = output_tokens or fallback_output
|
|||
|
|
cache_creation_input_tokens = cache_creation_input_tokens or fallback_cache_creation
|
|||
|
|
cache_read_input_tokens = cache_read_input_tokens or fallback_cache_read
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"[endpoint_check] Fallback token counting failed: {e}")
|
|||
|
|
# 设置最小值
|
|||
|
|
input_tokens = input_tokens or 1
|
|||
|
|
output_tokens = output_tokens or 1
|
|||
|
|
cache_creation_input_tokens = cache_creation_input_tokens or 0
|
|||
|
|
cache_read_input_tokens = cache_read_input_tokens or 0
|
|||
|
|
|
|||
|
|
logger.info(f"[endpoint_check] Final token count | input={input_tokens}, output={output_tokens}, "
|
|||
|
|
f"cache_creation={cache_creation_input_tokens}, cache_read={cache_read_input_tokens}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 使用UsageService记录用量
|
|||
|
|
# 测试请求会关联到执行测试的用户API Key,但实际使用Provider API Key
|
|||
|
|
logger.info(f"[endpoint_check] Recording usage | provider={provider_name}, model={model_name}, "
|
|||
|
|
f"tokens=({input_tokens}+{output_tokens}), status={status_code}, "
|
|||
|
|
f"user_api_key_id={user_api_key.id if user_api_key else None}, "
|
|||
|
|
f"provider_endpoint_id={provider_endpoint.id if provider_endpoint else None}")
|
|||
|
|
|
|||
|
|
usage_record = await UsageService.record_usage_async(
|
|||
|
|
db=db,
|
|||
|
|
user=user,
|
|||
|
|
api_key=user_api_key, # 关联到执行测试的用户API Key
|
|||
|
|
provider=provider_name,
|
|||
|
|
model=model_name,
|
|||
|
|
input_tokens=input_tokens,
|
|||
|
|
output_tokens=output_tokens,
|
|||
|
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
|||
|
|
cache_read_input_tokens=cache_read_input_tokens,
|
|||
|
|
request_type="endpoint_test", # 使用特殊的请求类型标识测试
|
|||
|
|
api_format=api_format,
|
|||
|
|
is_stream=False,
|
|||
|
|
response_time_ms=response_time_ms,
|
|||
|
|
first_byte_time_ms=response_time_ms,
|
|||
|
|
status_code=status_code,
|
|||
|
|
error_message=error_message,
|
|||
|
|
request_headers=request_headers,
|
|||
|
|
response_headers=response_headers,
|
|||
|
|
request_body=request_data,
|
|||
|
|
response_body=response_data,
|
|||
|
|
request_id=f"test_{request_id}",
|
|||
|
|
provider_id=provider_id,
|
|||
|
|
provider_endpoint_id=provider_endpoint.id if provider_endpoint else None, # 添加端点ID
|
|||
|
|
provider_api_key_id=api_key_id, # 记录实际使用的Provider API Key
|
|||
|
|
status="completed" if status_code == 200 else "failed",
|
|||
|
|
use_tiered_pricing=True,
|
|||
|
|
target_model=model_name, # 添加目标模型
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查费用计算是否成功
|
|||
|
|
total_cost = float(usage_record.total_cost_usd) if usage_record.total_cost_usd else 0.0
|
|||
|
|
actual_cost = float(usage_record.actual_total_cost_usd) if usage_record.actual_total_cost_usd else 0.0
|
|||
|
|
cache_cost = float(usage_record.cache_cost_usd) if usage_record.cache_cost_usd else 0.0
|
|||
|
|
|
|||
|
|
# 如果费用为0但Token不为0,可能是价格配置缺失,使用默认价格
|
|||
|
|
if total_cost == 0.0 and (input_tokens > 0 or output_tokens > 0):
|
|||
|
|
logger.warning(f"[endpoint_check] Cost calculation returned 0, using fallback pricing")
|
|||
|
|
# 使用默认价格:$0.001/1K tokens
|
|||
|
|
fallback_price_per_1m = 1.0 # $1 per 1M tokens
|
|||
|
|
total_cost = ((input_tokens + output_tokens) / 1_000_000) * fallback_price_per_1m
|
|||
|
|
actual_cost = total_cost # 测试请求使用实际成本
|
|||
|
|
|
|||
|
|
logger.info(f"[endpoint_check] Usage recorded successfully | "
|
|||
|
|
f"usage_id={usage_record.id}, total_cost=${total_cost:.6f}, "
|
|||
|
|
f"actual_cost=${actual_cost:.6f}")
|
|||
|
|
|
|||
|
|
# 创建RequestCandidate记录,用于监控追踪API
|
|||
|
|
try:
|
|||
|
|
# 首先创建候选记录
|
|||
|
|
candidate = RequestCandidateService.create_candidate(
|
|||
|
|
db=db,
|
|||
|
|
request_id=f"test_{request_id}",
|
|||
|
|
candidate_index=0, # 测试请求只有一个候选
|
|||
|
|
user_id=user.id if user else None,
|
|||
|
|
api_key_id=user_api_key.id if user_api_key else None,
|
|||
|
|
provider_id=provider_id,
|
|||
|
|
endpoint_id=provider_endpoint.id if provider_endpoint else None,
|
|||
|
|
key_id=api_key_id,
|
|||
|
|
status="available",
|
|||
|
|
extra_data={"model_name": model_name, "request_type": "endpoint_test"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 立即标记为开始执行
|
|||
|
|
RequestCandidateService.mark_candidate_started(db, candidate.id)
|
|||
|
|
|
|||
|
|
# 根据结果标记为成功或失败
|
|||
|
|
if status_code == 200:
|
|||
|
|
RequestCandidateService.mark_candidate_success(
|
|||
|
|
db=db,
|
|||
|
|
candidate_id=candidate.id,
|
|||
|
|
status_code=status_code,
|
|||
|
|
latency_ms=response_time_ms,
|
|||
|
|
extra_data={"model_name": model_name, "api_format": api_format},
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
RequestCandidateService.mark_candidate_failed(
|
|||
|
|
db=db,
|
|||
|
|
candidate_id=candidate.id,
|
|||
|
|
error_type="http_error" if status_code > 0 else "network_error",
|
|||
|
|
error_message=error_message or "Unknown error",
|
|||
|
|
status_code=status_code,
|
|||
|
|
latency_ms=response_time_ms,
|
|||
|
|
extra_data={"model_name": model_name, "api_format": api_format},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.info(f"[endpoint_check] RequestCandidate created | request_id=test_{request_id}, candidate_id={candidate.id}")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"[endpoint_check] Failed to create RequestCandidate: {e}")
|
|||
|
|
# 不影响主要功能
|
|||
|
|
candidate = None
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"input_tokens": input_tokens,
|
|||
|
|
"output_tokens": output_tokens,
|
|||
|
|
"cache_creation_input_tokens": cache_creation_input_tokens,
|
|||
|
|
"cache_read_input_tokens": cache_read_input_tokens,
|
|||
|
|
"total_tokens": input_tokens + output_tokens,
|
|||
|
|
"total_cost_usd": total_cost,
|
|||
|
|
"actual_total_cost_usd": actual_cost,
|
|||
|
|
"cache_cost_usd": cache_cost,
|
|||
|
|
"status_code": status_code,
|
|||
|
|
"usage_id": str(usage_record.id),
|
|||
|
|
"api_format": api_format,
|
|||
|
|
"request_id": f"test_{request_id}", # 返回request_id用于追踪
|
|||
|
|
"candidate_id": str(candidate.id) if candidate else None, # 返回candidate_id
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to record usage for endpoint check: {e}")
|
|||
|
|
return {
|
|||
|
|
"input_tokens": input_tokens,
|
|||
|
|
"output_tokens": output_tokens,
|
|||
|
|
"cache_creation_input_tokens": cache_creation_input_tokens,
|
|||
|
|
"cache_read_input_tokens": cache_read_input_tokens,
|
|||
|
|
"total_tokens": input_tokens + output_tokens,
|
|||
|
|
"error": str(e),
|
|||
|
|
"status_code": status_code,
|
|||
|
|
"api_format": api_format,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_tokens_from_response(api_identifier: str, response_data: Optional[Dict[str, Any]]) -> tuple[int, int, int, int]:
|
|||
|
|
"""
|
|||
|
|
从响应中提取Token计数信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
api_identifier: API标识符(api_format或provider_name)
|
|||
|
|
response_data: 响应数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
tuple[int, int, int, int]: (input_tokens, output_tokens, cache_creation_input_tokens, cache_read_input_tokens)
|
|||
|
|
"""
|
|||
|
|
if not response_data:
|
|||
|
|
return 0, 0, 0, 0
|
|||
|
|
|
|||
|
|
api_identifier_lower = api_identifier.lower()
|
|||
|
|
usage_info = response_data.get("usage", {})
|
|||
|
|
|
|||
|
|
if not usage_info:
|
|||
|
|
return 0, 0, 0, 0
|
|||
|
|
|
|||
|
|
input_tokens = 0
|
|||
|
|
output_tokens = 0
|
|||
|
|
cache_creation_input_tokens = 0
|
|||
|
|
cache_read_input_tokens = 0
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 基于adapter名字进行更精确的检测
|
|||
|
|
if "claude" in api_identifier_lower:
|
|||
|
|
# Claude格式 - 支持claude.chat和claude.token_count等adapter
|
|||
|
|
input_tokens = usage_info.get("input_tokens", 0)
|
|||
|
|
output_tokens = usage_info.get("output_tokens", 0)
|
|||
|
|
cache_read_input_tokens = usage_info.get("cache_read_input_tokens", 0)
|
|||
|
|
|
|||
|
|
# 尝试提取cache creation tokens
|
|||
|
|
try:
|
|||
|
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
|||
|
|
cache_creation_input_tokens = extract_cache_creation_tokens(usage_info)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"[endpoint_check] Failed to extract cache creation tokens: {e}")
|
|||
|
|
|
|||
|
|
elif "openai" in api_identifier_lower:
|
|||
|
|
# OpenAI格式
|
|||
|
|
input_tokens = usage_info.get("prompt_tokens", 0) or usage_info.get("input_tokens", 0)
|
|||
|
|
output_tokens = usage_info.get("completion_tokens", 0) or usage_info.get("output_tokens", 0)
|
|||
|
|
cache_creation_input_tokens = 0
|
|||
|
|
cache_read_input_tokens = 0
|
|||
|
|
|
|||
|
|
elif "gemini" in api_identifier_lower or "google" in api_identifier_lower:
|
|||
|
|
# Gemini格式 - 使用与OpenAI类似的字段名
|
|||
|
|
input_tokens = usage_info.get("prompt_tokens", 0) or usage_info.get("input_tokens", 0)
|
|||
|
|
output_tokens = usage_info.get("completion_tokens", 0) or usage_info.get("output_tokens", 0)
|
|||
|
|
cache_creation_input_tokens = 0
|
|||
|
|
cache_read_input_tokens = 0
|
|||
|
|
|
|||
|
|
# Fallback: 尝试其他可能的provider名称匹配
|
|||
|
|
elif "anthropic" in api_identifier_lower:
|
|||
|
|
# Anthropic/Claude的其他别名
|
|||
|
|
input_tokens = usage_info.get("input_tokens", 0)
|
|||
|
|
output_tokens = usage_info.get("output_tokens", 0)
|
|||
|
|
cache_read_input_tokens = usage_info.get("cache_read_input_tokens", 0)
|
|||
|
|
try:
|
|||
|
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
|||
|
|
cache_creation_input_tokens = extract_cache_creation_tokens(usage_info)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"[endpoint_check] Failed to extract cache creation tokens: {e}")
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
# 默认情况:尝试通用提取
|
|||
|
|
logger.warning(f"[endpoint_check] Unknown API identifier: {api_identifier}, using generic token extraction")
|
|||
|
|
input_tokens = usage_info.get("input_tokens", 0) or usage_info.get("prompt_tokens", 0)
|
|||
|
|
output_tokens = usage_info.get("output_tokens", 0) or usage_info.get("completion_tokens", 0)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"[endpoint_check] Error extracting tokens from response: {e}")
|
|||
|
|
return 0, 0, 0, 0
|
|||
|
|
|
|||
|
|
logger.info(f"[endpoint_check] Tokens extracted from response | "
|
|||
|
|
f"api_identifier={api_identifier}, "
|
|||
|
|
f"input={input_tokens}, output={output_tokens}, "
|
|||
|
|
f"cache_creation={cache_creation_input_tokens}, cache_read={cache_read_input_tokens}")
|
|||
|
|
|
|||
|
|
return input_tokens, output_tokens, cache_creation_input_tokens, cache_read_input_tokens
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _fallback_token_counting(request_data: Dict[str, Any], response_data: Optional[Dict[str, Any]]) -> tuple[int, int, int, int]:
|
|||
|
|
"""
|
|||
|
|
回退的Token计数方法(简单估算)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
tuple[int, int, int, int]: (input_tokens, output_tokens, cache_creation_input_tokens, cache_read_input_tokens)
|
|||
|
|
"""
|
|||
|
|
# 估算输入Token
|
|||
|
|
messages = request_data.get("messages", request_data.get("contents", []))
|
|||
|
|
if messages:
|
|||
|
|
input_text = str(messages)
|
|||
|
|
input_tokens = max(1, len(input_text.split()) // 4)
|
|||
|
|
else:
|
|||
|
|
# 如果没有消息内容,使用最小值
|
|||
|
|
input_tokens = 1
|
|||
|
|
|
|||
|
|
# 估算输出Token
|
|||
|
|
output_tokens = 1 # 最小输出Token数
|
|||
|
|
if response_data:
|
|||
|
|
# 尝试从响应中提取文本内容
|
|||
|
|
if isinstance(response_data, dict):
|
|||
|
|
# Claude格式
|
|||
|
|
if "content" in response_data:
|
|||
|
|
content = response_data["content"]
|
|||
|
|
if isinstance(content, str):
|
|||
|
|
output_text = content
|
|||
|
|
elif isinstance(content, list):
|
|||
|
|
# Claude的content通常是列表格式
|
|||
|
|
output_text = ""
|
|||
|
|
for block in content:
|
|||
|
|
if isinstance(block, dict) and block.get("type") == "text":
|
|||
|
|
output_text += block.get("text", "")
|
|||
|
|
else:
|
|||
|
|
output_text = str(content)
|
|||
|
|
output_tokens = max(1, len(output_text.split()) // 4)
|
|||
|
|
# OpenAI格式
|
|||
|
|
elif "choices" in response_data and response_data["choices"]:
|
|||
|
|
choice = response_data["choices"][0]
|
|||
|
|
if "message" in choice:
|
|||
|
|
content = choice["message"].get("content", "")
|
|||
|
|
output_tokens = max(1, len(content.split()) // 4)
|
|||
|
|
# Gemini格式
|
|||
|
|
elif "candidates" in response_data and response_data["candidates"]:
|
|||
|
|
candidate = response_data["candidates"][0]
|
|||
|
|
if "content" in candidate and "parts" in candidate["content"]:
|
|||
|
|
output_text = ""
|
|||
|
|
for part in candidate["content"]["parts"]:
|
|||
|
|
if "text" in part:
|
|||
|
|
output_text += part["text"]
|
|||
|
|
output_tokens = max(1, len(output_text.split()) // 4)
|
|||
|
|
|
|||
|
|
logger.info(f"[endpoint_check] Fallback token count | input={input_tokens}, output={output_tokens}")
|
|||
|
|
return input_tokens, output_tokens, 0, 0
|
|||
|
|
|
|||
|
|
# =========================================================================
|
|||
|
|
# 重构后的架构类 - 分离关注点
|
|||
|
|
# =========================================================================
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class EndpointCheckRequest:
|
|||
|
|
"""端点检查请求数据类"""
|
|||
|
|
url: str
|
|||
|
|
headers: Dict[str, str]
|
|||
|
|
json_body: Dict[str, Any]
|
|||
|
|
api_format: str
|
|||
|
|
provider_name: Optional[str] = None
|
|||
|
|
model_name: Optional[str] = None
|
|||
|
|
api_key_id: Optional[str] = None
|
|||
|
|
provider_id: Optional[str] = None
|
|||
|
|
db: Optional[Any] = None
|
|||
|
|
user: Optional[Any] = None
|
|||
|
|
request_id: Optional[str] = None
|
|||
|
|
timeout: float = 30.0
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class EndpointCheckResult:
|
|||
|
|
"""端点检查结果数据类"""
|
|||
|
|
status_code: int
|
|||
|
|
headers: Dict[str, str]
|
|||
|
|
response_time_ms: int
|
|||
|
|
request_id: str
|
|||
|
|
response_data: Optional[Dict[str, Any]] = None
|
|||
|
|
error_message: Optional[str] = None
|
|||
|
|
usage_data: Optional[Dict[str, Any]] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class HttpRequestExecutor:
|
|||
|
|
"""HTTP请求执行器 - 专门负责网络请求"""
|
|||
|
|
|
|||
|
|
def __init__(self, timeout: float = 30.0):
|
|||
|
|
self.timeout = timeout
|
|||
|
|
|
|||
|
|
async def execute(self, request: EndpointCheckRequest) -> EndpointCheckResult:
|
|||
|
|
"""执行HTTP请求"""
|
|||
|
|
start_time = time.time()
|
|||
|
|
request_id = request.request_id or str(uuid.uuid4())[:8]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 使用httpx进行异步请求
|
|||
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|||
|
|
response = await client.post(
|
|||
|
|
url=request.url,
|
|||
|
|
json=request.json_body,
|
|||
|
|
headers=request.headers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
end_time = time.time()
|
|||
|
|
response_time_ms = int((end_time - start_time) * 1000)
|
|||
|
|
|
|||
|
|
# 处理响应
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
try:
|
|||
|
|
response_data = response.json()
|
|||
|
|
logger.debug(f"[{request.api_format}] check_endpoint | response | json={_truncate_repr(response_data)}")
|
|||
|
|
except Exception:
|
|||
|
|
response_data = None
|
|||
|
|
logger.debug(f"[{request.api_format}] check_endpoint | response | invalid json")
|
|||
|
|
|
|||
|
|
return EndpointCheckResult(
|
|||
|
|
status_code=response.status_code,
|
|||
|
|
headers=dict(response.headers),
|
|||
|
|
response_time_ms=response_time_ms,
|
|||
|
|
request_id=request_id,
|
|||
|
|
response_data=response_data
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
# 对于非200状态码,使用错误处理器
|
|||
|
|
error_body = response.text[:500] if response.text else "(empty)"
|
|||
|
|
logger.debug(f"[{request.api_format}] check_endpoint | response | error={error_body}")
|
|||
|
|
|
|||
|
|
# 创建HTTPStatusError让错误处理器处理
|
|||
|
|
http_error = httpx.HTTPStatusError(
|
|||
|
|
message=f"HTTP {response.status_code}: {error_body}",
|
|||
|
|
request=None, # 我们不需要完整的request对象
|
|||
|
|
response=response
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return await ErrorHandler.handle_error(http_error, request)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
# 使用统一错误处理器处理异常
|
|||
|
|
return await ErrorHandler.handle_error(e, request)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class UsageCalculator:
|
|||
|
|
"""用量计算器 - 专门负责Token计数和费用计算"""
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def calculate_tokens(request: EndpointCheckRequest, result: EndpointCheckResult) -> tuple[int, int, int, int]:
|
|||
|
|
"""
|
|||
|
|
计算Token数量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
tuple[int, int, int, int]: (input_tokens, output_tokens, cache_creation_input_tokens, cache_read_input_tokens)
|
|||
|
|
"""
|
|||
|
|
# 优先使用api_format(更准确),fallback到provider_name
|
|||
|
|
api_identifier = request.api_format or request.provider_name
|
|||
|
|
|
|||
|
|
if not api_identifier or not result.response_data:
|
|||
|
|
# 如果没有adapter信息或响应数据,使用fallback
|
|||
|
|
return UsageCalculator._fallback_token_counting(request.json_body, result.response_data)
|
|||
|
|
|
|||
|
|
# 优先从响应中提取usage信息
|
|||
|
|
return _extract_tokens_from_response(api_identifier, result.response_data)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _fallback_token_counting(request_data: Dict[str, Any], response_data: Optional[Dict[str, Any]]) -> tuple[int, int, int, int]:
|
|||
|
|
"""回退的Token计数方法(简单估算)"""
|
|||
|
|
# 估算输入Token
|
|||
|
|
messages = request_data.get("messages", request_data.get("contents", []))
|
|||
|
|
if messages:
|
|||
|
|
input_text = str(messages)
|
|||
|
|
input_tokens = max(1, len(input_text.split()) // 4)
|
|||
|
|
else:
|
|||
|
|
input_tokens = 1
|
|||
|
|
|
|||
|
|
# 估算输出Token
|
|||
|
|
output_tokens = 1 # 最小输出Token数
|
|||
|
|
if response_data:
|
|||
|
|
# 尝试从响应中提取文本内容
|
|||
|
|
if isinstance(response_data, dict):
|
|||
|
|
# Claude格式
|
|||
|
|
if "content" in response_data:
|
|||
|
|
content = response_data["content"]
|
|||
|
|
if isinstance(content, str):
|
|||
|
|
output_text = content
|
|||
|
|
elif isinstance(content, list):
|
|||
|
|
output_text = ""
|
|||
|
|
for block in content:
|
|||
|
|
if isinstance(block, dict) and block.get("type") == "text":
|
|||
|
|
output_text += block.get("text", "")
|
|||
|
|
else:
|
|||
|
|
output_text = str(content)
|
|||
|
|
output_tokens = max(1, len(output_text.split()) // 4)
|
|||
|
|
# OpenAI格式
|
|||
|
|
elif "choices" in response_data and response_data["choices"]:
|
|||
|
|
choice = response_data["choices"][0]
|
|||
|
|
if "message" in choice:
|
|||
|
|
content = choice["message"].get("content", "")
|
|||
|
|
output_tokens = max(1, len(content.split()) // 4)
|
|||
|
|
# Gemini格式
|
|||
|
|
elif "candidates" in response_data and response_data["candidates"]:
|
|||
|
|
candidate = response_data["candidates"][0]
|
|||
|
|
if "content" in candidate and "parts" in candidate["content"]:
|
|||
|
|
output_text = ""
|
|||
|
|
for part in candidate["content"]["parts"]:
|
|||
|
|
if "text" in part:
|
|||
|
|
output_text += part["text"]
|
|||
|
|
output_tokens = max(1, len(output_text.split()) // 4)
|
|||
|
|
|
|||
|
|
return input_tokens, output_tokens, 0, 0
|
|||
|
|
|
|||
|
|
class AsyncBatchUsageRecorder:
|
|||
|
|
"""异步用量记录器 - 批处理数据库操作"""
|
|||
|
|
|
|||
|
|
def __init__(self, batch_size: int = 10, flush_interval: float = 2.0):
|
|||
|
|
self.batch_size = batch_size
|
|||
|
|
self.flush_interval = flush_interval
|
|||
|
|
self.pending_records: List[Dict[str, Any]] = []
|
|||
|
|
self._flush_task: Optional[asyncio.Task] = None
|
|||
|
|
self._lock = asyncio.Lock()
|
|||
|
|
self._running = True
|
|||
|
|
|
|||
|
|
async def add_record(self, usage_data: Dict[str, Any]) -> None:
|
|||
|
|
"""添加用量记录到批处理队列"""
|
|||
|
|
async with self._lock:
|
|||
|
|
self.pending_records.append(usage_data)
|
|||
|
|
|
|||
|
|
# 如果达到批处理大小,立即刷新
|
|||
|
|
if len(self.pending_records) >= self.batch_size:
|
|||
|
|
await self._flush_batch()
|
|||
|
|
else:
|
|||
|
|
# 启动定时刷新任务
|
|||
|
|
self._ensure_flush_task()
|
|||
|
|
|
|||
|
|
def _ensure_flush_task(self) -> None:
|
|||
|
|
"""确保定时刷新任务在运行"""
|
|||
|
|
if self._flush_task is None or self._flush_task.done():
|
|||
|
|
self._flush_task = asyncio.create_task(self._periodic_flush())
|
|||
|
|
|
|||
|
|
async def _periodic_flush(self) -> None:
|
|||
|
|
"""定时刷新任务"""
|
|||
|
|
try:
|
|||
|
|
await asyncio.sleep(self.flush_interval)
|
|||
|
|
async with self._lock:
|
|||
|
|
if self.pending_records:
|
|||
|
|
await self._flush_batch()
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
pass
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"[AsyncBatchUsageRecorder] Periodic flush failed: {e}")
|
|||
|
|
|
|||
|
|
async def _flush_batch(self) -> None:
|
|||
|
|
"""批量刷新到数据库"""
|
|||
|
|
if not self.pending_records:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
records_to_flush = self.pending_records.copy()
|
|||
|
|
self.pending_records.clear()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 这里可以实现批量插入逻辑
|
|||
|
|
# 目前保持简单的逐条插入,但减少了锁的竞争
|
|||
|
|
for record in records_to_flush:
|
|||
|
|
# 调用原有的用量记录逻辑(简化版)
|
|||
|
|
logger.debug(f"[AsyncBatchUsageRecorder] Flushing usage record: {record.get('request_id', 'unknown')}")
|
|||
|
|
|
|||
|
|
logger.info(f"[AsyncBatchUsageRecorder] Flushed {len(records_to_flush)} usage records")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"[AsyncBatchUsageRecorder] Failed to flush batch: {e}")
|
|||
|
|
# 将失败的记录重新加入队列(可选)
|
|||
|
|
async with self._lock:
|
|||
|
|
self.pending_records.extend(records_to_flush)
|
|||
|
|
|
|||
|
|
async def flush(self) -> None:
|
|||
|
|
"""立即刷新所有待处理的记录"""
|
|||
|
|
async with self._lock:
|
|||
|
|
await self._flush_batch()
|
|||
|
|
if self._flush_task:
|
|||
|
|
self._flush_task.cancel()
|
|||
|
|
try:
|
|||
|
|
await self._flush_task
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
pass
|
|||
|
|
self._flush_task = None
|
|||
|
|
|
|||
|
|
async def close(self) -> None:
|
|||
|
|
"""关闭批处理器,刷新所有待处理记录"""
|
|||
|
|
self._running = False
|
|||
|
|
await self.flush()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局批处理器实例(单例)
|
|||
|
|
_global_batch_recorder: Optional[AsyncBatchUsageRecorder] = None
|
|||
|
|
|
|||
|
|
def get_batch_recorder() -> AsyncBatchUsageRecorder:
|
|||
|
|
"""获取全局批处理器实例"""
|
|||
|
|
global _global_batch_recorder
|
|||
|
|
if _global_batch_recorder is None:
|
|||
|
|
_global_batch_recorder = AsyncBatchUsageRecorder()
|
|||
|
|
return _global_batch_recorder
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =========================================================================
|
|||
|
|
# 统一错误处理机制
|
|||
|
|
# =========================================================================
|
|||
|
|
|
|||
|
|
class EndpointCheckError(Exception):
|
|||
|
|
"""端点检查错误基类"""
|
|||
|
|
def __init__(self, message: str, error_type: str, status_code: int = 500, details: Optional[Dict[str, Any]] = None):
|
|||
|
|
super().__init__(message)
|
|||
|
|
self.message = message
|
|||
|
|
self.error_type = error_type
|
|||
|
|
self.status_code = status_code
|
|||
|
|
self.details = details or {}
|
|||
|
|
|
|||
|
|
class NetworkError(EndpointCheckError):
|
|||
|
|
"""网络请求错误"""
|
|||
|
|
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
|||
|
|
super().__init__(message, "network_error", 0, details)
|
|||
|
|
|
|||
|
|
class AuthenticationError(EndpointCheckError):
|
|||
|
|
"""认证错误"""
|
|||
|
|
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
|||
|
|
super().__init__(message, "authentication_error", 401, details)
|
|||
|
|
|
|||
|
|
class RateLimitError(EndpointCheckError):
|
|||
|
|
"""速率限制错误"""
|
|||
|
|
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
|||
|
|
super().__init__(message, "rate_limit_error", 429, details)
|
|||
|
|
|
|||
|
|
class UpstreamError(EndpointCheckError):
|
|||
|
|
"""上游服务错误"""
|
|||
|
|
def __init__(self, message: str, status_code: int, details: Optional[Dict[str, Any]] = None):
|
|||
|
|
super().__init__(message, "upstream_error", status_code, details)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ErrorHandler:
|
|||
|
|
"""统一错误处理器"""
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
async def handle_error(error: Exception, request: EndpointCheckRequest) -> EndpointCheckResult:
|
|||
|
|
"""统一处理各种错误类型"""
|
|||
|
|
if isinstance(error, httpx.RequestError):
|
|||
|
|
return ErrorHandler._handle_network_error(error, request)
|
|||
|
|
elif isinstance(error, httpx.TimeoutException):
|
|||
|
|
return ErrorHandler._handle_timeout_error(error, request)
|
|||
|
|
elif isinstance(error, httpx.HTTPStatusError):
|
|||
|
|
return ErrorHandler._handle_http_status_error(error, request)
|
|||
|
|
elif isinstance(error, EndpointCheckError):
|
|||
|
|
return ErrorHandler._handle_business_error(error, request)
|
|||
|
|
elif isinstance(error, (ValueError, TypeError)):
|
|||
|
|
return ErrorHandler._handle_validation_error(error, request)
|
|||
|
|
else:
|
|||
|
|
return ErrorHandler._handle_unknown_error(error, request)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _handle_network_error(error: httpx.RequestError, request: EndpointCheckRequest) -> EndpointCheckResult:
|
|||
|
|
"""处理网络错误"""
|
|||
|
|
error_message = f"Network error: {str(error)}"
|
|||
|
|
logger.warning(f"[{request.api_format}] Network error: {error}")
|
|||
|
|
|
|||
|
|
# 分类网络错误
|
|||
|
|
if "connect" in str(error).lower():
|
|||
|
|
error_type = "connection_failed"
|
|||
|
|
error_message = "Connection failed to upstream service"
|
|||
|
|
elif "timeout" in str(error).lower():
|
|||
|
|
error_type = "timeout"
|
|||
|
|
error_message = "Request timeout"
|
|||
|
|
else:
|
|||
|
|
error_type = "network_error"
|
|||
|
|
|
|||
|
|
return EndpointCheckResult(
|
|||
|
|
status_code=0,
|
|||
|
|
headers={},
|
|||
|
|
response_time_ms=0,
|
|||
|
|
request_id=request.request_id or str(uuid.uuid4())[:8],
|
|||
|
|
error_message=error_message,
|
|||
|
|
response_data={
|
|||
|
|
"error_type": error_type,
|
|||
|
|
"original_error": str(error),
|
|||
|
|
"retryable": True
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _handle_timeout_error(error: httpx.TimeoutException, request: EndpointCheckRequest) -> EndpointCheckResult:
|
|||
|
|
"""处理超时错误"""
|
|||
|
|
logger.warning(f"[{request.api_format}] Request timeout: {error}")
|
|||
|
|
return EndpointCheckResult(
|
|||
|
|
status_code=0,
|
|||
|
|
headers={},
|
|||
|
|
response_time_ms=int(request.timeout * 1000), # 转换为毫秒
|
|||
|
|
request_id=request.request_id or str(uuid.uuid4())[:8],
|
|||
|
|
error_message="Request timeout",
|
|||
|
|
response_data={
|
|||
|
|
"error_type": "timeout",
|
|||
|
|
"original_error": str(error),
|
|||
|
|
"retryable": True,
|
|||
|
|
"timeout_seconds": request.timeout
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _handle_http_status_error(error: httpx.HTTPStatusError, request: EndpointCheckRequest) -> EndpointCheckResult:
|
|||
|
|
"""处理HTTP状态错误"""
|
|||
|
|
logger.warning(f"[{request.api_format}] HTTP error: {error.response.status_code} - {error.response.text[:200]}")
|
|||
|
|
|
|||
|
|
# 根据状态码分类错误
|
|||
|
|
status_code = error.response.status_code
|
|||
|
|
if status_code == 401:
|
|||
|
|
error_type = "authentication_error"
|
|||
|
|
error_message = "Authentication failed"
|
|||
|
|
retryable = False
|
|||
|
|
elif status_code == 429:
|
|||
|
|
error_type = "rate_limit_error"
|
|||
|
|
error_message = "Rate limit exceeded"
|
|||
|
|
retryable = True
|
|||
|
|
elif 400 <= status_code < 500:
|
|||
|
|
error_type = "client_error"
|
|||
|
|
error_message = f"Client error: {status_code}"
|
|||
|
|
retryable = False
|
|||
|
|
elif 500 <= status_code < 600:
|
|||
|
|
error_type = "server_error"
|
|||
|
|
error_message = f"Server error: {status_code}"
|
|||
|
|
retryable = True
|
|||
|
|
else:
|
|||
|
|
error_type = "http_error"
|
|||
|
|
error_message = f"HTTP error: {status_code}"
|
|||
|
|
retryable = status_code >= 500
|
|||
|
|
|
|||
|
|
return EndpointCheckResult(
|
|||
|
|
status_code=status_code,
|
|||
|
|
headers=dict(error.response.headers),
|
|||
|
|
response_time_ms=0,
|
|||
|
|
request_id=request.request_id or str(uuid.uuid4())[:8],
|
|||
|
|
error_message=error_message,
|
|||
|
|
response_data={
|
|||
|
|
"error_type": error_type,
|
|||
|
|
"http_status": status_code,
|
|||
|
|
"response_body": error.response.text[:500] if error.response.text else "",
|
|||
|
|
"retryable": retryable
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _handle_business_error(error: EndpointCheckError, request: EndpointCheckRequest) -> EndpointCheckResult:
|
|||
|
|
"""处理业务逻辑错误"""
|
|||
|
|
logger.warning(f"[{request.api_format}] Business error: {error.error_type} - {error.message}")
|
|||
|
|
return EndpointCheckResult(
|
|||
|
|
status_code=error.status_code,
|
|||
|
|
headers={},
|
|||
|
|
response_time_ms=0,
|
|||
|
|
request_id=request.request_id or str(uuid.uuid4())[:8],
|
|||
|
|
error_message=error.message,
|
|||
|
|
response_data={
|
|||
|
|
"error_type": error.error_type,
|
|||
|
|
"details": error.details,
|
|||
|
|
"retryable": error.status_code >= 500 or error.status_code == 429
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _handle_validation_error(error: ValueError, request: EndpointCheckRequest) -> EndpointCheckResult:
|
|||
|
|
"""处理验证错误"""
|
|||
|
|
logger.warning(f"[{request.api_format}] Validation error: {error}")
|
|||
|
|
return EndpointCheckResult(
|
|||
|
|
status_code=400,
|
|||
|
|
headers={},
|
|||
|
|
response_time_ms=0,
|
|||
|
|
request_id=request.request_id or str(uuid.uuid4())[:8],
|
|||
|
|
error_message=f"Validation error: {str(error)}",
|
|||
|
|
response_data={
|
|||
|
|
"error_type": "validation_error",
|
|||
|
|
"original_error": str(error),
|
|||
|
|
"retryable": False
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _handle_unknown_error(error: Exception, request: EndpointCheckRequest) -> EndpointCheckResult:
|
|||
|
|
"""处理未知错误"""
|
|||
|
|
logger.error(f"[{request.api_format}] Unknown error: {type(error).__name__}: {error}")
|
|||
|
|
import traceback
|
|||
|
|
logger.error(f"[{request.api_format}] Traceback: {traceback.format_exc()}")
|
|||
|
|
|
|||
|
|
return EndpointCheckResult(
|
|||
|
|
status_code=500,
|
|||
|
|
headers={},
|
|||
|
|
response_time_ms=0,
|
|||
|
|
request_id=request.request_id or str(uuid.uuid4())[:8],
|
|||
|
|
error_message="Internal server error",
|
|||
|
|
response_data={
|
|||
|
|
"error_type": "internal_error",
|
|||
|
|
"original_error": str(error),
|
|||
|
|
"retryable": False
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =========================================================================
|
|||
|
|
# 配置化支持
|
|||
|
|
# =========================================================================
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class EndpointCheckConfig:
|
|||
|
|
"""端点检查配置"""
|
|||
|
|
# 性能配置
|
|||
|
|
timeout: float = 30.0
|
|||
|
|
max_retries: int = 3
|
|||
|
|
retry_delay: float = 1.0
|
|||
|
|
|
|||
|
|
# 缓存配置
|
|||
|
|
api_format_cache_size: int = 512
|
|||
|
|
request_cache_size: int = 128
|
|||
|
|
|
|||
|
|
# 批处理配置
|
|||
|
|
enable_batch_recording: bool = True
|
|||
|
|
batch_size: int = 10
|
|||
|
|
batch_flush_interval: float = 2.0
|
|||
|
|
|
|||
|
|
# 日志配置
|
|||
|
|
enable_detailed_logging: bool = False
|
|||
|
|
enable_structured_logging: bool = True
|
|||
|
|
|
|||
|
|
# 用量计算配置
|
|||
|
|
enable_usage_calculation: bool = True
|
|||
|
|
enable_fallback_token_counting: bool = True
|
|||
|
|
|
|||
|
|
# 错误处理配置
|
|||
|
|
enable_error_classification: bool = True
|
|||
|
|
retry_on_server_errors: bool = True
|
|||
|
|
retry_on_timeouts: bool = True
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def from_env(cls) -> 'EndpointCheckConfig':
|
|||
|
|
"""从环境变量创建配置"""
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
return cls(
|
|||
|
|
timeout=float(os.getenv('ENDPOINT_CHECK_TIMEOUT', '30.0')),
|
|||
|
|
max_retries=int(os.getenv('ENDPOINT_CHECK_MAX_RETRIES', '3')),
|
|||
|
|
retry_delay=float(os.getenv('ENDPOINT_CHECK_RETRY_DELAY', '1.0')),
|
|||
|
|
api_format_cache_size=int(os.getenv('ENDPOINT_CHECK_CACHE_SIZE', '512')),
|
|||
|
|
enable_batch_recording=os.getenv('ENDPOINT_CHECK_BATCH_RECORDING', 'true').lower() == 'true',
|
|||
|
|
batch_size=int(os.getenv('ENDPOINT_CHECK_BATCH_SIZE', '10')),
|
|||
|
|
batch_flush_interval=float(os.getenv('ENDPOINT_CHECK_BATCH_INTERVAL', '2.0')),
|
|||
|
|
enable_detailed_logging=os.getenv('ENDPOINT_CHECK_DETAILED_LOGGING', 'false').lower() == 'true',
|
|||
|
|
enable_structured_logging=os.getenv('ENDPOINT_CHECK_STRUCTURED_LOGGING', 'true').lower() == 'true',
|
|||
|
|
enable_usage_calculation=os.getenv('ENDPOINT_CHECK_USAGE_CALCULATION', 'true').lower() == 'true',
|
|||
|
|
enable_fallback_token_counting=os.getenv('ENDPOINT_CHECK_FALLBACK_COUNTING', 'true').lower() == 'true',
|
|||
|
|
enable_error_classification=os.getenv('ENDPOINT_CHECK_ERROR_CLASSIFICATION', 'true').lower() == 'true',
|
|||
|
|
retry_on_server_errors=os.getenv('ENDPOINT_CHECK_RETRY_SERVER_ERRORS', 'true').lower() == 'true',
|
|||
|
|
retry_on_timeouts=os.getenv('ENDPOINT_CHECK_RETRY_TIMEOUTS', 'true').lower() == 'true',
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def from_dict(cls, config_dict: Dict[str, Any]) -> 'EndpointCheckConfig':
|
|||
|
|
"""从字典创建配置"""
|
|||
|
|
return cls(**{k: v for k, v in config_dict.items() if hasattr(cls, k)})
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ConfigurableEndpointChecker:
|
|||
|
|
"""可配置的端点检查器"""
|
|||
|
|
|
|||
|
|
def __init__(self, config: Optional[EndpointCheckConfig] = None):
|
|||
|
|
self.config = config or EndpointCheckConfig()
|
|||
|
|
self.executor = HttpRequestExecutor(timeout=self.config.timeout)
|
|||
|
|
self.usage_calculator = UsageCalculator()
|
|||
|
|
self.orchestrator = EndpointCheckOrchestrator(
|
|||
|
|
executor=self.executor,
|
|||
|
|
usage_calculator=self.usage_calculator
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 应用配置到缓存大小
|
|||
|
|
self._apply_cache_config()
|
|||
|
|
|
|||
|
|
def _apply_cache_config(self) -> None:
|
|||
|
|
"""应用缓存配置"""
|
|||
|
|
# 简化缓存配置 - 移除了有问题的缓存实现
|
|||
|
|
# 未来如果需要缓存,可以重新设计缓存策略
|
|||
|
|
logger.info(f"[ConfigurableEndpointChecker] Cache config applied: api_format_cache_size={self.config.api_format_cache_size}")
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
async def check_endpoint(self, request: EndpointCheckRequest) -> EndpointCheckResult:
|
|||
|
|
"""根据配置执行端点检查"""
|
|||
|
|
# 应用配置到请求
|
|||
|
|
request.timeout = self.config.timeout
|
|||
|
|
|
|||
|
|
# 如果启用了结构化日志,使用结构化日志记录
|
|||
|
|
if self.config.enable_structured_logging:
|
|||
|
|
self._log_structured_start(request)
|
|||
|
|
|
|||
|
|
# 执行检查
|
|||
|
|
result = await self.orchestrator.execute_check(request)
|
|||
|
|
|
|||
|
|
# 应用重试逻辑
|
|||
|
|
if self.config.max_retries > 0 and self._should_retry(result):
|
|||
|
|
result = await self._retry_check(request, result)
|
|||
|
|
|
|||
|
|
# 记录结构化日志
|
|||
|
|
if self.config.enable_structured_logging:
|
|||
|
|
self._log_structured_result(request, result)
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def _should_retry(self, result: EndpointCheckResult) -> bool:
|
|||
|
|
"""判断是否应该重试"""
|
|||
|
|
if not self.config.enable_error_classification or not result.response_data:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
error_type = result.response_data.get("error_type", "")
|
|||
|
|
retryable = result.response_data.get("retryable", False)
|
|||
|
|
|
|||
|
|
# 根据配置和错误类型判断是否重试
|
|||
|
|
if error_type == "timeout" and self.config.retry_on_timeouts:
|
|||
|
|
return True
|
|||
|
|
elif error_type in ["server_error", "network_error", "connection_failed"] and self.config.retry_on_server_errors:
|
|||
|
|
return retryable
|
|||
|
|
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
async def _retry_check(self, request: EndpointCheckRequest, last_result: EndpointCheckResult) -> EndpointCheckResult:
|
|||
|
|
"""重试端点检查"""
|
|||
|
|
for attempt in range(self.config.max_retries):
|
|||
|
|
if self.config.enable_structured_logging:
|
|||
|
|
self._log_structured_retry(request, attempt + 1, last_result)
|
|||
|
|
|
|||
|
|
# 等待重试延迟
|
|||
|
|
await asyncio.sleep(self.config.retry_delay * (2 ** attempt)) # 指数退避
|
|||
|
|
|
|||
|
|
# 执行重试
|
|||
|
|
result = await self.orchestrator.execute_check(request)
|
|||
|
|
|
|||
|
|
# 如果成功或不再需要重试,返回结果
|
|||
|
|
if result.status_code == 200 or not self._should_retry(result):
|
|||
|
|
if self.config.enable_structured_logging:
|
|||
|
|
self._log_structured_retry_success(request, attempt + 1)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
# 所有重试都失败了,返回最后一个结果
|
|||
|
|
if self.config.enable_structured_logging:
|
|||
|
|
self._log_structured_retry_failed(request)
|
|||
|
|
return last_result
|
|||
|
|
|
|||
|
|
def _log_structured_start(self, request: EndpointCheckRequest) -> None:
|
|||
|
|
"""记录结构化开始日志"""
|
|||
|
|
log_entry = {
|
|||
|
|
"event": "endpoint_check_start",
|
|||
|
|
"timestamp": time.time(),
|
|||
|
|
"request_id": request.request_id,
|
|||
|
|
"provider": request.provider_name,
|
|||
|
|
"model": request.model_name,
|
|||
|
|
"url": request.url,
|
|||
|
|
"config": {
|
|||
|
|
"timeout": self.config.timeout,
|
|||
|
|
"max_retries": self.config.max_retries,
|
|||
|
|
"enable_batch_recording": self.config.enable_batch_recording,
|
|||
|
|
"enable_usage_calculation": self.config.enable_usage_calculation,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
logger.info(f"[{request.api_format}] {json.dumps(log_entry)}")
|
|||
|
|
|
|||
|
|
def _log_structured_result(self, request: EndpointCheckRequest, result: EndpointCheckResult) -> None:
|
|||
|
|
"""记录结构化结果日志"""
|
|||
|
|
log_entry = {
|
|||
|
|
"event": "endpoint_check_complete",
|
|||
|
|
"timestamp": time.time(),
|
|||
|
|
"request_id": request.request_id,
|
|||
|
|
"provider": request.provider_name,
|
|||
|
|
"model": request.model_name,
|
|||
|
|
"status_code": result.status_code,
|
|||
|
|
"response_time_ms": result.response_time_ms,
|
|||
|
|
"error_message": result.error_message,
|
|||
|
|
"has_usage_data": result.usage_data is not None,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if result.response_data and "error_type" in result.response_data:
|
|||
|
|
log_entry["error_type"] = result.response_data["error_type"]
|
|||
|
|
log_entry["retryable"] = result.response_data.get("retryable", False)
|
|||
|
|
|
|||
|
|
logger.info(f"[{request.api_format}] {json.dumps(log_entry)}")
|
|||
|
|
|
|||
|
|
def _log_structured_retry(self, request: EndpointCheckRequest, attempt: int, last_result: EndpointCheckResult) -> None:
|
|||
|
|
"""记录重试日志"""
|
|||
|
|
log_entry = {
|
|||
|
|
"event": "endpoint_check_retry",
|
|||
|
|
"timestamp": time.time(),
|
|||
|
|
"request_id": request.request_id,
|
|||
|
|
"provider": request.provider_name,
|
|||
|
|
"model": request.model_name,
|
|||
|
|
"attempt": attempt,
|
|||
|
|
"last_status_code": last_result.status_code,
|
|||
|
|
"last_error": last_result.error_message,
|
|||
|
|
"retry_delay": self.config.retry_delay * (2 ** (attempt - 1)),
|
|||
|
|
}
|
|||
|
|
logger.warning(f"[{request.api_format}] {json.dumps(log_entry)}")
|
|||
|
|
|
|||
|
|
def _log_structured_retry_success(self, request: EndpointCheckRequest, attempt: int) -> None:
|
|||
|
|
"""记录重试成功日志"""
|
|||
|
|
log_entry = {
|
|||
|
|
"event": "endpoint_check_retry_success",
|
|||
|
|
"timestamp": time.time(),
|
|||
|
|
"request_id": request.request_id,
|
|||
|
|
"provider": request.provider_name,
|
|||
|
|
"model": request.model_name,
|
|||
|
|
"attempts": attempt,
|
|||
|
|
}
|
|||
|
|
logger.info(f"[{request.api_format}] {json.dumps(log_entry)}")
|
|||
|
|
|
|||
|
|
def _log_structured_retry_failed(self, request: EndpointCheckRequest) -> None:
|
|||
|
|
"""记录重试失败日志"""
|
|||
|
|
log_entry = {
|
|||
|
|
"event": "endpoint_check_retry_failed",
|
|||
|
|
"timestamp": time.time(),
|
|||
|
|
"request_id": request.request_id,
|
|||
|
|
"provider": request.provider_name,
|
|||
|
|
"model": request.model_name,
|
|||
|
|
"max_attempts": self.config.max_retries,
|
|||
|
|
}
|
|||
|
|
logger.error(f"[{request.api_format}] {json.dumps(log_entry)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局配置检查器实例
|
|||
|
|
_global_configured_checker: Optional[ConfigurableEndpointChecker] = None
|
|||
|
|
|
|||
|
|
def get_configured_checker(config: Optional[EndpointCheckConfig] = None) -> ConfigurableEndpointChecker:
|
|||
|
|
"""获取全局配置检查器实例"""
|
|||
|
|
global _global_configured_checker
|
|||
|
|
if _global_configured_checker is None or config is not None:
|
|||
|
|
_global_configured_checker = ConfigurableEndpointChecker(config or EndpointCheckConfig.from_env())
|
|||
|
|
return _global_configured_checker
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
class EndpointCheckOrchestrator:
|
|||
|
|
"""端点检查协调器 - 协调整个流程"""
|
|||
|
|
|
|||
|
|
def __init__(self, executor: Optional[HttpRequestExecutor] = None,
|
|||
|
|
usage_calculator: Optional[UsageCalculator] = None):
|
|||
|
|
self.executor = executor or HttpRequestExecutor()
|
|||
|
|
self.usage_calculator = usage_calculator or UsageCalculator()
|
|||
|
|
|
|||
|
|
async def execute_check(self, request: EndpointCheckRequest) -> EndpointCheckResult:
|
|||
|
|
"""执行端点检查的完整流程"""
|
|||
|
|
logger.info(f"[{request.api_format}] Starting endpoint check | "
|
|||
|
|
f"provider={request.provider_name}, model={request.model_name}")
|
|||
|
|
|
|||
|
|
# 1. 执行HTTP请求
|
|||
|
|
result = await self.executor.execute(request)
|
|||
|
|
|
|||
|
|
# 2. 计算用量
|
|||
|
|
if request.db and request.user: # 只在有数据库连接和用户信息时才计算用量
|
|||
|
|
try:
|
|||
|
|
input_tokens, output_tokens, cache_creation_input_tokens, cache_read_input_tokens = \
|
|||
|
|
self.usage_calculator.calculate_tokens(request, result)
|
|||
|
|
|
|||
|
|
# 检测API格式
|
|||
|
|
api_format = request.api_format
|
|||
|
|
result.usage_data = await _calculate_and_record_usage(
|
|||
|
|
db=request.db,
|
|||
|
|
user=request.user,
|
|||
|
|
provider_name=request.provider_name or "unknown",
|
|||
|
|
provider_id=request.provider_id or "unknown",
|
|||
|
|
api_key_id=request.api_key_id or "unknown",
|
|||
|
|
model_name=request.model_name or "unknown",
|
|||
|
|
request_data=request.json_body,
|
|||
|
|
response_data=result.response_data,
|
|||
|
|
request_id=result.request_id,
|
|||
|
|
response_time_ms=result.response_time_ms,
|
|||
|
|
request_headers=request.headers,
|
|||
|
|
response_headers=result.headers,
|
|||
|
|
status_code=result.status_code,
|
|||
|
|
error_message=result.error_message,
|
|||
|
|
# 直接传递计算好的token数据
|
|||
|
|
input_tokens=input_tokens,
|
|||
|
|
output_tokens=output_tokens,
|
|||
|
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
|||
|
|
cache_read_input_tokens=cache_read_input_tokens,
|
|||
|
|
api_format=api_format,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.info(f"[{request.api_format}] Usage calculated successfully: {usage_data}")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"[{request.api_format}] Failed to calculate usage: {e}")
|
|||
|
|
import traceback
|
|||
|
|
logger.error(f"[{request.api_format}] Usage calculation traceback: {traceback.format_exc()}")
|
|||
|
|
|
|||
|
|
return result
|