""" Claude Chat Adapter - 基于 ChatAdapterBase 的 Claude Chat API 适配器 处理 /v1/messages 端点的 Claude Chat 格式请求。 """ from typing import Any, Dict, Optional, Type from fastapi import HTTPException, Request from fastapi.responses import JSONResponse from src.api.base.adapter import ApiAdapter, ApiMode from src.api.base.context import ApiRequestContext from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter from src.api.handlers.base.chat_handler_base import ChatHandlerBase from src.core.logger import logger from src.core.optimization_utils import TokenCounter from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest class ClaudeCapabilityDetector: """Claude API 能力检测器""" @staticmethod def detect_from_headers( headers: Dict[str, str], request_body: Optional[Dict[str, Any]] = None, ) -> Dict[str, bool]: """ 从 Claude 请求头检测能力需求 检测规则: - anthropic-beta: context-1m-xxx -> context_1m: True Args: headers: 请求头字典 request_body: 请求体(Claude 不使用,保留用于接口统一) """ requirements: Dict[str, bool] = {} # 检查 anthropic-beta 请求头(大小写不敏感) beta_header = None for key, value in headers.items(): if key.lower() == "anthropic-beta": beta_header = value break if beta_header: # 检查是否包含 context-1m 标识 if "context-1m" in beta_header.lower(): requirements["context_1m"] = True return requirements @register_adapter class ClaudeChatAdapter(ChatAdapterBase): """ Claude Chat API 适配器 处理 Claude Chat 格式的请求(/v1/messages 端点,进行格式验证)。 """ FORMAT_ID = "CLAUDE" name = "claude.chat" @property def HANDLER_CLASS(self) -> Type[ChatHandlerBase]: """延迟导入 Handler 类避免循环依赖""" from src.api.handlers.claude.handler import ClaudeChatHandler return ClaudeChatHandler def __init__(self, allowed_api_formats: Optional[list[str]] = None): super().__init__(allowed_api_formats or ["CLAUDE"]) logger.info(f"[{self.name}] 初始化Chat模式适配器 | API格式: {self.allowed_api_formats}") def extract_api_key(self, request: Request) -> Optional[str]: """从请求中提取 API 密钥 (x-api-key)""" return request.headers.get("x-api-key") def detect_capability_requirements( self, headers: Dict[str, str], request_body: Optional[Dict[str, Any]] = None, ) -> Dict[str, bool]: """检测 Claude 请求中隐含的能力需求""" return ClaudeCapabilityDetector.detect_from_headers(headers) # ========================================================================= # Claude 特定的计费逻辑 # ========================================================================= def compute_total_input_context( self, input_tokens: int, cache_read_input_tokens: int, cache_creation_input_tokens: int = 0, ) -> int: """ 计算 Claude 的总输入上下文(用于阶梯计费判定) Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens """ return input_tokens + cache_creation_input_tokens + cache_read_input_tokens def _validate_request_body(self, original_request_body: dict, path_params: dict = None): """验证请求体""" try: if not isinstance(original_request_body, dict): raise ValueError("Request body must be a JSON object") required_fields = ["model", "messages", "max_tokens"] missing_fields = [f for f in required_fields if f not in original_request_body] if missing_fields: raise ValueError(f"Missing required fields: {', '.join(missing_fields)}") request = ClaudeMessagesRequest.model_validate( original_request_body, strict=False, ) except ValueError as e: logger.error(f"请求体基本验证失败: {str(e)}") raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}") request = ClaudeMessagesRequest.model_construct( model=original_request_body.get("model"), max_tokens=original_request_body.get("max_tokens"), messages=original_request_body.get("messages", []), stream=original_request_body.get("stream", False), ) return request def _build_audit_metadata(self, _payload: Dict[str, Any], request_obj) -> Dict[str, Any]: """构建 Claude Chat 特定的审计元数据""" role_counts: dict[str, int] = {} for message in request_obj.messages: role_counts[message.role] = role_counts.get(message.role, 0) + 1 return { "action": "claude_messages", "model": request_obj.model, "stream": bool(request_obj.stream), "max_tokens": request_obj.max_tokens, "temperature": getattr(request_obj, "temperature", None), "top_p": getattr(request_obj, "top_p", None), "top_k": getattr(request_obj, "top_k", None), "messages_count": len(request_obj.messages), "message_roles": role_counts, "stop_sequences": len(request_obj.stop_sequences or []), "tools_count": len(request_obj.tools or []), "system_present": bool(request_obj.system), "metadata_present": bool(request_obj.metadata), "thinking_enabled": bool(request_obj.thinking), } def build_claude_adapter(x_app_header: Optional[str]): """根据 x-app 头部构造 Chat 或 Claude Code 适配器。""" if x_app_header and x_app_header.lower() == "cli": from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter return ClaudeCliAdapter() return ClaudeChatAdapter() class ClaudeTokenCountAdapter(ApiAdapter): """计算 Claude 请求 Token 数的轻量适配器。""" name = "claude.token_count" mode = ApiMode.STANDARD def extract_api_key(self, request: Request) -> Optional[str]: """从请求中提取 API 密钥 (x-api-key 或 Authorization: Bearer)""" # 优先检查 x-api-key api_key = request.headers.get("x-api-key") if api_key: return api_key # 降级到 Authorization: Bearer authorization = request.headers.get("authorization") if authorization and authorization.startswith("Bearer "): return authorization.replace("Bearer ", "") return None async def handle(self, context: ApiRequestContext): payload = context.ensure_json_body() try: request = ClaudeTokenCountRequest.model_validate(payload, strict=False) except Exception as e: logger.error(f"Token count payload invalid: {e}") raise HTTPException(status_code=400, detail="Invalid token count payload") from e token_counter = TokenCounter() total_tokens = 0 if request.system: if isinstance(request.system, str): total_tokens += token_counter.count_tokens(request.system, request.model) elif isinstance(request.system, list): for block in request.system: if hasattr(block, "text"): total_tokens += token_counter.count_tokens(block.text, request.model) messages_dict = [ msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in request.messages ] total_tokens += token_counter.count_messages_tokens(messages_dict, request.model) context.add_audit_metadata( action="claude_token_count", model=request.model, messages_count=len(request.messages), system_present=bool(request.system), tools_count=len(request.tools or []), thinking_enabled=bool(request.thinking), input_tokens=total_tokens, ) return JSONResponse({"input_tokens": total_tokens}) __all__ = [ "ClaudeChatAdapter", "ClaudeTokenCountAdapter", "build_claude_adapter", ]