Files
Aether/src/api/handlers/claude/adapter.py
2025-12-10 20:52:44 +08:00

229 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Claude Chat Adapter - 基于 ChatAdapterBase 的 Claude Chat API 适配器
处理 /v1/messages 端点的 Claude Chat 格式请求。
"""
from typing import Any, Dict, Optional, Type
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.context import ApiRequestContext
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.logger import logger
from src.core.optimization_utils import TokenCounter
from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest
class ClaudeCapabilityDetector:
"""Claude API 能力检测器"""
@staticmethod
def detect_from_headers(
headers: Dict[str, str],
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""
从 Claude 请求头检测能力需求
检测规则:
- anthropic-beta: context-1m-xxx -> context_1m: True
Args:
headers: 请求头字典
request_body: 请求体Claude 不使用,保留用于接口统一)
"""
requirements: Dict[str, bool] = {}
# 检查 anthropic-beta 请求头(大小写不敏感)
beta_header = None
for key, value in headers.items():
if key.lower() == "anthropic-beta":
beta_header = value
break
if beta_header:
# 检查是否包含 context-1m 标识
if "context-1m" in beta_header.lower():
requirements["context_1m"] = True
return requirements
@register_adapter
class ClaudeChatAdapter(ChatAdapterBase):
"""
Claude Chat API 适配器
处理 Claude Chat 格式的请求(/v1/messages 端点,进行格式验证)。
"""
FORMAT_ID = "CLAUDE"
name = "claude.chat"
@property
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.claude.handler import ClaudeChatHandler
return ClaudeChatHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["CLAUDE"])
logger.info(f"[{self.name}] 初始化Chat模式适配器 | API格式: {self.allowed_api_formats}")
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (x-api-key)"""
return request.headers.get("x-api-key")
def detect_capability_requirements(
self,
headers: Dict[str, str],
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""检测 Claude 请求中隐含的能力需求"""
return ClaudeCapabilityDetector.detect_from_headers(headers)
# =========================================================================
# Claude 特定的计费逻辑
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算 Claude 的总输入上下文(用于阶梯计费判定)
Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
"""
return input_tokens + cache_creation_input_tokens + cache_read_input_tokens
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
"""验证请求体"""
try:
if not isinstance(original_request_body, dict):
raise ValueError("Request body must be a JSON object")
required_fields = ["model", "messages", "max_tokens"]
missing_fields = [f for f in required_fields if f not in original_request_body]
if missing_fields:
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
request = ClaudeMessagesRequest.model_validate(
original_request_body,
strict=False,
)
except ValueError as e:
logger.error(f"请求体基本验证失败: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
request = ClaudeMessagesRequest.model_construct(
model=original_request_body.get("model"),
max_tokens=original_request_body.get("max_tokens"),
messages=original_request_body.get("messages", []),
stream=original_request_body.get("stream", False),
)
return request
def _build_audit_metadata(self, _payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
"""构建 Claude Chat 特定的审计元数据"""
role_counts: dict[str, int] = {}
for message in request_obj.messages:
role_counts[message.role] = role_counts.get(message.role, 0) + 1
return {
"action": "claude_messages",
"model": request_obj.model,
"stream": bool(request_obj.stream),
"max_tokens": request_obj.max_tokens,
"temperature": getattr(request_obj, "temperature", None),
"top_p": getattr(request_obj, "top_p", None),
"top_k": getattr(request_obj, "top_k", None),
"messages_count": len(request_obj.messages),
"message_roles": role_counts,
"stop_sequences": len(request_obj.stop_sequences or []),
"tools_count": len(request_obj.tools or []),
"system_present": bool(request_obj.system),
"metadata_present": bool(request_obj.metadata),
"thinking_enabled": bool(request_obj.thinking),
}
def build_claude_adapter(x_app_header: Optional[str]):
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
if x_app_header and x_app_header.lower() == "cli":
from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter
return ClaudeCliAdapter()
return ClaudeChatAdapter()
class ClaudeTokenCountAdapter(ApiAdapter):
"""计算 Claude 请求 Token 数的轻量适配器。"""
name = "claude.token_count"
mode = ApiMode.STANDARD
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (x-api-key 或 Authorization: Bearer)"""
# 优先检查 x-api-key
api_key = request.headers.get("x-api-key")
if api_key:
return api_key
# 降级到 Authorization: Bearer
authorization = request.headers.get("authorization")
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
async def handle(self, context: ApiRequestContext):
payload = context.ensure_json_body()
try:
request = ClaudeTokenCountRequest.model_validate(payload, strict=False)
except Exception as e:
logger.error(f"Token count payload invalid: {e}")
raise HTTPException(status_code=400, detail="Invalid token count payload") from e
token_counter = TokenCounter()
total_tokens = 0
if request.system:
if isinstance(request.system, str):
total_tokens += token_counter.count_tokens(request.system, request.model)
elif isinstance(request.system, list):
for block in request.system:
if hasattr(block, "text"):
total_tokens += token_counter.count_tokens(block.text, request.model)
messages_dict = [
msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in request.messages
]
total_tokens += token_counter.count_messages_tokens(messages_dict, request.model)
context.add_audit_metadata(
action="claude_token_count",
model=request.model,
messages_count=len(request.messages),
system_present=bool(request.system),
tools_count=len(request.tools or []),
thinking_enabled=bool(request.thinking),
input_tokens=total_tokens,
)
return JSONResponse({"input_tokens": total_tokens})
__all__ = [
"ClaudeChatAdapter",
"ClaudeTokenCountAdapter",
"build_claude_adapter",
]