mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 10:12:27 +08:00
229 lines
8.4 KiB
Python
229 lines
8.4 KiB
Python
|
|
"""
|
|||
|
|
Claude Chat Adapter - 基于 ChatAdapterBase 的 Claude Chat API 适配器
|
|||
|
|
|
|||
|
|
处理 /v1/messages 端点的 Claude Chat 格式请求。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from typing import Any, Dict, Optional, Type
|
|||
|
|
|
|||
|
|
from fastapi import HTTPException, Request
|
|||
|
|
from fastapi.responses import JSONResponse
|
|||
|
|
|
|||
|
|
from src.api.base.adapter import ApiAdapter, ApiMode
|
|||
|
|
from src.api.base.context import ApiRequestContext
|
|||
|
|
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
|||
|
|
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
|||
|
|
from src.core.logger import logger
|
|||
|
|
from src.core.optimization_utils import TokenCounter
|
|||
|
|
from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ClaudeCapabilityDetector:
|
|||
|
|
"""Claude API 能力检测器"""
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def detect_from_headers(
|
|||
|
|
headers: Dict[str, str],
|
|||
|
|
request_body: Optional[Dict[str, Any]] = None,
|
|||
|
|
) -> Dict[str, bool]:
|
|||
|
|
"""
|
|||
|
|
从 Claude 请求头检测能力需求
|
|||
|
|
|
|||
|
|
检测规则:
|
|||
|
|
- anthropic-beta: context-1m-xxx -> context_1m: True
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
headers: 请求头字典
|
|||
|
|
request_body: 请求体(Claude 不使用,保留用于接口统一)
|
|||
|
|
"""
|
|||
|
|
requirements: Dict[str, bool] = {}
|
|||
|
|
|
|||
|
|
# 检查 anthropic-beta 请求头(大小写不敏感)
|
|||
|
|
beta_header = None
|
|||
|
|
for key, value in headers.items():
|
|||
|
|
if key.lower() == "anthropic-beta":
|
|||
|
|
beta_header = value
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if beta_header:
|
|||
|
|
# 检查是否包含 context-1m 标识
|
|||
|
|
if "context-1m" in beta_header.lower():
|
|||
|
|
requirements["context_1m"] = True
|
|||
|
|
|
|||
|
|
return requirements
|
|||
|
|
|
|||
|
|
|
|||
|
|
@register_adapter
|
|||
|
|
class ClaudeChatAdapter(ChatAdapterBase):
|
|||
|
|
"""
|
|||
|
|
Claude Chat API 适配器
|
|||
|
|
|
|||
|
|
处理 Claude Chat 格式的请求(/v1/messages 端点,进行格式验证)。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
FORMAT_ID = "CLAUDE"
|
|||
|
|
name = "claude.chat"
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
|
|||
|
|
"""延迟导入 Handler 类避免循环依赖"""
|
|||
|
|
from src.api.handlers.claude.handler import ClaudeChatHandler
|
|||
|
|
|
|||
|
|
return ClaudeChatHandler
|
|||
|
|
|
|||
|
|
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
|||
|
|
super().__init__(allowed_api_formats or ["CLAUDE"])
|
|||
|
|
logger.info(f"[{self.name}] 初始化Chat模式适配器 | API格式: {self.allowed_api_formats}")
|
|||
|
|
|
|||
|
|
def extract_api_key(self, request: Request) -> Optional[str]:
|
|||
|
|
"""从请求中提取 API 密钥 (x-api-key)"""
|
|||
|
|
return request.headers.get("x-api-key")
|
|||
|
|
|
|||
|
|
def detect_capability_requirements(
|
|||
|
|
self,
|
|||
|
|
headers: Dict[str, str],
|
|||
|
|
request_body: Optional[Dict[str, Any]] = None,
|
|||
|
|
) -> Dict[str, bool]:
|
|||
|
|
"""检测 Claude 请求中隐含的能力需求"""
|
|||
|
|
return ClaudeCapabilityDetector.detect_from_headers(headers)
|
|||
|
|
|
|||
|
|
# =========================================================================
|
|||
|
|
# Claude 特定的计费逻辑
|
|||
|
|
# =========================================================================
|
|||
|
|
|
|||
|
|
def compute_total_input_context(
|
|||
|
|
self,
|
|||
|
|
input_tokens: int,
|
|||
|
|
cache_read_input_tokens: int,
|
|||
|
|
cache_creation_input_tokens: int = 0,
|
|||
|
|
) -> int:
|
|||
|
|
"""
|
|||
|
|
计算 Claude 的总输入上下文(用于阶梯计费判定)
|
|||
|
|
|
|||
|
|
Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
|||
|
|
"""
|
|||
|
|
return input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
|||
|
|
|
|||
|
|
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
|||
|
|
"""验证请求体"""
|
|||
|
|
try:
|
|||
|
|
if not isinstance(original_request_body, dict):
|
|||
|
|
raise ValueError("Request body must be a JSON object")
|
|||
|
|
|
|||
|
|
required_fields = ["model", "messages", "max_tokens"]
|
|||
|
|
missing_fields = [f for f in required_fields if f not in original_request_body]
|
|||
|
|
if missing_fields:
|
|||
|
|
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
|
|||
|
|
|
|||
|
|
request = ClaudeMessagesRequest.model_validate(
|
|||
|
|
original_request_body,
|
|||
|
|
strict=False,
|
|||
|
|
)
|
|||
|
|
except ValueError as e:
|
|||
|
|
logger.error(f"请求体基本验证失败: {str(e)}")
|
|||
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
|
|||
|
|
request = ClaudeMessagesRequest.model_construct(
|
|||
|
|
model=original_request_body.get("model"),
|
|||
|
|
max_tokens=original_request_body.get("max_tokens"),
|
|||
|
|
messages=original_request_body.get("messages", []),
|
|||
|
|
stream=original_request_body.get("stream", False),
|
|||
|
|
)
|
|||
|
|
return request
|
|||
|
|
|
|||
|
|
def _build_audit_metadata(self, _payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
|||
|
|
"""构建 Claude Chat 特定的审计元数据"""
|
|||
|
|
role_counts: dict[str, int] = {}
|
|||
|
|
for message in request_obj.messages:
|
|||
|
|
role_counts[message.role] = role_counts.get(message.role, 0) + 1
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"action": "claude_messages",
|
|||
|
|
"model": request_obj.model,
|
|||
|
|
"stream": bool(request_obj.stream),
|
|||
|
|
"max_tokens": request_obj.max_tokens,
|
|||
|
|
"temperature": getattr(request_obj, "temperature", None),
|
|||
|
|
"top_p": getattr(request_obj, "top_p", None),
|
|||
|
|
"top_k": getattr(request_obj, "top_k", None),
|
|||
|
|
"messages_count": len(request_obj.messages),
|
|||
|
|
"message_roles": role_counts,
|
|||
|
|
"stop_sequences": len(request_obj.stop_sequences or []),
|
|||
|
|
"tools_count": len(request_obj.tools or []),
|
|||
|
|
"system_present": bool(request_obj.system),
|
|||
|
|
"metadata_present": bool(request_obj.metadata),
|
|||
|
|
"thinking_enabled": bool(request_obj.thinking),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_claude_adapter(x_app_header: Optional[str]):
|
|||
|
|
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
|||
|
|
if x_app_header and x_app_header.lower() == "cli":
|
|||
|
|
from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter
|
|||
|
|
|
|||
|
|
return ClaudeCliAdapter()
|
|||
|
|
return ClaudeChatAdapter()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ClaudeTokenCountAdapter(ApiAdapter):
|
|||
|
|
"""计算 Claude 请求 Token 数的轻量适配器。"""
|
|||
|
|
|
|||
|
|
name = "claude.token_count"
|
|||
|
|
mode = ApiMode.STANDARD
|
|||
|
|
|
|||
|
|
def extract_api_key(self, request: Request) -> Optional[str]:
|
|||
|
|
"""从请求中提取 API 密钥 (x-api-key 或 Authorization: Bearer)"""
|
|||
|
|
# 优先检查 x-api-key
|
|||
|
|
api_key = request.headers.get("x-api-key")
|
|||
|
|
if api_key:
|
|||
|
|
return api_key
|
|||
|
|
# 降级到 Authorization: Bearer
|
|||
|
|
authorization = request.headers.get("authorization")
|
|||
|
|
if authorization and authorization.startswith("Bearer "):
|
|||
|
|
return authorization.replace("Bearer ", "")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
async def handle(self, context: ApiRequestContext):
|
|||
|
|
payload = context.ensure_json_body()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
request = ClaudeTokenCountRequest.model_validate(payload, strict=False)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Token count payload invalid: {e}")
|
|||
|
|
raise HTTPException(status_code=400, detail="Invalid token count payload") from e
|
|||
|
|
|
|||
|
|
token_counter = TokenCounter()
|
|||
|
|
total_tokens = 0
|
|||
|
|
|
|||
|
|
if request.system:
|
|||
|
|
if isinstance(request.system, str):
|
|||
|
|
total_tokens += token_counter.count_tokens(request.system, request.model)
|
|||
|
|
elif isinstance(request.system, list):
|
|||
|
|
for block in request.system:
|
|||
|
|
if hasattr(block, "text"):
|
|||
|
|
total_tokens += token_counter.count_tokens(block.text, request.model)
|
|||
|
|
|
|||
|
|
messages_dict = [
|
|||
|
|
msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in request.messages
|
|||
|
|
]
|
|||
|
|
total_tokens += token_counter.count_messages_tokens(messages_dict, request.model)
|
|||
|
|
|
|||
|
|
context.add_audit_metadata(
|
|||
|
|
action="claude_token_count",
|
|||
|
|
model=request.model,
|
|||
|
|
messages_count=len(request.messages),
|
|||
|
|
system_present=bool(request.system),
|
|||
|
|
tools_count=len(request.tools or []),
|
|||
|
|
thinking_enabled=bool(request.thinking),
|
|||
|
|
input_tokens=total_tokens,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return JSONResponse({"input_tokens": total_tokens})
|
|||
|
|
|
|||
|
|
|
|||
|
|
__all__ = [
|
|||
|
|
"ClaudeChatAdapter",
|
|||
|
|
"ClaudeTokenCountAdapter",
|
|||
|
|
"build_claude_adapter",
|
|||
|
|
]
|