mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 18:52:28 +08:00
Initial commit
This commit is contained in:
228
src/api/handlers/claude/adapter.py
Normal file
228
src/api/handlers/claude/adapter.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Claude Chat Adapter - 基于 ChatAdapterBase 的 Claude Chat API 适配器
|
||||
|
||||
处理 /v1/messages 端点的 Claude Chat 格式请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.core.logger import logger
|
||||
from src.core.optimization_utils import TokenCounter
|
||||
from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest
|
||||
|
||||
|
||||
class ClaudeCapabilityDetector:
|
||||
"""Claude API 能力检测器"""
|
||||
|
||||
@staticmethod
|
||||
def detect_from_headers(
|
||||
headers: Dict[str, str],
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
从 Claude 请求头检测能力需求
|
||||
|
||||
检测规则:
|
||||
- anthropic-beta: context-1m-xxx -> context_1m: True
|
||||
|
||||
Args:
|
||||
headers: 请求头字典
|
||||
request_body: 请求体(Claude 不使用,保留用于接口统一)
|
||||
"""
|
||||
requirements: Dict[str, bool] = {}
|
||||
|
||||
# 检查 anthropic-beta 请求头(大小写不敏感)
|
||||
beta_header = None
|
||||
for key, value in headers.items():
|
||||
if key.lower() == "anthropic-beta":
|
||||
beta_header = value
|
||||
break
|
||||
|
||||
if beta_header:
|
||||
# 检查是否包含 context-1m 标识
|
||||
if "context-1m" in beta_header.lower():
|
||||
requirements["context_1m"] = True
|
||||
|
||||
return requirements
|
||||
|
||||
|
||||
@register_adapter
|
||||
class ClaudeChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
Claude Chat API 适配器
|
||||
|
||||
处理 Claude Chat 格式的请求(/v1/messages 端点,进行格式验证)。
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE"
|
||||
name = "claude.chat"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.claude.handler import ClaudeChatHandler
|
||||
|
||||
return ClaudeChatHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["CLAUDE"])
|
||||
logger.info(f"[{self.name}] 初始化Chat模式适配器 | API格式: {self.allowed_api_formats}")
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (x-api-key)"""
|
||||
return request.headers.get("x-api-key")
|
||||
|
||||
def detect_capability_requirements(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""检测 Claude 请求中隐含的能力需求"""
|
||||
return ClaudeCapabilityDetector.detect_from_headers(headers)
|
||||
|
||||
# =========================================================================
|
||||
# Claude 特定的计费逻辑
|
||||
# =========================================================================
|
||||
|
||||
def compute_total_input_context(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
计算 Claude 的总输入上下文(用于阶梯计费判定)
|
||||
|
||||
Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
"""
|
||||
return input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
|
||||
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
||||
"""验证请求体"""
|
||||
try:
|
||||
if not isinstance(original_request_body, dict):
|
||||
raise ValueError("Request body must be a JSON object")
|
||||
|
||||
required_fields = ["model", "messages", "max_tokens"]
|
||||
missing_fields = [f for f in required_fields if f not in original_request_body]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
|
||||
|
||||
request = ClaudeMessagesRequest.model_validate(
|
||||
original_request_body,
|
||||
strict=False,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"请求体基本验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
|
||||
request = ClaudeMessagesRequest.model_construct(
|
||||
model=original_request_body.get("model"),
|
||||
max_tokens=original_request_body.get("max_tokens"),
|
||||
messages=original_request_body.get("messages", []),
|
||||
stream=original_request_body.get("stream", False),
|
||||
)
|
||||
return request
|
||||
|
||||
def _build_audit_metadata(self, _payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
||||
"""构建 Claude Chat 特定的审计元数据"""
|
||||
role_counts: dict[str, int] = {}
|
||||
for message in request_obj.messages:
|
||||
role_counts[message.role] = role_counts.get(message.role, 0) + 1
|
||||
|
||||
return {
|
||||
"action": "claude_messages",
|
||||
"model": request_obj.model,
|
||||
"stream": bool(request_obj.stream),
|
||||
"max_tokens": request_obj.max_tokens,
|
||||
"temperature": getattr(request_obj, "temperature", None),
|
||||
"top_p": getattr(request_obj, "top_p", None),
|
||||
"top_k": getattr(request_obj, "top_k", None),
|
||||
"messages_count": len(request_obj.messages),
|
||||
"message_roles": role_counts,
|
||||
"stop_sequences": len(request_obj.stop_sequences or []),
|
||||
"tools_count": len(request_obj.tools or []),
|
||||
"system_present": bool(request_obj.system),
|
||||
"metadata_present": bool(request_obj.metadata),
|
||||
"thinking_enabled": bool(request_obj.thinking),
|
||||
}
|
||||
|
||||
|
||||
def build_claude_adapter(x_app_header: Optional[str]):
|
||||
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
||||
if x_app_header and x_app_header.lower() == "cli":
|
||||
from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter
|
||||
|
||||
return ClaudeCliAdapter()
|
||||
return ClaudeChatAdapter()
|
||||
|
||||
|
||||
class ClaudeTokenCountAdapter(ApiAdapter):
|
||||
"""计算 Claude 请求 Token 数的轻量适配器。"""
|
||||
|
||||
name = "claude.token_count"
|
||||
mode = ApiMode.STANDARD
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (x-api-key 或 Authorization: Bearer)"""
|
||||
# 优先检查 x-api-key
|
||||
api_key = request.headers.get("x-api-key")
|
||||
if api_key:
|
||||
return api_key
|
||||
# 降级到 Authorization: Bearer
|
||||
authorization = request.headers.get("authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
request = ClaudeTokenCountRequest.model_validate(payload, strict=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Token count payload invalid: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid token count payload") from e
|
||||
|
||||
token_counter = TokenCounter()
|
||||
total_tokens = 0
|
||||
|
||||
if request.system:
|
||||
if isinstance(request.system, str):
|
||||
total_tokens += token_counter.count_tokens(request.system, request.model)
|
||||
elif isinstance(request.system, list):
|
||||
for block in request.system:
|
||||
if hasattr(block, "text"):
|
||||
total_tokens += token_counter.count_tokens(block.text, request.model)
|
||||
|
||||
messages_dict = [
|
||||
msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in request.messages
|
||||
]
|
||||
total_tokens += token_counter.count_messages_tokens(messages_dict, request.model)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="claude_token_count",
|
||||
model=request.model,
|
||||
messages_count=len(request.messages),
|
||||
system_present=bool(request.system),
|
||||
tools_count=len(request.tools or []),
|
||||
thinking_enabled=bool(request.thinking),
|
||||
input_tokens=total_tokens,
|
||||
)
|
||||
|
||||
return JSONResponse({"input_tokens": total_tokens})
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ClaudeChatAdapter",
|
||||
"ClaudeTokenCountAdapter",
|
||||
"build_claude_adapter",
|
||||
]
|
||||
Reference in New Issue
Block a user